1 #include <algorithm>
2 #include <atomic>
3 #include <cmath>
4 #include <iostream>
5 #include <sstream>
6 #include <utility>
7 
8 #include "CSE.h"
9 #include "Debug.h"
10 #include "IREquality.h"
11 #include "IRMutator.h"
12 #include "IROperator.h"
13 #include "IRPrinter.h"
14 #include "Util.h"
15 #include "Var.h"
16 
17 namespace Halide {
18 
19 // Evaluate a float polynomial efficiently, taking instruction latency
20 // into account. The high order terms come first. n is the number of
21 // terms, which is the degree plus one.
22 namespace {
23 
evaluate_polynomial(Expr x,float * coeff,int n)24 Expr evaluate_polynomial(Expr x, float *coeff, int n) {
25     internal_assert(n >= 2);
26 
27     Expr x2 = x * x;
28 
29     Expr even_terms = coeff[0];
30     Expr odd_terms = coeff[1];
31 
32     for (int i = 2; i < n; i++) {
33         if ((i & 1) == 0) {
34             if (coeff[i] == 0.0f) {
35                 even_terms *= x2;
36             } else {
37                 even_terms = even_terms * x2 + coeff[i];
38             }
39         } else {
40             if (coeff[i] == 0.0f) {
41                 odd_terms *= x2;
42             } else {
43                 odd_terms = odd_terms * x2 + coeff[i];
44             }
45         }
46     }
47 
48     if ((n & 1) == 0) {
49         return even_terms * std::move(x) + odd_terms;
50     } else {
51         return odd_terms * std::move(x) + even_terms;
52     }
53 }
54 
stringify(const std::vector<Expr> & args)55 Expr stringify(const std::vector<Expr> &args) {
56     if (args.empty()) {
57         return Expr("");
58     }
59 
60     return Internal::Call::make(type_of<const char *>(), Internal::Call::stringify,
61                                 args, Internal::Call::Intrinsic);
62 }
63 
combine_strings(const std::vector<Expr> & args)64 Expr combine_strings(const std::vector<Expr> &args) {
65     // Insert spaces between each expr.
66     std::vector<Expr> strings(args.size() * 2);
67     for (size_t i = 0; i < args.size(); i++) {
68         strings[i * 2] = args[i];
69         if (i < args.size() - 1) {
70             strings[i * 2 + 1] = Expr(" ");
71         } else {
72             strings[i * 2 + 1] = Expr("\n");
73         }
74     }
75 
76     return stringify(strings);
77 }
78 
79 }  // namespace
80 
81 namespace Internal {
82 
is_const(const Expr & e)83 bool is_const(const Expr &e) {
84     if (e.as<IntImm>() ||
85         e.as<UIntImm>() ||
86         e.as<FloatImm>() ||
87         e.as<StringImm>()) {
88         return true;
89     } else if (const Cast *c = e.as<Cast>()) {
90         return is_const(c->value);
91     } else if (const Ramp *r = e.as<Ramp>()) {
92         return is_const(r->base) && is_const(r->stride);
93     } else if (const Broadcast *b = e.as<Broadcast>()) {
94         return is_const(b->value);
95     } else {
96         return false;
97     }
98 }
99 
is_const(const Expr & e,int64_t value)100 bool is_const(const Expr &e, int64_t value) {
101     if (const IntImm *i = e.as<IntImm>()) {
102         return i->value == value;
103     } else if (const UIntImm *i = e.as<UIntImm>()) {
104         return (value >= 0) && (i->value == (uint64_t)value);
105     } else if (const FloatImm *i = e.as<FloatImm>()) {
106         return i->value == value;
107     } else if (const Cast *c = e.as<Cast>()) {
108         return is_const(c->value, value);
109     } else if (const Broadcast *b = e.as<Broadcast>()) {
110         return is_const(b->value, value);
111     } else {
112         return false;
113     }
114 }
115 
is_no_op(const Stmt & s)116 bool is_no_op(const Stmt &s) {
117     if (!s.defined()) return true;
118     const Evaluate *e = s.as<Evaluate>();
119     return e && is_const(e->value);
120 }
121 
122 namespace {
123 
124 class ExprIsPure : public IRGraphVisitor {
125     using IRVisitor::visit;
126 
visit(const Call * op)127     void visit(const Call *op) override {
128         if (!op->is_pure()) {
129             result = false;
130         } else {
131             IRGraphVisitor::visit(op);
132         }
133     }
134 
visit(const Load * op)135     void visit(const Load *op) override {
136         if (!op->image.defined() && !op->param.defined()) {
137             // It's a load from an internal buffer, which could
138             // mutate.
139             result = false;
140         } else {
141             IRGraphVisitor::visit(op);
142         }
143     }
144 
145 public:
146     bool result = true;
147 };
148 
149 }  // namespace
150 
is_pure(const Expr & e)151 bool is_pure(const Expr &e) {
152     ExprIsPure pure;
153     e.accept(&pure);
154     return pure.result;
155 }
156 
as_const_int(const Expr & e)157 const int64_t *as_const_int(const Expr &e) {
158     if (!e.defined()) {
159         return nullptr;
160     } else if (const Broadcast *b = e.as<Broadcast>()) {
161         return as_const_int(b->value);
162     } else if (const IntImm *i = e.as<IntImm>()) {
163         return &(i->value);
164     } else {
165         return nullptr;
166     }
167 }
168 
as_const_uint(const Expr & e)169 const uint64_t *as_const_uint(const Expr &e) {
170     if (!e.defined()) {
171         return nullptr;
172     } else if (const Broadcast *b = e.as<Broadcast>()) {
173         return as_const_uint(b->value);
174     } else if (const UIntImm *i = e.as<UIntImm>()) {
175         return &(i->value);
176     } else {
177         return nullptr;
178     }
179 }
180 
as_const_float(const Expr & e)181 const double *as_const_float(const Expr &e) {
182     if (!e.defined()) {
183         return nullptr;
184     } else if (const Broadcast *b = e.as<Broadcast>()) {
185         return as_const_float(b->value);
186     } else if (const FloatImm *f = e.as<FloatImm>()) {
187         return &(f->value);
188     } else {
189         return nullptr;
190     }
191 }
192 
is_const_power_of_two_integer(const Expr & e,int * bits)193 bool is_const_power_of_two_integer(const Expr &e, int *bits) {
194     if (!(e.type().is_int() || e.type().is_uint())) return false;
195 
196     const Broadcast *b = e.as<Broadcast>();
197     if (b) return is_const_power_of_two_integer(b->value, bits);
198 
199     const Cast *c = e.as<Cast>();
200     if (c) return is_const_power_of_two_integer(c->value, bits);
201 
202     uint64_t val = 0;
203 
204     if (const int64_t *i = as_const_int(e)) {
205         if (*i < 0) return false;
206         val = (uint64_t)(*i);
207     } else if (const uint64_t *u = as_const_uint(e)) {
208         val = *u;
209     }
210 
211     if (val && ((val & (val - 1)) == 0)) {
212         *bits = 0;
213         for (; val; val >>= 1) {
214             if (val == 1) return true;
215             (*bits)++;
216         }
217     }
218 
219     return false;
220 }
221 
is_positive_const(const Expr & e)222 bool is_positive_const(const Expr &e) {
223     if (const IntImm *i = e.as<IntImm>()) return i->value > 0;
224     if (const UIntImm *u = e.as<UIntImm>()) return u->value > 0;
225     if (const FloatImm *f = e.as<FloatImm>()) return f->value > 0.0f;
226     if (const Cast *c = e.as<Cast>()) {
227         return is_positive_const(c->value);
228     }
229     if (const Ramp *r = e.as<Ramp>()) {
230         // slightly conservative
231         return is_positive_const(r->base) && is_positive_const(r->stride);
232     }
233     if (const Broadcast *b = e.as<Broadcast>()) {
234         return is_positive_const(b->value);
235     }
236     return false;
237 }
238 
is_negative_const(const Expr & e)239 bool is_negative_const(const Expr &e) {
240     if (const IntImm *i = e.as<IntImm>()) return i->value < 0;
241     if (const FloatImm *f = e.as<FloatImm>()) return f->value < 0.0f;
242     if (const Cast *c = e.as<Cast>()) {
243         return is_negative_const(c->value);
244     }
245     if (const Ramp *r = e.as<Ramp>()) {
246         // slightly conservative
247         return is_negative_const(r->base) && is_negative_const(r->stride);
248     }
249     if (const Broadcast *b = e.as<Broadcast>()) {
250         return is_negative_const(b->value);
251     }
252     return false;
253 }
254 
is_negative_negatable_const(const Expr & e,Type T)255 bool is_negative_negatable_const(const Expr &e, Type T) {
256     if (const IntImm *i = e.as<IntImm>()) {
257         return (i->value < 0 && !T.is_min(i->value));
258     }
259     if (const FloatImm *f = e.as<FloatImm>()) return f->value < 0.0f;
260     if (const Cast *c = e.as<Cast>()) {
261         return is_negative_negatable_const(c->value, c->type);
262     }
263     if (const Ramp *r = e.as<Ramp>()) {
264         // slightly conservative
265         return is_negative_negatable_const(r->base) && is_negative_const(r->stride);
266     }
267     if (const Broadcast *b = e.as<Broadcast>()) {
268         return is_negative_negatable_const(b->value);
269     }
270     return false;
271 }
272 
is_negative_negatable_const(const Expr & e)273 bool is_negative_negatable_const(const Expr &e) {
274     return is_negative_negatable_const(e, e.type());
275 }
276 
is_undef(const Expr & e)277 bool is_undef(const Expr &e) {
278     if (const Call *c = e.as<Call>()) return c->is_intrinsic(Call::undef);
279     return false;
280 }
281 
is_zero(const Expr & e)282 bool is_zero(const Expr &e) {
283     if (const IntImm *int_imm = e.as<IntImm>()) return int_imm->value == 0;
284     if (const UIntImm *uint_imm = e.as<UIntImm>()) return uint_imm->value == 0;
285     if (const FloatImm *float_imm = e.as<FloatImm>()) return float_imm->value == 0.0;
286     if (const Cast *c = e.as<Cast>()) return is_zero(c->value);
287     if (const Broadcast *b = e.as<Broadcast>()) return is_zero(b->value);
288     if (const Call *c = e.as<Call>()) {
289         return (c->is_intrinsic(Call::bool_to_mask) || c->is_intrinsic(Call::cast_mask)) &&
290                is_zero(c->args[0]);
291     }
292     return false;
293 }
294 
is_one(const Expr & e)295 bool is_one(const Expr &e) {
296     if (const IntImm *int_imm = e.as<IntImm>()) return int_imm->value == 1;
297     if (const UIntImm *uint_imm = e.as<UIntImm>()) return uint_imm->value == 1;
298     if (const FloatImm *float_imm = e.as<FloatImm>()) return float_imm->value == 1.0;
299     if (const Cast *c = e.as<Cast>()) return is_one(c->value);
300     if (const Broadcast *b = e.as<Broadcast>()) return is_one(b->value);
301     if (const Call *c = e.as<Call>()) {
302         return (c->is_intrinsic(Call::bool_to_mask) || c->is_intrinsic(Call::cast_mask)) &&
303                is_one(c->args[0]);
304     }
305     return false;
306 }
307 
is_two(const Expr & e)308 bool is_two(const Expr &e) {
309     if (e.type().bits() < 2) return false;
310     if (const IntImm *int_imm = e.as<IntImm>()) return int_imm->value == 2;
311     if (const UIntImm *uint_imm = e.as<UIntImm>()) return uint_imm->value == 2;
312     if (const FloatImm *float_imm = e.as<FloatImm>()) return float_imm->value == 2.0;
313     if (const Cast *c = e.as<Cast>()) return is_two(c->value);
314     if (const Broadcast *b = e.as<Broadcast>()) return is_two(b->value);
315     return false;
316 }
317 
318 namespace {
319 
320 template<typename T>
make_const_helper(Type t,T val)321 Expr make_const_helper(Type t, T val) {
322     if (t.is_vector()) {
323         return Broadcast::make(make_const(t.element_of(), val), t.lanes());
324     } else if (t.is_int()) {
325         return IntImm::make(t, (int64_t)val);
326     } else if (t.is_uint()) {
327         return UIntImm::make(t, (uint64_t)val);
328     } else if (t.is_float()) {
329         return FloatImm::make(t, (double)val);
330     } else {
331         internal_error << "Can't make a constant of type " << t << "\n";
332         return Expr();
333     }
334 }
335 
336 }  // namespace
337 
make_const(Type t,int64_t val)338 Expr make_const(Type t, int64_t val) {
339     return make_const_helper(t, val);
340 }
341 
make_const(Type t,uint64_t val)342 Expr make_const(Type t, uint64_t val) {
343     return make_const_helper(t, val);
344 }
345 
make_const(Type t,double val)346 Expr make_const(Type t, double val) {
347     return make_const_helper(t, val);
348 }
349 
make_bool(bool val,int w)350 Expr make_bool(bool val, int w) {
351     return make_const(UInt(1, w), val);
352 }
353 
make_zero(Type t)354 Expr make_zero(Type t) {
355     if (t.is_handle()) {
356         return reinterpret(t, make_zero(UInt(64)));
357     } else {
358         return make_const(t, 0);
359     }
360 }
361 
make_one(Type t)362 Expr make_one(Type t) {
363     return make_const(t, 1);
364 }
365 
make_two(Type t)366 Expr make_two(Type t) {
367     return make_const(t, 2);
368 }
369 
make_signed_integer_overflow(Type type)370 Expr make_signed_integer_overflow(Type type) {
371     static std::atomic<int> counter{0};
372     return Call::make(type, Call::signed_integer_overflow, {counter++}, Call::Intrinsic);
373 }
374 
const_true(int w)375 Expr const_true(int w) {
376     return make_one(UInt(1, w));
377 }
378 
const_false(int w)379 Expr const_false(int w) {
380     return make_zero(UInt(1, w));
381 }
382 
lossless_cast(Type t,Expr e)383 Expr lossless_cast(Type t, Expr e) {
384     if (!e.defined() || t == e.type()) {
385         return e;
386     } else if (t.can_represent(e.type())) {
387         return cast(t, std::move(e));
388     }
389 
390     if (const Cast *c = e.as<Cast>()) {
391         if (t.can_represent(c->value.type())) {
392             // We can recurse into widening casts.
393             return lossless_cast(t, c->value);
394         } else {
395             return Expr();
396         }
397     }
398 
399     if (const Broadcast *b = e.as<Broadcast>()) {
400         Expr v = lossless_cast(t.element_of(), b->value);
401         if (v.defined()) {
402             return Broadcast::make(v, b->lanes);
403         } else {
404             return Expr();
405         }
406     }
407 
408     if (const IntImm *i = e.as<IntImm>()) {
409         if (t.can_represent(i->value)) {
410             return make_const(t, i->value);
411         } else {
412             return Expr();
413         }
414     }
415 
416     if (const UIntImm *i = e.as<UIntImm>()) {
417         if (t.can_represent(i->value)) {
418             return make_const(t, i->value);
419         } else {
420             return Expr();
421         }
422     }
423 
424     if (const FloatImm *f = e.as<FloatImm>()) {
425         if (t.can_represent(f->value)) {
426             return make_const(t, f->value);
427         } else {
428             return Expr();
429         }
430     }
431 
432     if ((t.is_int() || t.is_uint()) && t.bits() >= 16) {
433         if (const Add *add = e.as<Add>()) {
434             // If we can losslessly narrow the args even more
435             // aggressively, we're good.
436             // E.g. lossless_cast(uint16, (uint32)(some_u8) + 37)
437             // = (uint16)(some_u8) + 37
438             Expr a = lossless_cast(t.with_bits(t.bits() / 2), add->a);
439             Expr b = lossless_cast(t.with_bits(t.bits() / 2), add->b);
440             if (a.defined() && b.defined()) {
441                 return cast(t, a) + cast(t, b);
442             } else {
443                 return Expr();
444             }
445         }
446 
447         if (const Sub *sub = e.as<Sub>()) {
448             Expr a = lossless_cast(t.with_bits(t.bits() / 2), sub->a);
449             Expr b = lossless_cast(t.with_bits(t.bits() / 2), sub->b);
450             if (a.defined() && b.defined()) {
451                 return cast(t, a) + cast(t, b);
452             } else {
453                 return Expr();
454             }
455         }
456 
457         if (const Mul *mul = e.as<Mul>()) {
458             Expr a = lossless_cast(t.with_bits(t.bits() / 2), mul->a);
459             Expr b = lossless_cast(t.with_bits(t.bits() / 2), mul->b);
460             if (a.defined() && b.defined()) {
461                 return cast(t, a) * cast(t, b);
462             } else {
463                 return Expr();
464             }
465         }
466 
467         if (const VectorReduce *reduce = e.as<VectorReduce>()) {
468             const int factor = reduce->value.type().lanes() / reduce->type.lanes();
469             switch (reduce->op) {
470             case VectorReduce::Add:
471                 // A horizontal add requires one extra bit per factor
472                 // of two in the reduction factor. E.g. a reduction of
473                 // 8 vector lanes down to 2 requires 2 extra bits in
474                 // the output. We only deal with power-of-two types
475                 // though, so just make sure the reduction factor
476                 // isn't so large that it will more than double the
477                 // number of bits required.
478                 if (factor < (1 << (t.bits() / 2))) {
479                     Type narrower = reduce->value.type().with_bits(t.bits() / 2);
480                     Expr val = lossless_cast(narrower, reduce->value);
481                     if (val.defined()) {
482                         return VectorReduce::make(reduce->op, val, reduce->type.lanes());
483                     }
484                 }
485                 break;
486             case VectorReduce::Max:
487             case VectorReduce::Min: {
488                 Expr val = lossless_cast(t, reduce->value);
489                 if (val.defined()) {
490                     return VectorReduce::make(reduce->op, val, reduce->type.lanes());
491                 }
492                 break;
493             }
494             default:
495                 break;
496             }
497         }
498     }
499 
500     return Expr();
501 }
502 
check_representable(Type dst,int64_t x)503 void check_representable(Type dst, int64_t x) {
504     if (dst.is_handle()) {
505         user_assert(dst.can_represent(x))
506             << "Integer constant " << x
507             << " will be implicitly coerced to type " << dst
508             << ", but Halide does not support pointer arithmetic.\n";
509     } else {
510         user_assert(dst.can_represent(x))
511             << "Integer constant " << x
512             << " will be implicitly coerced to type " << dst
513             << ", which changes its value to " << make_const(dst, x)
514             << ".\n";
515     }
516 }
517 
match_types(Expr & a,Expr & b)518 void match_types(Expr &a, Expr &b) {
519     if (a.type() == b.type()) return;
520 
521     user_assert(!a.type().is_handle() && !b.type().is_handle())
522         << "Can't do arithmetic on opaque pointer types: "
523         << a << ", " << b << "\n";
524 
525     // Broadcast scalar to match vector
526     if (a.type().is_scalar() && b.type().is_vector()) {
527         a = Broadcast::make(std::move(a), b.type().lanes());
528     } else if (a.type().is_vector() && b.type().is_scalar()) {
529         b = Broadcast::make(std::move(b), a.type().lanes());
530     } else {
531         internal_assert(a.type().lanes() == b.type().lanes()) << "Can't match types of differing widths";
532     }
533 
534     Type ta = a.type(), tb = b.type();
535 
536     // If type broadcasting has made the types match no additional casts are needed
537     if (ta == tb) return;
538 
539     if (!ta.is_float() && tb.is_float()) {
540         // int(a) * float(b) -> float(b)
541         // uint(a) * float(b) -> float(b)
542         a = cast(tb, std::move(a));
543     } else if (ta.is_float() && !tb.is_float()) {
544         b = cast(ta, std::move(b));
545     } else if (ta.is_float() && tb.is_float()) {
546         // float(a) * float(b) -> float(max(a, b))
547         if (ta.bits() > tb.bits())
548             b = cast(ta, std::move(b));
549         else
550             a = cast(tb, std::move(a));
551     } else if (ta.is_uint() && tb.is_uint()) {
552         // uint(a) * uint(b) -> uint(max(a, b))
553         if (ta.bits() > tb.bits())
554             b = cast(ta, std::move(b));
555         else
556             a = cast(tb, std::move(a));
557     } else if (!ta.is_float() && !tb.is_float()) {
558         // int(a) * (u)int(b) -> int(max(a, b))
559         int bits = std::max(ta.bits(), tb.bits());
560         int lanes = a.type().lanes();
561         a = cast(Int(bits, lanes), std::move(a));
562         b = cast(Int(bits, lanes), std::move(b));
563     } else {
564         internal_error << "Could not match types: " << ta << ", " << tb << "\n";
565     }
566 }
567 
568 // Cast to the wider type of the two. Already guaranteed to leave
569 // signed/unsigned on number of lanes unchanged.
match_bits(Expr & x,Expr & y)570 void match_bits(Expr &x, Expr &y) {
571     // The signedness doesn't match, so just match the bits.
572     if (x.type().bits() < y.type().bits()) {
573         Type t;
574         if (x.type().is_int()) {
575             t = Int(y.type().bits(), y.type().lanes());
576         } else {
577             t = UInt(y.type().bits(), y.type().lanes());
578         }
579         x = cast(t, x);
580     } else if (y.type().bits() < x.type().bits()) {
581         Type t;
582         if (y.type().is_int()) {
583             t = Int(x.type().bits(), x.type().lanes());
584         } else {
585             t = UInt(x.type().bits(), x.type().lanes());
586         }
587         y = cast(t, y);
588     }
589 }
590 
match_types_bitwise(Expr & x,Expr & y,const char * op_name)591 void match_types_bitwise(Expr &x, Expr &y, const char *op_name) {
592     user_assert(x.defined() && y.defined()) << op_name << " of undefined Expr\n";
593     user_assert(x.type().is_int() || x.type().is_uint())
594         << "The first argument to " << op_name << " must be an integer or unsigned integer";
595     user_assert(y.type().is_int() || y.type().is_uint())
596         << "The second argument to " << op_name << " must be an integer or unsigned integer";
597     user_assert(y.type().is_int() == x.type().is_int())
598         << "Arguments to " << op_name
599         << " must be both be signed or both be unsigned.\n"
600         << "LHS type: " << x.type() << " RHS type: " << y.type() << "\n"
601         << "LHS value: " << x << " RHS value: " << y << "\n";
602 
603     // Broadcast scalar to match vector
604     if (x.type().is_scalar() && y.type().is_vector()) {
605         x = Broadcast::make(std::move(x), y.type().lanes());
606     } else if (x.type().is_vector() && y.type().is_scalar()) {
607         y = Broadcast::make(std::move(y), x.type().lanes());
608     } else {
609         internal_assert(x.type().lanes() == y.type().lanes()) << "Can't match types of differing widths";
610     }
611 
612     // Cast to the wider type of the two. Already guaranteed to leave
613     // signed/unsigned on number of lanes unchanged.
614     if (x.type().bits() < y.type().bits()) {
615         x = cast(y.type(), x);
616     } else if (y.type().bits() < x.type().bits()) {
617         y = cast(x.type(), y);
618     }
619 }
620 
621 // Fast math ops based on those from Syrah (http://github.com/boulos/syrah). Thanks, Solomon!
622 
623 // Factor a float into 2^exponent * reduced, where reduced is between 0.75 and 1.5
range_reduce_log(const Expr & input,Expr * reduced,Expr * exponent)624 void range_reduce_log(const Expr &input, Expr *reduced, Expr *exponent) {
625     Type type = input.type();
626     Type int_type = Int(32, type.lanes());
627     Expr int_version = reinterpret(int_type, input);
628 
629     // single precision = SEEE EEEE EMMM MMMM MMMM MMMM MMMM MMMM
630     // exponent mask    = 0111 1111 1000 0000 0000 0000 0000 0000
631     //                    0x7  0xF  0x8  0x0  0x0  0x0  0x0  0x0
632     // non-exponent     = 1000 0000 0111 1111 1111 1111 1111 1111
633     //                  = 0x8  0x0  0x7  0xF  0xF  0xF  0xF  0xF
634     Expr non_exponent_mask = make_const(int_type, 0x807fffff);
635 
636     // Extract a version with no exponent (between 1.0 and 2.0)
637     Expr no_exponent = int_version & non_exponent_mask;
638 
639     // If > 1.5, we want to divide by two, to normalize back into the
640     // range (0.75, 1.5). We can detect this by sniffing the high bit
641     // of the mantissa.
642     Expr new_exponent = no_exponent >> 22;
643 
644     Expr new_biased_exponent = 127 - new_exponent;
645     Expr old_biased_exponent = int_version >> 23;
646     *exponent = old_biased_exponent - new_biased_exponent;
647 
648     Expr blended = (int_version & non_exponent_mask) | (new_biased_exponent << 23);
649 
650     *reduced = reinterpret(type, blended);
651 }
652 
halide_log(const Expr & x_full)653 Expr halide_log(const Expr &x_full) {
654     Type type = x_full.type();
655     internal_assert(type.element_of() == Float(32));
656 
657     Expr nan = Call::make(type, "nan_f32", {}, Call::PureExtern);
658     Expr neg_inf = Call::make(type, "neg_inf_f32", {}, Call::PureExtern);
659 
660     Expr use_nan = x_full < 0.0f;       // log of a negative returns nan
661     Expr use_neg_inf = x_full == 0.0f;  // log of zero is -inf
662     Expr exceptional = use_nan | use_neg_inf;
663 
664     // Avoid producing nans or infs by generating ln(1.0f) instead and
665     // then fixing it later.
666     Expr patched = select(exceptional, make_one(type), x_full);
667     Expr reduced, exponent;
668     range_reduce_log(patched, &reduced, &exponent);
669 
670     // Very close to the Taylor series for log about 1, but tuned to
671     // have minimum relative error in the reduced domain (0.75 - 1.5).
672 
673     float coeff[] = {
674         0.05111976432738144643f,
675         -0.11793923497136414580f,
676         0.14971993724699017569f,
677         -0.16862004708254804686f,
678         0.19980668101718729313f,
679         -0.24991211576292837737f,
680         0.33333435275479328386f,
681         -0.50000106292873236491f,
682         1.0f,
683         0.0f};
684     Expr x1 = reduced - 1.0f;
685     Expr result = evaluate_polynomial(x1, coeff, sizeof(coeff) / sizeof(coeff[0]));
686 
687     result += cast(type, exponent) * logf(2.0);
688 
689     result = select(exceptional, select(use_nan, nan, neg_inf), result);
690 
691     // This introduces lots of common subexpressions
692     result = common_subexpression_elimination(result);
693 
694     return result;
695 }
696 
halide_exp(const Expr & x_full)697 Expr halide_exp(const Expr &x_full) {
698     Type type = x_full.type();
699     internal_assert(type.element_of() == Float(32));
700 
701     float ln2_part1 = 0.6931457519f;
702     float ln2_part2 = 1.4286067653e-6f;
703     float one_over_ln2 = 1.0f / logf(2.0f);
704 
705     Expr scaled = x_full * one_over_ln2;
706     Expr k_real = floor(scaled);
707     Expr k = cast(Int(32, type.lanes()), k_real);
708 
709     Expr x = x_full - k_real * ln2_part1;
710     x -= k_real * ln2_part2;
711 
712     float coeff[] = {
713         0.00031965933071842413f,
714         0.00119156835564003744f,
715         0.00848988645943932717f,
716         0.04160188091348320655f,
717         0.16667983794100929562f,
718         0.49999899033463041098f,
719         1.0f,
720         1.0f};
721     Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0]));
722 
723     // Compute 2^k.
724     int fpbias = 127;
725     Expr biased = k + fpbias;
726 
727     Expr inf = Call::make(type, "inf_f32", {}, Call::PureExtern);
728 
729     // Shift the bits up into the exponent field and reinterpret this
730     // thing as float.
731     Expr two_to_the_n = reinterpret(type, biased << 23);
732     result *= two_to_the_n;
733 
734     // Catch overflow and underflow
735     result = select(biased < 255, result, inf);
736     result = select(biased > 0, result, make_zero(type));
737 
738     // This introduces lots of common subexpressions
739     result = common_subexpression_elimination(result);
740 
741     return result;
742 }
743 
halide_erf(const Expr & x_full)744 Expr halide_erf(const Expr &x_full) {
745     user_assert(x_full.type() == Float(32)) << "halide_erf only works for Float(32)";
746 
747     // Extract the sign and magnitude.
748     Expr sign = select(x_full < 0, -1.0f, 1.0f);
749     Expr x = abs(x_full);
750 
751     // An approximation very similar to one from Abramowitz and
752     // Stegun, but tuned for values > 1. Takes the form 1 - P(x)^-16.
753     float c1[] = {0.0000818502f,
754                   -0.0000026500f,
755                   0.0009353904f,
756                   0.0081960206f,
757                   0.0430054424f,
758                   0.0703310579f,
759                   1.0f};
760     Expr approx1 = evaluate_polynomial(x, c1, sizeof(c1) / sizeof(c1[0]));
761 
762     approx1 = 1.0f - pow(approx1, -16);
763 
764     // An odd polynomial tuned for values < 1. Similar to the Taylor
765     // expansion of erf.
766     float c2[] = {-0.0005553339f,
767                   0.0048937243f,
768                   -0.0266849239f,
769                   0.1127890132f,
770                   -0.3761207240f,
771                   1.1283789803f};
772 
773     Expr approx2 = evaluate_polynomial(x * x, c2, sizeof(c2) / sizeof(c2[0]));
774     approx2 *= x;
775 
776     // Switch between the two approximations based on the magnitude.
777     Expr y = select(x > 1.0f, approx1, approx2);
778 
779     Expr result = common_subexpression_elimination(sign * y);
780 
781     return result;
782 }
783 
raise_to_integer_power(Expr e,int64_t p)784 Expr raise_to_integer_power(Expr e, int64_t p) {
785     Expr result;
786     if (p == 0) {
787         result = make_one(e.type());
788     } else if (p == 1) {
789         result = std::move(e);
790     } else if (p < 0) {
791         result = make_one(e.type());
792         result /= raise_to_integer_power(std::move(e), -p);
793     } else {
794         // p is at least 2
795         if (p & 1) {
796             Expr y = raise_to_integer_power(e, p >> 1);
797             result = y * y * std::move(e);
798         } else {
799             e = raise_to_integer_power(std::move(e), p >> 1);
800             result = e * e;
801         }
802     }
803     return result;
804 }
805 
split_into_ands(const Expr & cond,std::vector<Expr> & result)806 void split_into_ands(const Expr &cond, std::vector<Expr> &result) {
807     if (!cond.defined()) {
808         return;
809     }
810     internal_assert(cond.type().is_bool()) << "Should be a boolean condition\n";
811     if (const And *a = cond.as<And>()) {
812         split_into_ands(a->a, result);
813         split_into_ands(a->b, result);
814     } else if (!is_one(cond)) {
815         result.push_back(cond);
816     }
817 }
818 
build() const819 Expr BufferBuilder::build() const {
820     std::vector<Expr> args(10);
821     if (buffer_memory.defined()) {
822         args[0] = buffer_memory;
823     } else {
824         Expr sz = Call::make(Int(32), Call::size_of_halide_buffer_t, {}, Call::Intrinsic);
825         args[0] = Call::make(type_of<struct halide_buffer_t *>(), Call::alloca, {sz}, Call::Intrinsic);
826     }
827 
828     std::string shape_var_name = unique_name('t');
829     Expr shape_var = Variable::make(type_of<halide_dimension_t *>(), shape_var_name);
830     if (shape_memory.defined()) {
831         args[1] = shape_memory;
832     } else if (dimensions == 0) {
833         args[1] = make_zero(type_of<halide_dimension_t *>());
834     } else {
835         args[1] = shape_var;
836     }
837 
838     if (host.defined()) {
839         args[2] = host;
840     } else {
841         args[2] = make_zero(type_of<void *>());
842     }
843 
844     if (device.defined()) {
845         args[3] = device;
846     } else {
847         args[3] = make_zero(UInt(64));
848     }
849 
850     if (device_interface.defined()) {
851         args[4] = device_interface;
852     } else {
853         args[4] = make_zero(type_of<struct halide_device_interface_t *>());
854     }
855 
856     args[5] = (int)type.code();
857     args[6] = type.bits();
858     args[7] = dimensions;
859 
860     std::vector<Expr> shape;
861     for (size_t i = 0; i < (size_t)dimensions; i++) {
862         if (i < mins.size()) {
863             shape.push_back(mins[i]);
864         } else {
865             shape.emplace_back(0);
866         }
867         if (i < extents.size()) {
868             shape.push_back(extents[i]);
869         } else {
870             shape.emplace_back(0);
871         }
872         if (i < strides.size()) {
873             shape.push_back(strides[i]);
874         } else {
875             shape.emplace_back(0);
876         }
877         // per-dimension flags, currently unused.
878         shape.emplace_back(0);
879     }
880     for (const Expr &e : shape) {
881         internal_assert(e.type() == Int(32))
882             << "Buffer shape fields must be int32_t:" << e << "\n";
883     }
884     Expr shape_arg = Call::make(type_of<halide_dimension_t *>(), Call::make_struct, shape, Call::Intrinsic);
885     if (shape_memory.defined()) {
886         args[8] = shape_arg;
887     } else if (dimensions == 0) {
888         args[8] = make_zero(type_of<halide_dimension_t *>());
889     } else {
890         args[8] = shape_var;
891     }
892 
893     Expr flags = make_zero(UInt(64));
894     if (host_dirty.defined()) {
895         flags = select(host_dirty,
896                        make_const(UInt(64), halide_buffer_flag_host_dirty),
897                        make_zero(UInt(64)));
898     }
899     if (device_dirty.defined()) {
900         flags = flags | select(device_dirty,
901                                make_const(UInt(64), halide_buffer_flag_device_dirty),
902                                make_zero(UInt(64)));
903     }
904     args[9] = flags;
905 
906     Expr e = Call::make(type_of<struct halide_buffer_t *>(), Call::buffer_init, args, Call::Extern);
907 
908     if (!shape_memory.defined() && dimensions != 0) {
909         e = Let::make(shape_var_name, shape_arg, e);
910     }
911 
912     return e;
913 }
914 
strided_ramp_base(const Expr & e,int stride)915 Expr strided_ramp_base(const Expr &e, int stride) {
916     const Ramp *r = e.as<Ramp>();
917     if (r == nullptr) {
918         return Expr();
919     }
920 
921     const IntImm *i = r->stride.as<IntImm>();
922     if (i != nullptr && i->value == stride) {
923         return r->base;
924     }
925 
926     return Expr();
927 }
928 
929 namespace {
930 
931 struct RemoveLikelies : public IRMutator {
932     using IRMutator::visit;
visitHalide::Internal::__anon460940500411::RemoveLikelies933     Expr visit(const Call *op) override {
934         if (op->is_intrinsic(Call::likely) ||
935             op->is_intrinsic(Call::likely_if_innermost)) {
936             return mutate(op->args[0]);
937         } else {
938             return IRMutator::visit(op);
939         }
940     }
941 };
942 
943 }  // namespace
944 
remove_likelies(const Expr & e)945 Expr remove_likelies(const Expr &e) {
946     return RemoveLikelies().mutate(e);
947 }
948 
remove_likelies(const Stmt & s)949 Stmt remove_likelies(const Stmt &s) {
950     return RemoveLikelies().mutate(s);
951 }
952 
requirement_failed_error(Expr condition,const std::vector<Expr> & args)953 Expr requirement_failed_error(Expr condition, const std::vector<Expr> &args) {
954     return Internal::Call::make(Int(32),
955                                 "halide_error_requirement_failed",
956                                 {stringify({std::move(condition)}), combine_strings(args)},
957                                 Internal::Call::Extern);
958 }
959 
memoize_tag_helper(Expr result,const std::vector<Expr> & cache_key_values)960 Expr memoize_tag_helper(Expr result, const std::vector<Expr> &cache_key_values) {
961     Type t = result.type();
962     std::vector<Expr> args;
963     args.push_back(std::move(result));
964     args.insert(args.end(), cache_key_values.begin(), cache_key_values.end());
965     return Internal::Call::make(t, Internal::Call::memoize_expr,
966                                 args, Internal::Call::PureIntrinsic);
967 }
968 
969 }  // namespace Internal
970 
fast_log(const Expr & x)971 Expr fast_log(const Expr &x) {
972     user_assert(x.type() == Float(32)) << "fast_log only works for Float(32)";
973 
974     Expr reduced, exponent;
975     range_reduce_log(x, &reduced, &exponent);
976 
977     Expr x1 = reduced - 1.0f;
978 
979     float coeff[] = {
980         0.07640318789187280912f,
981         -0.16252961013874300811f,
982         0.20625219040645212387f,
983         -0.25110261010892864775f,
984         0.33320464908377461777f,
985         -0.49997513376789826101f,
986         1.0f,
987         0.0f};
988 
989     Expr result = evaluate_polynomial(x1, coeff, sizeof(coeff) / sizeof(coeff[0]));
990     result = result + cast<float>(exponent) * logf(2);
991     result = common_subexpression_elimination(result);
992     return result;
993 }
994 
995 namespace {
996 
997 // A vectorizable sine and cosine implementation. Based on syrah fast vector math
998 // https://github.com/boulos/syrah/blob/master/src/include/syrah/FixedVectorMath.h#L55
fast_sin_cos(const Expr & x_full,bool is_sin)999 Expr fast_sin_cos(const Expr &x_full, bool is_sin) {
1000     const float two_over_pi = 0.636619746685028076171875f;
1001     const float pi_over_two = 1.57079637050628662109375f;
1002     Expr scaled = x_full * two_over_pi;
1003     Expr k_real = floor(scaled);
1004     Expr k = cast<int>(k_real);
1005     Expr k_mod4 = k % 4;
1006     Expr sin_usecos = is_sin ? ((k_mod4 == 1) || (k_mod4 == 3)) : ((k_mod4 == 0) || (k_mod4 == 2));
1007     Expr flip_sign = is_sin ? (k_mod4 > 1) : ((k_mod4 == 1) || (k_mod4 == 2));
1008 
1009     // Reduce the angle modulo pi/2.
1010     Expr x = x_full - k_real * pi_over_two;
1011 
1012     const float sin_c2 = -0.16666667163372039794921875f;
1013     const float sin_c4 = 8.333347737789154052734375e-3;
1014     const float sin_c6 = -1.9842604524455964565277099609375e-4;
1015     const float sin_c8 = 2.760012648650445044040679931640625e-6;
1016     const float sin_c10 = -2.50293279435709337121807038784027099609375e-8;
1017 
1018     const float cos_c2 = -0.5f;
1019     const float cos_c4 = 4.166664183139801025390625e-2;
1020     const float cos_c6 = -1.388833043165504932403564453125e-3;
1021     const float cos_c8 = 2.47562347794882953166961669921875e-5;
1022     const float cos_c10 = -2.59630184018533327616751194000244140625e-7;
1023 
1024     Expr outside = select(sin_usecos, 1, x);
1025     Expr c2 = select(sin_usecos, cos_c2, sin_c2);
1026     Expr c4 = select(sin_usecos, cos_c4, sin_c4);
1027     Expr c6 = select(sin_usecos, cos_c6, sin_c6);
1028     Expr c8 = select(sin_usecos, cos_c8, sin_c8);
1029     Expr c10 = select(sin_usecos, cos_c10, sin_c10);
1030 
1031     Expr x2 = x * x;
1032     Expr tri_func = outside * (x2 * (x2 * (x2 * (x2 * (x2 * c10 + c8) + c6) + c4) + c2) + 1);
1033     return select(flip_sign, -tri_func, tri_func);
1034 }
1035 
1036 }  // namespace
1037 
fast_sin(const Expr & x_full)1038 Expr fast_sin(const Expr &x_full) {
1039     return fast_sin_cos(x_full, true);
1040 }
1041 
fast_cos(const Expr & x_full)1042 Expr fast_cos(const Expr &x_full) {
1043     return fast_sin_cos(x_full, false);
1044 }
1045 
fast_exp(const Expr & x_full)1046 Expr fast_exp(const Expr &x_full) {
1047     user_assert(x_full.type() == Float(32)) << "fast_exp only works for Float(32)";
1048 
1049     Expr scaled = x_full / logf(2.0);
1050     Expr k_real = floor(scaled);
1051     Expr k = cast<int>(k_real);
1052     Expr x = x_full - k_real * logf(2.0);
1053 
1054     float coeff[] = {
1055         0.01314350012789660196f,
1056         0.03668965196652099192f,
1057         0.16873890085469545053f,
1058         0.49970514590562437052f,
1059         1.0f,
1060         1.0f};
1061     Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0]));
1062 
1063     // Compute 2^k.
1064     int fpbias = 127;
1065     Expr biased = clamp(k + fpbias, 0, 255);
1066 
1067     // Shift the bits up into the exponent field and reinterpret this
1068     // thing as float.
1069     Expr two_to_the_n = reinterpret<float>(biased << 23);
1070     result *= two_to_the_n;
1071     result = common_subexpression_elimination(result);
1072     return result;
1073 }
1074 
print(const std::vector<Expr> & args)1075 Expr print(const std::vector<Expr> &args) {
1076     Expr combined_string = combine_strings(args);
1077 
1078     // Call halide_print.
1079     Expr print_call =
1080         Internal::Call::make(Int(32), "halide_print",
1081                              {combined_string}, Internal::Call::Extern);
1082 
1083     // Return the first argument.
1084     Expr result =
1085         Internal::Call::make(args[0].type(), Internal::Call::return_second,
1086                              {print_call, args[0]}, Internal::Call::PureIntrinsic);
1087     return result;
1088 }
1089 
print_when(Expr condition,const std::vector<Expr> & args)1090 Expr print_when(Expr condition, const std::vector<Expr> &args) {
1091     Expr p = print(args);
1092     return Internal::Call::make(p.type(),
1093                                 Internal::Call::if_then_else,
1094                                 {std::move(condition), p, args[0]},
1095                                 Internal::Call::PureIntrinsic);
1096 }
1097 
require(Expr condition,const std::vector<Expr> & args)1098 Expr require(Expr condition, const std::vector<Expr> &args) {
1099     user_assert(condition.defined()) << "Require of undefined condition.\n";
1100     user_assert(condition.type().is_bool()) << "Require condition must be a boolean type.\n";
1101     user_assert(args.at(0).defined()) << "Require of undefined value.\n";
1102 
1103     Expr err = requirement_failed_error(condition, args);
1104 
1105     return Internal::Call::make(args[0].type(),
1106                                 Internal::Call::require,
1107                                 {likely(std::move(condition)), args[0], std::move(err)},
1108                                 Internal::Call::PureIntrinsic);
1109 }
1110 
saturating_cast(Type t,Expr e)1111 Expr saturating_cast(Type t, Expr e) {
1112     // For float to float, guarantee infinities are always pinned to range.
1113     if (t.is_float() && e.type().is_float()) {
1114         if (t.bits() < e.type().bits()) {
1115             e = cast(t, clamp(std::move(e), t.min(), t.max()));
1116         } else {
1117             e = clamp(cast(t, std::move(e)), t.min(), t.max());
1118         }
1119     } else if (e.type() != t) {
1120         // Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n)
1121         if (e.type().is_float() && !t.is_float() && t.bits() >= e.type().bits()) {
1122             e = max(std::move(e), t.min());  // min values turn out to be always representable
1123 
1124             // This line depends on t.max() rounding upward, which should always
1125             // be the case as it is one less than a representable value, thus
1126             // the one larger is always the closest.
1127             e = select(e >= cast(e.type(), t.max()), t.max(), cast(t, e));
1128         } else {
1129             Expr min_bound;
1130             if (!e.type().is_uint()) {
1131                 min_bound = lossless_cast(e.type(), t.min());
1132             }
1133             Expr max_bound = lossless_cast(e.type(), t.max());
1134 
1135             if (min_bound.defined() && max_bound.defined()) {
1136                 e = clamp(std::move(e), min_bound, max_bound);
1137             } else if (min_bound.defined()) {
1138                 e = max(std::move(e), min_bound);
1139             } else if (max_bound.defined()) {
1140                 e = min(std::move(e), max_bound);
1141             }
1142             e = cast(t, std::move(e));
1143         }
1144     }
1145     return e;
1146 }
1147 
select(Expr condition,Expr true_value,Expr false_value)1148 Expr select(Expr condition, Expr true_value, Expr false_value) {
1149     if (as_const_int(condition)) {
1150         // Why are you doing this? We'll preserve the select node until constant folding for you.
1151         condition = cast(Bool(), std::move(condition));
1152     }
1153 
1154     // Coerce int literals to the type of the other argument
1155     if (as_const_int(true_value)) {
1156         true_value = cast(false_value.type(), std::move(true_value));
1157     }
1158     if (as_const_int(false_value)) {
1159         false_value = cast(true_value.type(), std::move(false_value));
1160     }
1161 
1162     user_assert(condition.type().is_bool())
1163         << "The first argument to a select must be a boolean:\n"
1164         << "  " << condition << " has type " << condition.type() << "\n";
1165 
1166     user_assert(true_value.type() == false_value.type())
1167         << "The second and third arguments to a select do not have a matching type:\n"
1168         << "  " << true_value << " has type " << true_value.type() << "\n"
1169         << "  " << false_value << " has type " << false_value.type() << "\n";
1170 
1171     return Internal::Select::make(std::move(condition), std::move(true_value), std::move(false_value));
1172 }
1173 
tuple_select(const Tuple & condition,const Tuple & true_value,const Tuple & false_value)1174 Tuple tuple_select(const Tuple &condition, const Tuple &true_value, const Tuple &false_value) {
1175     user_assert(condition.size() == true_value.size() && true_value.size() == false_value.size())
1176         << "tuple_select() requires all Tuples to have identical sizes.";
1177     Tuple result(std::vector<Expr>(condition.size()));
1178     for (size_t i = 0; i < result.size(); i++) {
1179         result[i] = select(condition[i], true_value[i], false_value[i]);
1180     }
1181     return result;
1182 }
1183 
tuple_select(const Expr & condition,const Tuple & true_value,const Tuple & false_value)1184 Tuple tuple_select(const Expr &condition, const Tuple &true_value, const Tuple &false_value) {
1185     user_assert(true_value.size() == false_value.size())
1186         << "tuple_select() requires all Tuples to have identical sizes.";
1187     Tuple result(std::vector<Expr>(true_value.size()));
1188     for (size_t i = 0; i < result.size(); i++) {
1189         result[i] = select(condition, true_value[i], false_value[i]);
1190     }
1191     return result;
1192 }
1193 
mux(const Expr & id,const std::vector<Expr> & values)1194 Expr mux(const Expr &id, const std::vector<Expr> &values) {
1195     user_assert(values.size() >= 2) << "mux() only accepts values with size >= 2.\n";
1196     // Check if all the values have the same type.
1197     Type t = values[0].type();
1198     for (int i = 1; i < (int)values.size(); i++) {
1199         user_assert(values[i].type() == t) << "mux() requires all the values to have the same type.";
1200     }
1201     Expr result = values.back();
1202     for (int i = (int)values.size() - 2; i >= 0; i--) {
1203         result = select(id == i, values[i], result);
1204     }
1205     return result;
1206 }
1207 
mux(const Expr & id,const Tuple & tup)1208 Expr mux(const Expr &id, const Tuple &tup) {
1209     return mux(id, tup.as_vector());
1210 }
1211 
mux(const Expr & id,const std::initializer_list<Expr> & values)1212 Expr mux(const Expr &id, const std::initializer_list<Expr> &values) {
1213     return mux(id, std::vector<Expr>(values));
1214 }
1215 
unsafe_promise_clamped(const Expr & value,const Expr & min,const Expr & max)1216 Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
1217     user_assert(value.defined()) << "unsafe_promise_clamped with undefined value.\n";
1218     Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
1219     Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
1220 
1221     // Min and max are allowed to be undefined with the meaning of no bound on that side.
1222 
1223     return Internal::Call::make(value.type(),
1224                                 Internal::Call::unsafe_promise_clamped,
1225                                 {value, n_min_val, n_max_val},
1226                                 Internal::Call::Intrinsic);
1227 }
1228 
1229 namespace Internal {
promise_clamped(const Expr & value,const Expr & min,const Expr & max)1230 Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
1231     internal_assert(value.defined()) << "promise_clamped with undefined value.\n";
1232     Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
1233     Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
1234 
1235     // Min and max are allowed to be undefined with the meaning of no bound on that side.
1236     return Internal::Call::make(value.type(),
1237                                 Internal::Call::promise_clamped,
1238                                 {value, n_min_val, n_max_val},
1239                                 Internal::Call::Intrinsic);
1240 }
1241 }  // namespace Internal
1242 
operator +(Expr a,Expr b)1243 Expr operator+(Expr a, Expr b) {
1244     user_assert(a.defined() && b.defined()) << "operator+ of undefined Expr\n";
1245     Internal::match_types(a, b);
1246     return Internal::Add::make(std::move(a), std::move(b));
1247 }
1248 
operator +(Expr a,int b)1249 Expr operator+(Expr a, int b) {
1250     user_assert(a.defined()) << "operator+ of undefined Expr\n";
1251     Type t = a.type();
1252     Internal::check_representable(t, b);
1253     return Internal::Add::make(std::move(a), Internal::make_const(t, b));
1254 }
1255 
operator +(int a,Expr b)1256 Expr operator+(int a, Expr b) {
1257     user_assert(b.defined()) << "operator+ of undefined Expr\n";
1258     Type t = b.type();
1259     Internal::check_representable(t, a);
1260     return Internal::Add::make(Internal::make_const(t, a), std::move(b));
1261 }
1262 
operator +=(Expr & a,Expr b)1263 Expr &operator+=(Expr &a, Expr b) {
1264     user_assert(a.defined() && b.defined()) << "operator+= of undefined Expr\n";
1265     Type t = a.type();
1266     a = Internal::Add::make(std::move(a), cast(t, std::move(b)));
1267     return a;
1268 }
1269 
operator -(Expr a,Expr b)1270 Expr operator-(Expr a, Expr b) {
1271     user_assert(a.defined() && b.defined()) << "operator- of undefined Expr\n";
1272     Internal::match_types(a, b);
1273     return Internal::Sub::make(std::move(a), std::move(b));
1274 }
1275 
operator -(Expr a,int b)1276 Expr operator-(Expr a, int b) {
1277     user_assert(a.defined()) << "operator- of undefined Expr\n";
1278     Type t = a.type();
1279     Internal::check_representable(t, b);
1280     return Internal::Sub::make(std::move(a), Internal::make_const(t, b));
1281 }
1282 
operator -(int a,Expr b)1283 Expr operator-(int a, Expr b) {
1284     user_assert(b.defined()) << "operator- of undefined Expr\n";
1285     Type t = b.type();
1286     Internal::check_representable(t, a);
1287     return Internal::Sub::make(Internal::make_const(t, a), std::move(b));
1288 }
1289 
operator -(Expr a)1290 Expr operator-(Expr a) {
1291     user_assert(a.defined()) << "operator- of undefined Expr\n";
1292     Type t = a.type();
1293     return Internal::Sub::make(Internal::make_zero(t), std::move(a));
1294 }
1295 
operator -=(Expr & a,Expr b)1296 Expr &operator-=(Expr &a, Expr b) {
1297     user_assert(a.defined() && b.defined()) << "operator-= of undefined Expr\n";
1298     Type t = a.type();
1299     a = Internal::Sub::make(std::move(a), cast(t, std::move(b)));
1300     return a;
1301 }
1302 
operator *(Expr a,Expr b)1303 Expr operator*(Expr a, Expr b) {
1304     user_assert(a.defined() && b.defined()) << "operator* of undefined Expr\n";
1305     Internal::match_types(a, b);
1306     return Internal::Mul::make(std::move(a), std::move(b));
1307 }
1308 
operator *(Expr a,int b)1309 Expr operator*(Expr a, int b) {
1310     user_assert(a.defined()) << "operator* of undefined Expr\n";
1311     Type t = a.type();
1312     Internal::check_representable(t, b);
1313     return Internal::Mul::make(std::move(a), Internal::make_const(t, b));
1314 }
1315 
operator *(int a,Expr b)1316 Expr operator*(int a, Expr b) {
1317     user_assert(b.defined()) << "operator* of undefined Expr\n";
1318     Type t = b.type();
1319     Internal::check_representable(t, a);
1320     return Internal::Mul::make(Internal::make_const(t, a), std::move(b));
1321 }
1322 
operator *=(Expr & a,Expr b)1323 Expr &operator*=(Expr &a, Expr b) {
1324     user_assert(a.defined() && b.defined()) << "operator*= of undefined Expr\n";
1325     Type t = a.type();
1326     a = Internal::Mul::make(std::move(a), cast(t, std::move(b)));
1327     return a;
1328 }
1329 
operator /(Expr a,Expr b)1330 Expr operator/(Expr a, Expr b) {
1331     user_assert(a.defined() && b.defined()) << "operator/ of undefined Expr\n";
1332     Internal::match_types(a, b);
1333     return Internal::Div::make(std::move(a), std::move(b));
1334 }
1335 
operator /=(Expr & a,Expr b)1336 Expr &operator/=(Expr &a, Expr b) {
1337     user_assert(a.defined() && b.defined()) << "operator/= of undefined Expr\n";
1338     Type t = a.type();
1339     a = Internal::Div::make(std::move(a), cast(t, std::move(b)));
1340     return a;
1341 }
1342 
operator /(Expr a,int b)1343 Expr operator/(Expr a, int b) {
1344     user_assert(a.defined()) << "operator/ of undefined Expr\n";
1345     Type t = a.type();
1346     Internal::check_representable(t, b);
1347     return Internal::Div::make(std::move(a), Internal::make_const(t, b));
1348 }
1349 
operator /(int a,Expr b)1350 Expr operator/(int a, Expr b) {
1351     user_assert(b.defined()) << "operator- of undefined Expr\n";
1352     Type t = b.type();
1353     Internal::check_representable(t, a);
1354     return Internal::Div::make(Internal::make_const(t, a), std::move(b));
1355 }
1356 
operator %(Expr a,Expr b)1357 Expr operator%(Expr a, Expr b) {
1358     user_assert(a.defined() && b.defined()) << "operator% of undefined Expr\n";
1359     Internal::match_types(a, b);
1360     return Internal::Mod::make(std::move(a), std::move(b));
1361 }
1362 
operator %(Expr a,int b)1363 Expr operator%(Expr a, int b) {
1364     user_assert(a.defined()) << "operator% of undefined Expr\n";
1365     Type t = a.type();
1366     Internal::check_representable(t, b);
1367     return Internal::Mod::make(std::move(a), Internal::make_const(t, b));
1368 }
1369 
operator %(int a,Expr b)1370 Expr operator%(int a, Expr b) {
1371     user_assert(b.defined()) << "operator% of undefined Expr\n";
1372     Type t = b.type();
1373     Internal::check_representable(t, a);
1374     return Internal::Mod::make(Internal::make_const(t, a), std::move(b));
1375 }
1376 
operator >(Expr a,Expr b)1377 Expr operator>(Expr a, Expr b) {
1378     user_assert(a.defined() && b.defined()) << "operator> of undefined Expr\n";
1379     Internal::match_types(a, b);
1380     return Internal::GT::make(std::move(a), std::move(b));
1381 }
1382 
operator >(Expr a,int b)1383 Expr operator>(Expr a, int b) {
1384     user_assert(a.defined()) << "operator> of undefined Expr\n";
1385     Type t = a.type();
1386     Internal::check_representable(t, b);
1387     return Internal::GT::make(std::move(a), Internal::make_const(t, b));
1388 }
1389 
operator >(int a,Expr b)1390 Expr operator>(int a, Expr b) {
1391     user_assert(b.defined()) << "operator> of undefined Expr\n";
1392     Type t = b.type();
1393     Internal::check_representable(t, a);
1394     return Internal::GT::make(Internal::make_const(t, a), std::move(b));
1395 }
1396 
operator <(Expr a,Expr b)1397 Expr operator<(Expr a, Expr b) {
1398     user_assert(a.defined() && b.defined()) << "operator< of undefined Expr\n";
1399     Internal::match_types(a, b);
1400     return Internal::LT::make(std::move(a), std::move(b));
1401 }
1402 
operator <(Expr a,int b)1403 Expr operator<(Expr a, int b) {
1404     user_assert(a.defined()) << "operator< of undefined Expr\n";
1405     Type t = a.type();
1406     Internal::check_representable(t, b);
1407     return Internal::LT::make(std::move(a), Internal::make_const(t, b));
1408 }
1409 
operator <(int a,Expr b)1410 Expr operator<(int a, Expr b) {
1411     user_assert(b.defined()) << "operator< of undefined Expr\n";
1412     Type t = b.type();
1413     Internal::check_representable(t, a);
1414     return Internal::LT::make(Internal::make_const(t, a), std::move(b));
1415 }
1416 
operator <=(Expr a,Expr b)1417 Expr operator<=(Expr a, Expr b) {
1418     user_assert(a.defined() && b.defined()) << "operator<= of undefined Expr\n";
1419     Internal::match_types(a, b);
1420     return Internal::LE::make(std::move(a), std::move(b));
1421 }
1422 
operator <=(Expr a,int b)1423 Expr operator<=(Expr a, int b) {
1424     user_assert(a.defined()) << "operator<= of undefined Expr\n";
1425     Type t = a.type();
1426     Internal::check_representable(t, b);
1427     return Internal::LE::make(std::move(a), Internal::make_const(t, b));
1428 }
1429 
operator <=(int a,Expr b)1430 Expr operator<=(int a, Expr b) {
1431     user_assert(b.defined()) << "operator<= of undefined Expr\n";
1432     Type t = b.type();
1433     Internal::check_representable(t, a);
1434     return Internal::LE::make(Internal::make_const(t, a), std::move(b));
1435 }
1436 
operator >=(Expr a,Expr b)1437 Expr operator>=(Expr a, Expr b) {
1438     user_assert(a.defined() && b.defined()) << "operator>= of undefined Expr\n";
1439     Internal::match_types(a, b);
1440     return Internal::GE::make(std::move(a), std::move(b));
1441 }
1442 
operator >=(const Expr & a,int b)1443 Expr operator>=(const Expr &a, int b) {
1444     user_assert(a.defined()) << "operator>= of undefined Expr\n";
1445     Type t = a.type();
1446     Internal::check_representable(t, b);
1447     return Internal::GE::make(a, Internal::make_const(t, b));
1448 }
1449 
operator >=(int a,const Expr & b)1450 Expr operator>=(int a, const Expr &b) {
1451     user_assert(b.defined()) << "operator>= of undefined Expr\n";
1452     Type t = b.type();
1453     Internal::check_representable(t, a);
1454     return Internal::GE::make(Internal::make_const(t, a), b);
1455 }
1456 
operator ==(Expr a,Expr b)1457 Expr operator==(Expr a, Expr b) {
1458     user_assert(a.defined() && b.defined()) << "operator== of undefined Expr\n";
1459     Internal::match_types(a, b);
1460     return Internal::EQ::make(std::move(a), std::move(b));
1461 }
1462 
operator ==(Expr a,int b)1463 Expr operator==(Expr a, int b) {
1464     user_assert(a.defined()) << "operator== of undefined Expr\n";
1465     Type t = a.type();
1466     Internal::check_representable(t, b);
1467     return Internal::EQ::make(std::move(a), Internal::make_const(t, b));
1468 }
1469 
operator ==(int a,Expr b)1470 Expr operator==(int a, Expr b) {
1471     user_assert(b.defined()) << "operator== of undefined Expr\n";
1472     Type t = b.type();
1473     Internal::check_representable(t, a);
1474     return Internal::EQ::make(Internal::make_const(t, a), std::move(b));
1475 }
1476 
operator !=(Expr a,Expr b)1477 Expr operator!=(Expr a, Expr b) {
1478     user_assert(a.defined() && b.defined()) << "operator!= of undefined Expr\n";
1479     Internal::match_types(a, b);
1480     return Internal::NE::make(std::move(a), std::move(b));
1481 }
1482 
operator !=(Expr a,int b)1483 Expr operator!=(Expr a, int b) {
1484     user_assert(a.defined()) << "operator!= of undefined Expr\n";
1485     Type t = a.type();
1486     Internal::check_representable(t, b);
1487     return Internal::NE::make(std::move(a), Internal::make_const(t, b));
1488 }
1489 
operator !=(int a,Expr b)1490 Expr operator!=(int a, Expr b) {
1491     user_assert(b.defined()) << "operator!= of undefined Expr\n";
1492     Type t = b.type();
1493     Internal::check_representable(t, a);
1494     return Internal::NE::make(Internal::make_const(t, a), std::move(b));
1495 }
1496 
operator &&(Expr a,Expr b)1497 Expr operator&&(Expr a, Expr b) {
1498     Internal::match_types(a, b);
1499     return Internal::And::make(std::move(a), std::move(b));
1500 }
1501 
operator &&(Expr a,bool b)1502 Expr operator&&(Expr a, bool b) {
1503     internal_assert(a.defined()) << "operator&& of undefined Expr\n";
1504     internal_assert(a.type().is_bool()) << "operator&& of Expr of type " << a.type() << "\n";
1505     if (b) {
1506         return a;
1507     } else {
1508         return Internal::make_zero(a.type());
1509     }
1510 }
1511 
operator &&(bool a,Expr b)1512 Expr operator&&(bool a, Expr b) {
1513     return std::move(b) && a;
1514 }
1515 
operator ||(Expr a,Expr b)1516 Expr operator||(Expr a, Expr b) {
1517     Internal::match_types(a, b);
1518     return Internal::Or::make(std::move(a), std::move(b));
1519 }
1520 
operator ||(Expr a,bool b)1521 Expr operator||(Expr a, bool b) {
1522     internal_assert(a.defined()) << "operator|| of undefined Expr\n";
1523     internal_assert(a.type().is_bool()) << "operator|| of Expr of type " << a.type() << "\n";
1524     if (b) {
1525         return Internal::make_one(a.type());
1526     } else {
1527         return a;
1528     }
1529 }
1530 
operator ||(bool a,Expr b)1531 Expr operator||(bool a, Expr b) {
1532     return std::move(b) || a;
1533 }
1534 
operator !(Expr a)1535 Expr operator!(Expr a) {
1536     return Internal::Not::make(std::move(a));
1537 }
1538 
max(Expr a,Expr b)1539 Expr max(Expr a, Expr b) {
1540     user_assert(a.defined() && b.defined())
1541         << "max of undefined Expr\n";
1542     Internal::match_types(a, b);
1543     return Internal::Max::make(std::move(a), std::move(b));
1544 }
1545 
max(Expr a,int b)1546 Expr max(Expr a, int b) {
1547     user_assert(a.defined()) << "max of undefined Expr\n";
1548     Type t = a.type();
1549     Internal::check_representable(t, b);
1550     return Internal::Max::make(std::move(a), Internal::make_const(t, b));
1551 }
1552 
max(int a,Expr b)1553 Expr max(int a, Expr b) {
1554     user_assert(b.defined()) << "max of undefined Expr\n";
1555     Type t = b.type();
1556     Internal::check_representable(t, a);
1557     return Internal::Max::make(Internal::make_const(t, a), std::move(b));
1558 }
1559 
min(Expr a,Expr b)1560 Expr min(Expr a, Expr b) {
1561     user_assert(a.defined() && b.defined())
1562         << "min of undefined Expr\n";
1563     Internal::match_types(a, b);
1564     return Internal::Min::make(std::move(a), std::move(b));
1565 }
1566 
min(Expr a,int b)1567 Expr min(Expr a, int b) {
1568     user_assert(a.defined()) << "max of undefined Expr\n";
1569     Type t = a.type();
1570     Internal::check_representable(t, b);
1571     return Internal::Min::make(std::move(a), Internal::make_const(t, b));
1572 }
1573 
min(int a,Expr b)1574 Expr min(int a, Expr b) {
1575     user_assert(b.defined()) << "max of undefined Expr\n";
1576     Type t = b.type();
1577     Internal::check_representable(t, a);
1578     return Internal::Min::make(Internal::make_const(t, a), std::move(b));
1579 }
1580 
cast(Type t,Expr a)1581 Expr cast(Type t, Expr a) {
1582     user_assert(a.defined()) << "cast of undefined Expr\n";
1583     if (a.type() == t) {
1584         return a;
1585     }
1586 
1587     if (t.is_handle() && !a.type().is_handle()) {
1588         user_error << "Can't cast \"" << a << "\" to a handle. "
1589                    << "The only legal cast from scalar types to a handle is: "
1590                    << "reinterpret(Handle(), cast<uint64_t>(" << a << "));\n";
1591     } else if (a.type().is_handle() && !t.is_handle()) {
1592         user_error << "Can't cast handle \"" << a << "\" to type " << t << ". "
1593                    << "The only legal cast from handles to scalar types is: "
1594                    << "reinterpret(UInt(64), " << a << ");\n";
1595     }
1596 
1597     // Fold constants early
1598     if (const int64_t *i = as_const_int(a)) {
1599         return Internal::make_const(t, *i);
1600     }
1601     if (const uint64_t *u = as_const_uint(a)) {
1602         return Internal::make_const(t, *u);
1603     }
1604     if (const double *f = as_const_float(a)) {
1605         return Internal::make_const(t, *f);
1606     }
1607 
1608     if (t.is_vector()) {
1609         if (a.type().is_scalar()) {
1610             return Internal::Broadcast::make(cast(t.element_of(), std::move(a)), t.lanes());
1611         } else if (const Internal::Broadcast *b = a.as<Internal::Broadcast>()) {
1612             internal_assert(b->lanes == t.lanes());
1613             return Internal::Broadcast::make(cast(t.element_of(), b->value), t.lanes());
1614         }
1615     }
1616     return Internal::Cast::make(t, std::move(a));
1617 }
1618 
clamp(Expr a,const Expr & min_val,const Expr & max_val)1619 Expr clamp(Expr a, const Expr &min_val, const Expr &max_val) {
1620     user_assert(a.defined() && min_val.defined() && max_val.defined())
1621         << "clamp of undefined Expr\n";
1622     Expr n_min_val = lossless_cast(a.type(), min_val);
1623     user_assert(n_min_val.defined())
1624         << "Type mismatch in call to clamp. First argument ("
1625         << a << ") has type " << a.type() << ", but second argument ("
1626         << min_val << ") has type " << min_val.type() << ". Use an explicit cast.\n";
1627     Expr n_max_val = lossless_cast(a.type(), max_val);
1628     user_assert(n_max_val.defined())
1629         << "Type mismatch in call to clamp. First argument ("
1630         << a << ") has type " << a.type() << ", but third argument ("
1631         << max_val << ") has type " << max_val.type() << ". Use an explicit cast.\n";
1632     return Internal::Max::make(Internal::Min::make(std::move(a), std::move(n_max_val)), std::move(n_min_val));
1633 }
1634 
abs(Expr a)1635 Expr abs(Expr a) {
1636     user_assert(a.defined())
1637         << "abs of undefined Expr\n";
1638     Type t = a.type();
1639     if (t.is_uint()) {
1640         user_warning << "Warning: abs of an unsigned type is a no-op\n";
1641         return a;
1642     }
1643     return Internal::Call::make(t.with_code(t.is_int() ? Type::UInt : t.code()),
1644                                 Internal::Call::abs, {std::move(a)}, Internal::Call::PureIntrinsic);
1645 }
1646 
absd(Expr a,Expr b)1647 Expr absd(Expr a, Expr b) {
1648     user_assert(a.defined() && b.defined()) << "absd of undefined Expr\n";
1649     Internal::match_types(a, b);
1650     Type t = a.type();
1651 
1652     if (t.is_float()) {
1653         // Floats can just use abs.
1654         return abs(std::move(a) - std::move(b));
1655     }
1656 
1657     // The argument may be signed, but the return type is unsigned.
1658     return Internal::Call::make(t.with_code(t.is_int() ? Type::UInt : t.code()),
1659                                 Internal::Call::absd, {std::move(a), std::move(b)},
1660                                 Internal::Call::PureIntrinsic);
1661 }
1662 
sin(Expr x)1663 Expr sin(Expr x) {
1664     user_assert(x.defined()) << "sin of undefined Expr\n";
1665     if (x.type() == Float(64)) {
1666         return Internal::Call::make(Float(64), "sin_f64", {std::move(x)}, Internal::Call::PureExtern);
1667     } else if (x.type() == Float(16)) {
1668         return Internal::Call::make(Float(16), "sin_f16", {std::move(x)}, Internal::Call::PureExtern);
1669     } else {
1670         return Internal::Call::make(Float(32), "sin_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1671     }
1672 }
1673 
asin(Expr x)1674 Expr asin(Expr x) {
1675     user_assert(x.defined()) << "asin of undefined Expr\n";
1676     if (x.type() == Float(64)) {
1677         return Internal::Call::make(Float(64), "asin_f64", {std::move(x)}, Internal::Call::PureExtern);
1678     } else if (x.type() == Float(16)) {
1679         return Internal::Call::make(Float(16), "asin_f16", {std::move(x)}, Internal::Call::PureExtern);
1680     } else {
1681         return Internal::Call::make(Float(32), "asin_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1682     }
1683 }
1684 
cos(Expr x)1685 Expr cos(Expr x) {
1686     user_assert(x.defined()) << "cos of undefined Expr\n";
1687     if (x.type() == Float(64)) {
1688         return Internal::Call::make(Float(64), "cos_f64", {std::move(x)}, Internal::Call::PureExtern);
1689     } else if (x.type() == Float(16)) {
1690         return Internal::Call::make(Float(16), "cos_f16", {std::move(x)}, Internal::Call::PureExtern);
1691     } else {
1692         return Internal::Call::make(Float(32), "cos_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1693     }
1694 }
1695 
acos(Expr x)1696 Expr acos(Expr x) {
1697     user_assert(x.defined()) << "acos of undefined Expr\n";
1698     if (x.type() == Float(64)) {
1699         return Internal::Call::make(Float(64), "acos_f64", {std::move(x)}, Internal::Call::PureExtern);
1700     } else if (x.type() == Float(16)) {
1701         return Internal::Call::make(Float(16), "acos_f16", {std::move(x)}, Internal::Call::PureExtern);
1702     } else {
1703         return Internal::Call::make(Float(32), "acos_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1704     }
1705 }
1706 
tan(Expr x)1707 Expr tan(Expr x) {
1708     user_assert(x.defined()) << "tan of undefined Expr\n";
1709     if (x.type() == Float(64)) {
1710         return Internal::Call::make(Float(64), "tan_f64", {std::move(x)}, Internal::Call::PureExtern);
1711     } else if (x.type() == Float(16)) {
1712         return Internal::Call::make(Float(16), "tan_f16", {std::move(x)}, Internal::Call::PureExtern);
1713     } else {
1714         return Internal::Call::make(Float(32), "tan_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1715     }
1716 }
1717 
atan(Expr x)1718 Expr atan(Expr x) {
1719     user_assert(x.defined()) << "atan of undefined Expr\n";
1720     if (x.type() == Float(64)) {
1721         return Internal::Call::make(Float(64), "atan_f64", {std::move(x)}, Internal::Call::PureExtern);
1722     } else if (x.type() == Float(16)) {
1723         return Internal::Call::make(Float(16), "atan_f16", {std::move(x)}, Internal::Call::PureExtern);
1724     } else {
1725         return Internal::Call::make(Float(32), "atan_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1726     }
1727 }
1728 
atan2(Expr y,Expr x)1729 Expr atan2(Expr y, Expr x) {
1730     user_assert(x.defined() && y.defined()) << "atan2 of undefined Expr\n";
1731 
1732     if (y.type() == Float(64)) {
1733         x = cast<double>(x);
1734         return Internal::Call::make(Float(64), "atan2_f64", {std::move(y), std::move(x)}, Internal::Call::PureExtern);
1735     } else if (y.type() == Float(16)) {
1736         x = cast<float16_t>(x);
1737         return Internal::Call::make(Float(16), "atan2_f16", {std::move(y), std::move(x)}, Internal::Call::PureExtern);
1738     } else {
1739         y = cast<float>(y);
1740         x = cast<float>(x);
1741         return Internal::Call::make(Float(32), "atan2_f32", {std::move(y), std::move(x)}, Internal::Call::PureExtern);
1742     }
1743 }
1744 
sinh(Expr x)1745 Expr sinh(Expr x) {
1746     user_assert(x.defined()) << "sinh of undefined Expr\n";
1747     if (x.type() == Float(64)) {
1748         return Internal::Call::make(Float(64), "sinh_f64", {std::move(x)}, Internal::Call::PureExtern);
1749     } else if (x.type() == Float(16)) {
1750         return Internal::Call::make(Float(16), "sinh_f16", {std::move(x)}, Internal::Call::PureExtern);
1751     } else {
1752         return Internal::Call::make(Float(32), "sinh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1753     }
1754 }
1755 
asinh(Expr x)1756 Expr asinh(Expr x) {
1757     user_assert(x.defined()) << "asinh of undefined Expr\n";
1758     if (x.type() == Float(64)) {
1759         return Internal::Call::make(Float(64), "asinh_f64", {std::move(x)}, Internal::Call::PureExtern);
1760     } else if (x.type() == Float(16)) {
1761         return Internal::Call::make(Float(16), "asinh_f16", {std::move(x)}, Internal::Call::PureExtern);
1762     } else {
1763         return Internal::Call::make(Float(32), "asinh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1764     }
1765 }
1766 
cosh(Expr x)1767 Expr cosh(Expr x) {
1768     user_assert(x.defined()) << "cosh of undefined Expr\n";
1769     if (x.type() == Float(64)) {
1770         return Internal::Call::make(Float(64), "cosh_f64", {std::move(x)}, Internal::Call::PureExtern);
1771     } else if (x.type() == Float(16)) {
1772         return Internal::Call::make(Float(16), "cosh_f16", {std::move(x)}, Internal::Call::PureExtern);
1773     } else {
1774         return Internal::Call::make(Float(32), "cosh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1775     }
1776 }
1777 
acosh(Expr x)1778 Expr acosh(Expr x) {
1779     user_assert(x.defined()) << "acosh of undefined Expr\n";
1780     if (x.type() == Float(64)) {
1781         return Internal::Call::make(Float(64), "acosh_f64", {std::move(x)}, Internal::Call::PureExtern);
1782     } else if (x.type() == Float(16)) {
1783         return Internal::Call::make(Float(16), "acosh_f16", {std::move(x)}, Internal::Call::PureExtern);
1784     } else {
1785         return Internal::Call::make(Float(32), "acosh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1786     }
1787 }
1788 
tanh(Expr x)1789 Expr tanh(Expr x) {
1790     user_assert(x.defined()) << "tanh of undefined Expr\n";
1791     if (x.type() == Float(64)) {
1792         return Internal::Call::make(Float(64), "tanh_f64", {std::move(x)}, Internal::Call::PureExtern);
1793     } else if (x.type() == Float(16)) {
1794         return Internal::Call::make(Float(16), "tanh_f16", {std::move(x)}, Internal::Call::PureExtern);
1795     } else {
1796         return Internal::Call::make(Float(32), "tanh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1797     }
1798 }
1799 
atanh(Expr x)1800 Expr atanh(Expr x) {
1801     user_assert(x.defined()) << "atanh of undefined Expr\n";
1802     if (x.type() == Float(64)) {
1803         return Internal::Call::make(Float(64), "atanh_f64", {std::move(x)}, Internal::Call::PureExtern);
1804     } else if (x.type() == Float(16)) {
1805         return Internal::Call::make(Float(16), "atanh_f16", {std::move(x)}, Internal::Call::PureExtern);
1806     } else {
1807         return Internal::Call::make(Float(32), "atanh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1808     }
1809 }
1810 
sqrt(Expr x)1811 Expr sqrt(Expr x) {
1812     user_assert(x.defined()) << "sqrt of undefined Expr\n";
1813     if (x.type() == Float(64)) {
1814         return Internal::Call::make(Float(64), "sqrt_f64", {std::move(x)}, Internal::Call::PureExtern);
1815     } else if (x.type() == Float(16)) {
1816         return Internal::Call::make(Float(16), "sqrt_f16", {std::move(x)}, Internal::Call::PureExtern);
1817     } else {
1818         return Internal::Call::make(Float(32), "sqrt_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1819     }
1820 }
1821 
hypot(const Expr & x,const Expr & y)1822 Expr hypot(const Expr &x, const Expr &y) {
1823     return sqrt(x * x + y * y);
1824 }
1825 
exp(Expr x)1826 Expr exp(Expr x) {
1827     user_assert(x.defined()) << "exp of undefined Expr\n";
1828     if (x.type() == Float(64)) {
1829         return Internal::Call::make(Float(64), "exp_f64", {std::move(x)}, Internal::Call::PureExtern);
1830     } else if (x.type() == Float(16)) {
1831         return Internal::Call::make(Float(16), "exp_f16", {std::move(x)}, Internal::Call::PureExtern);
1832     } else {
1833         return Internal::Call::make(Float(32), "exp_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1834     }
1835 }
1836 
log(Expr x)1837 Expr log(Expr x) {
1838     user_assert(x.defined()) << "log of undefined Expr\n";
1839     if (x.type() == Float(64)) {
1840         return Internal::Call::make(Float(64), "log_f64", {std::move(x)}, Internal::Call::PureExtern);
1841     } else if (x.type() == Float(16)) {
1842         return Internal::Call::make(Float(16), "log_f16", {std::move(x)}, Internal::Call::PureExtern);
1843     } else {
1844         return Internal::Call::make(Float(32), "log_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1845     }
1846 }
1847 
pow(Expr x,Expr y)1848 Expr pow(Expr x, Expr y) {
1849     user_assert(x.defined() && y.defined()) << "pow of undefined Expr\n";
1850 
1851     if (const int64_t *i = as_const_int(y)) {
1852         return raise_to_integer_power(std::move(x), *i);
1853     }
1854 
1855     if (x.type() == Float(64)) {
1856         y = cast<double>(std::move(y));
1857         return Internal::Call::make(Float(64), "pow_f64", {std::move(x), std::move(y)}, Internal::Call::PureExtern);
1858     } else if (x.type() == Float(16)) {
1859         y = cast<float16_t>(std::move(y));
1860         return Internal::Call::make(Float(16), "pow_f16", {std::move(x), std::move(y)}, Internal::Call::PureExtern);
1861     } else {
1862         x = cast<float>(std::move(x));
1863         y = cast<float>(std::move(y));
1864         return Internal::Call::make(Float(32), "pow_f32", {std::move(x), std::move(y)}, Internal::Call::PureExtern);
1865     }
1866 }
1867 
erf(const Expr & x)1868 Expr erf(const Expr &x) {
1869     user_assert(x.defined()) << "erf of undefined Expr\n";
1870     user_assert(x.type() == Float(32)) << "erf only takes float arguments\n";
1871     return Internal::halide_erf(x);
1872 }
1873 
fast_pow(Expr x,Expr y)1874 Expr fast_pow(Expr x, Expr y) {
1875     if (const int64_t *i = as_const_int(y)) {
1876         return raise_to_integer_power(std::move(x), *i);
1877     }
1878 
1879     x = cast<float>(std::move(x));
1880     y = cast<float>(std::move(y));
1881     return select(x == 0.0f, 0.0f, fast_exp(fast_log(x) * std::move(y)));
1882 }
1883 
fast_inverse(Expr x)1884 Expr fast_inverse(Expr x) {
1885     user_assert(x.type() == Float(32)) << "fast_inverse only takes float arguments\n";
1886     Type t = x.type();
1887     return Internal::Call::make(t, "fast_inverse_f32", {std::move(x)}, Internal::Call::PureExtern);
1888 }
1889 
fast_inverse_sqrt(Expr x)1890 Expr fast_inverse_sqrt(Expr x) {
1891     user_assert(x.type() == Float(32)) << "fast_inverse_sqrt only takes float arguments\n";
1892     Type t = x.type();
1893     return Internal::Call::make(t, "fast_inverse_sqrt_f32", {std::move(x)}, Internal::Call::PureExtern);
1894 }
1895 
floor(Expr x)1896 Expr floor(Expr x) {
1897     user_assert(x.defined()) << "floor of undefined Expr\n";
1898     Type t = x.type();
1899     if (t.element_of() == Float(64)) {
1900         return Internal::Call::make(t, "floor_f64", {std::move(x)}, Internal::Call::PureExtern);
1901     } else if (t.element_of() == Float(16)) {
1902         return Internal::Call::make(t, "floor_f16", {std::move(x)}, Internal::Call::PureExtern);
1903     } else {
1904         t = Float(32, t.lanes());
1905         if (t.is_int() || t.is_uint()) {
1906             // Already an integer
1907             return cast(t, std::move(x));
1908         } else {
1909             return Internal::Call::make(t, "floor_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1910         }
1911     }
1912 }
1913 
ceil(Expr x)1914 Expr ceil(Expr x) {
1915     user_assert(x.defined()) << "ceil of undefined Expr\n";
1916     Type t = x.type();
1917     if (t.element_of() == Float(64)) {
1918         return Internal::Call::make(t, "ceil_f64", {std::move(x)}, Internal::Call::PureExtern);
1919     } else if (x.type().element_of() == Float(16)) {
1920         return Internal::Call::make(t, "ceil_f16", {std::move(x)}, Internal::Call::PureExtern);
1921     } else {
1922         t = Float(32, t.lanes());
1923         if (t.is_int() || t.is_uint()) {
1924             // Already an integer
1925             return cast(t, std::move(x));
1926         } else {
1927             return Internal::Call::make(t, "ceil_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1928         }
1929     }
1930 }
1931 
round(Expr x)1932 Expr round(Expr x) {
1933     user_assert(x.defined()) << "round of undefined Expr\n";
1934     Type t = x.type();
1935     if (t.element_of() == Float(64)) {
1936         return Internal::Call::make(t, "round_f64", {std::move(x)}, Internal::Call::PureExtern);
1937     } else if (t.element_of() == Float(16)) {
1938         return Internal::Call::make(t, "round_f16", {std::move(x)}, Internal::Call::PureExtern);
1939     } else {
1940         t = Float(32, t.lanes());
1941         if (t.is_int() || t.is_uint()) {
1942             // Already an integer
1943             return cast(t, std::move(x));
1944         } else {
1945             return Internal::Call::make(t, "round_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1946         }
1947     }
1948 }
1949 
trunc(Expr x)1950 Expr trunc(Expr x) {
1951     user_assert(x.defined()) << "trunc of undefined Expr\n";
1952     Type t = x.type();
1953     if (t.element_of() == Float(64)) {
1954         return Internal::Call::make(t, "trunc_f64", {std::move(x)}, Internal::Call::PureExtern);
1955     } else if (t.element_of() == Float(16)) {
1956         return Internal::Call::make(t, "trunc_f16", {std::move(x)}, Internal::Call::PureExtern);
1957     } else {
1958         t = Float(32, t.lanes());
1959         if (t.is_int() || t.is_uint()) {
1960             // Already an integer
1961             return cast(t, std::move(x));
1962         } else {
1963             return Internal::Call::make(t, "trunc_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1964         }
1965     }
1966 }
1967 
is_nan(Expr x)1968 Expr is_nan(Expr x) {
1969     user_assert(x.defined()) << "is_nan of undefined Expr\n";
1970     user_assert(x.type().is_float()) << "is_nan only works for float";
1971     Type t = Bool(x.type().lanes());
1972     if (!is_const(x)) {
1973         x = strict_float(x);
1974     }
1975     if (x.type().element_of() == Float(64)) {
1976         return Internal::Call::make(t, "is_nan_f64", {std::move(x)}, Internal::Call::PureExtern);
1977     } else if (x.type().element_of() == Float(16)) {
1978         return Internal::Call::make(t, "is_nan_f16", {std::move(x)}, Internal::Call::PureExtern);
1979     } else {
1980         Type ft = Float(32, t.lanes());
1981         return Internal::Call::make(t, "is_nan_f32", {cast(ft, std::move(x))}, Internal::Call::PureExtern);
1982     }
1983 }
1984 
is_inf(Expr x)1985 Expr is_inf(Expr x) {
1986     user_assert(x.defined()) << "is_inf of undefined Expr\n";
1987     user_assert(x.type().is_float()) << "is_inf only works for float";
1988     Type t = Bool(x.type().lanes());
1989     if (!is_const(x)) {
1990         x = strict_float(x);
1991     }
1992     if (x.type().element_of() == Float(64)) {
1993         return Internal::Call::make(t, "is_inf_f64", {std::move(x)}, Internal::Call::PureExtern);
1994     } else if (x.type().element_of() == Float(16)) {
1995         return Internal::Call::make(t, "is_inf_f16", {std::move(x)}, Internal::Call::PureExtern);
1996     } else {
1997         Type ft = Float(32, t.lanes());
1998         return Internal::Call::make(t, "is_inf_f32", {cast(ft, std::move(x))}, Internal::Call::PureExtern);
1999     }
2000 }
2001 
is_finite(Expr x)2002 Expr is_finite(Expr x) {
2003     user_assert(x.defined()) << "is_finite of undefined Expr\n";
2004     user_assert(x.type().is_float()) << "is_finite only works for float";
2005     Type t = Bool(x.type().lanes());
2006     if (!is_const(x)) {
2007         x = strict_float(x);
2008     }
2009     if (x.type().element_of() == Float(64)) {
2010         return Internal::Call::make(t, "is_finite_f64", {std::move(x)}, Internal::Call::PureExtern);
2011     } else if (x.type().element_of() == Float(16)) {
2012         return Internal::Call::make(t, "is_finite_f16", {std::move(x)}, Internal::Call::PureExtern);
2013     } else {
2014         Type ft = Float(32, t.lanes());
2015         return Internal::Call::make(t, "is_finite_f32", {cast(ft, std::move(x))}, Internal::Call::PureExtern);
2016     }
2017 }
2018 
fract(const Expr & x)2019 Expr fract(const Expr &x) {
2020     user_assert(x.defined()) << "fract of undefined Expr\n";
2021     return x - trunc(x);
2022 }
2023 
reinterpret(Type t,Expr e)2024 Expr reinterpret(Type t, Expr e) {
2025     user_assert(e.defined()) << "reinterpret of undefined Expr\n";
2026     int from_bits = e.type().bits() * e.type().lanes();
2027     int to_bits = t.bits() * t.lanes();
2028     user_assert(from_bits == to_bits)
2029         << "Reinterpret cast from type " << e.type()
2030         << " which has " << from_bits
2031         << " bits, to type " << t
2032         << " which has " << to_bits << " bits\n";
2033     return Internal::Call::make(t, Internal::Call::reinterpret, {std::move(e)}, Internal::Call::PureIntrinsic);
2034 }
2035 
operator &(Expr x,Expr y)2036 Expr operator&(Expr x, Expr y) {
2037     match_types_bitwise(x, y, "bitwise and");
2038     Type t = x.type();
2039     return Internal::Call::make(t, Internal::Call::bitwise_and, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2040 }
2041 
operator &(Expr x,int y)2042 Expr operator&(Expr x, int y) {
2043     Type t = x.type();
2044     Internal::check_representable(t, y);
2045     return Internal::Call::make(t, Internal::Call::bitwise_and, {std::move(x), Internal::make_const(t, y)}, Internal::Call::PureIntrinsic);
2046 }
2047 
operator &(int x,Expr y)2048 Expr operator&(int x, Expr y) {
2049     Type t = y.type();
2050     Internal::check_representable(t, x);
2051     return Internal::Call::make(t, Internal::Call::bitwise_and, {Internal::make_const(t, x), std::move(y)}, Internal::Call::PureIntrinsic);
2052 }
2053 
operator |(Expr x,Expr y)2054 Expr operator|(Expr x, Expr y) {
2055     match_types_bitwise(x, y, "bitwise or");
2056     Type t = x.type();
2057     return Internal::Call::make(t, Internal::Call::bitwise_or, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2058 }
2059 
operator |(Expr x,int y)2060 Expr operator|(Expr x, int y) {
2061     Type t = x.type();
2062     Internal::check_representable(t, y);
2063     return Internal::Call::make(t, Internal::Call::bitwise_or, {std::move(x), Internal::make_const(t, y)}, Internal::Call::PureIntrinsic);
2064 }
2065 
operator |(int x,Expr y)2066 Expr operator|(int x, Expr y) {
2067     Type t = y.type();
2068     Internal::check_representable(t, x);
2069     return Internal::Call::make(t, Internal::Call::bitwise_or, {Internal::make_const(t, x), std::move(y)}, Internal::Call::PureIntrinsic);
2070 }
2071 
operator ^(Expr x,Expr y)2072 Expr operator^(Expr x, Expr y) {
2073     match_types_bitwise(x, y, "bitwise xor");
2074     Type t = x.type();
2075     return Internal::Call::make(t, Internal::Call::bitwise_xor, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2076 }
2077 
operator ^(Expr x,int y)2078 Expr operator^(Expr x, int y) {
2079     Type t = x.type();
2080     Internal::check_representable(t, y);
2081     return Internal::Call::make(t, Internal::Call::bitwise_xor, {std::move(x), Internal::make_const(t, y)}, Internal::Call::PureIntrinsic);
2082 }
2083 
operator ^(int x,Expr y)2084 Expr operator^(int x, Expr y) {
2085     Type t = y.type();
2086     Internal::check_representable(t, x);
2087     return Internal::Call::make(t, Internal::Call::bitwise_xor, {Internal::make_const(t, x), std::move(y)}, Internal::Call::PureIntrinsic);
2088 }
2089 
operator ~(Expr x)2090 Expr operator~(Expr x) {
2091     user_assert(x.defined()) << "bitwise not of undefined Expr\n";
2092     user_assert(x.type().is_int() || x.type().is_uint())
2093         << "Argument to bitwise not must be an integer or unsigned integer";
2094     Type t = x.type();
2095     return Internal::Call::make(t, Internal::Call::bitwise_not, {std::move(x)}, Internal::Call::PureIntrinsic);
2096 }
2097 
operator <<(Expr x,Expr y)2098 Expr operator<<(Expr x, Expr y) {
2099     if (y.type().is_vector() && !x.type().is_vector()) {
2100         x = Internal::Broadcast::make(x, y.type().lanes());
2101     }
2102     match_bits(x, y);
2103     Type t = x.type();
2104     return Internal::Call::make(t, Internal::Call::shift_left, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2105 }
2106 
operator <<(Expr x,int y)2107 Expr operator<<(Expr x, int y) {
2108     Type t = Int(x.type().bits(), x.type().lanes());
2109     Internal::check_representable(t, y);
2110     return std::move(x) << Internal::make_const(t, y);
2111 }
2112 
operator >>(Expr x,Expr y)2113 Expr operator>>(Expr x, Expr y) {
2114     if (y.type().is_vector() && !x.type().is_vector()) {
2115         x = Internal::Broadcast::make(x, y.type().lanes());
2116     }
2117     match_bits(x, y);
2118     Type t = x.type();
2119     return Internal::Call::make(t, Internal::Call::shift_right, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2120 }
2121 
operator >>(Expr x,int y)2122 Expr operator>>(Expr x, int y) {
2123     Type t = Int(x.type().bits(), x.type().lanes());
2124     Internal::check_representable(t, y);
2125     return std::move(x) >> Internal::make_const(t, y);
2126 }
2127 
lerp(Expr zero_val,Expr one_val,Expr weight)2128 Expr lerp(Expr zero_val, Expr one_val, Expr weight) {
2129     user_assert(zero_val.defined()) << "lerp with undefined zero value";
2130     user_assert(one_val.defined()) << "lerp with undefined one value";
2131     user_assert(weight.defined()) << "lerp with undefined weight";
2132 
2133     // We allow integer constants through, so that you can say things
2134     // like lerp(0, cast<uint8_t>(x), alpha) and produce an 8-bit
2135     // result. Note that lerp(0.0f, cast<uint8_t>(x), alpha) will
2136     // produce an error, as will lerp(0.0f, cast<double>(x),
2137     // alpha). lerp(0, cast<float>(x), alpha) is also allowed and will
2138     // produce a float result.
2139     if (as_const_int(zero_val)) {
2140         zero_val = cast(one_val.type(), std::move(zero_val));
2141     }
2142     if (as_const_int(one_val)) {
2143         one_val = cast(zero_val.type(), std::move(one_val));
2144     }
2145 
2146     user_assert(zero_val.type() == one_val.type())
2147         << "Can't lerp between " << zero_val << " of type " << zero_val.type()
2148         << " and " << one_val << " of different type " << one_val.type() << "\n";
2149     user_assert((weight.type().is_uint() || weight.type().is_float()))
2150         << "A lerp weight must be an unsigned integer or a float, but "
2151         << "lerp weight " << weight << " has type " << weight.type() << ".\n";
2152     user_assert((zero_val.type().is_float() || zero_val.type().lanes() <= 32))
2153         << "Lerping between 64-bit integers is not supported\n";
2154     // Compilation error for constant weight that is out of range for integer use
2155     // as this seems like an easy to catch gotcha.
2156     if (!zero_val.type().is_float()) {
2157         const double *const_weight = as_const_float(weight);
2158         if (const_weight) {
2159             user_assert(*const_weight >= 0.0 && *const_weight <= 1.0)
2160                 << "Floating-point weight for lerp with integer arguments is "
2161                 << *const_weight << ", which is not in the range [0.0, 1.0].\n";
2162         }
2163     }
2164     Type t = zero_val.type();
2165     return Internal::Call::make(t, Internal::Call::lerp,
2166                                 {std::move(zero_val), std::move(one_val), std::move(weight)},
2167                                 Internal::Call::PureIntrinsic);
2168 }
2169 
popcount(Expr x)2170 Expr popcount(Expr x) {
2171     user_assert(x.defined()) << "popcount of undefined Expr\n";
2172     Type t = x.type();
2173     user_assert(t.is_uint() || t.is_int())
2174         << "Argument to popcount must be an integer\n";
2175     return Internal::Call::make(t, Internal::Call::popcount,
2176                                 {std::move(x)}, Internal::Call::PureIntrinsic);
2177 }
2178 
count_leading_zeros(Expr x)2179 Expr count_leading_zeros(Expr x) {
2180     user_assert(x.defined()) << "count leading zeros of undefined Expr\n";
2181     Type t = x.type();
2182     user_assert(t.is_uint() || t.is_int())
2183         << "Argument to count_leading_zeros must be an integer\n";
2184     return Internal::Call::make(t, Internal::Call::count_leading_zeros,
2185                                 {std::move(x)}, Internal::Call::PureIntrinsic);
2186 }
2187 
count_trailing_zeros(Expr x)2188 Expr count_trailing_zeros(Expr x) {
2189     user_assert(x.defined()) << "count trailing zeros of undefined Expr\n";
2190     Type t = x.type();
2191     user_assert(t.is_uint() || t.is_int())
2192         << "Argument to count_trailing_zeros must be an integer\n";
2193     return Internal::Call::make(t, Internal::Call::count_trailing_zeros,
2194                                 {std::move(x)}, Internal::Call::PureIntrinsic);
2195 }
2196 
div_round_to_zero(Expr x,Expr y)2197 Expr div_round_to_zero(Expr x, Expr y) {
2198     user_assert(x.defined()) << "div_round_to_zero of undefined dividend\n";
2199     user_assert(y.defined()) << "div_round_to_zero of undefined divisor\n";
2200     Internal::match_types(x, y);
2201     if (x.type().is_uint()) {
2202         return std::move(x) / std::move(y);
2203     }
2204     user_assert(x.type().is_int()) << "First argument to div_round_to_zero is not an integer: " << x << "\n";
2205     user_assert(y.type().is_int()) << "Second argument to div_round_to_zero is not an integer: " << y << "\n";
2206     Type t = x.type();
2207     return Internal::Call::make(t, Internal::Call::div_round_to_zero,
2208                                 {std::move(x), std::move(y)},
2209                                 Internal::Call::Intrinsic);
2210 }
2211 
mod_round_to_zero(Expr x,Expr y)2212 Expr mod_round_to_zero(Expr x, Expr y) {
2213     user_assert(x.defined()) << "mod_round_to_zero of undefined dividend\n";
2214     user_assert(y.defined()) << "mod_round_to_zero of undefined divisor\n";
2215     Internal::match_types(x, y);
2216     if (x.type().is_uint()) {
2217         return std::move(x) % std::move(y);
2218     }
2219     user_assert(x.type().is_int()) << "First argument to mod_round_to_zero is not an integer: " << x << "\n";
2220     user_assert(y.type().is_int()) << "Second argument to mod_round_to_zero is not an integer: " << y << "\n";
2221     Type t = x.type();
2222     return Internal::Call::make(t, Internal::Call::mod_round_to_zero,
2223                                 {std::move(x), std::move(y)},
2224                                 Internal::Call::Intrinsic);
2225 }
2226 
random_float(Expr seed)2227 Expr random_float(Expr seed) {
2228     // Random floats get even IDs
2229     static std::atomic<int> counter{0};
2230     int id = (counter++) * 2;
2231 
2232     std::vector<Expr> args;
2233     if (seed.defined()) {
2234         user_assert(seed.type() == Int(32))
2235             << "The seed passed to random_float must have type Int(32), but instead is "
2236             << seed << " of type " << seed.type() << "\n";
2237         args.push_back(std::move(seed));
2238     }
2239     args.emplace_back(id);
2240 
2241     // This is (surprisingly) pure - it's a fixed psuedo-random
2242     // function of its inputs.
2243     return Internal::Call::make(Float(32), Internal::Call::random,
2244                                 args, Internal::Call::PureIntrinsic);
2245 }
2246 
random_uint(Expr seed)2247 Expr random_uint(Expr seed) {
2248     // Random ints get odd IDs
2249     static std::atomic<int> counter{0};
2250     int id = (counter++) * 2 + 1;
2251 
2252     std::vector<Expr> args;
2253     if (seed.defined()) {
2254         user_assert(seed.type() == Int(32) || seed.type() == UInt(32))
2255             << "The seed passed to random_int must have type Int(32) or UInt(32), but instead is "
2256             << seed << " of type " << seed.type() << "\n";
2257         args.push_back(std::move(seed));
2258     }
2259     args.emplace_back(id);
2260 
2261     return Internal::Call::make(UInt(32), Internal::Call::random,
2262                                 args, Internal::Call::PureIntrinsic);
2263 }
2264 
random_int(Expr seed)2265 Expr random_int(Expr seed) {
2266     return cast<int32_t>(random_uint(std::move(seed)));
2267 }
2268 
likely(Expr e)2269 Expr likely(Expr e) {
2270     Type t = e.type();
2271     return Internal::Call::make(t, Internal::Call::likely,
2272                                 {std::move(e)}, Internal::Call::PureIntrinsic);
2273 }
2274 
likely_if_innermost(Expr e)2275 Expr likely_if_innermost(Expr e) {
2276     Type t = e.type();
2277     return Internal::Call::make(t, Internal::Call::likely_if_innermost,
2278                                 {std::move(e)}, Internal::Call::PureIntrinsic);
2279 }
2280 
strict_float(Expr e)2281 Expr strict_float(Expr e) {
2282     Type t = e.type();
2283     return Internal::Call::make(t, Internal::Call::strict_float,
2284                                 {std::move(e)}, Internal::Call::PureIntrinsic);
2285 }
2286 
undef(Type t)2287 Expr undef(Type t) {
2288     return Internal::Call::make(t, Internal::Call::undef,
2289                                 std::vector<Expr>(),
2290                                 Internal::Call::PureIntrinsic);
2291 }
2292 
2293 }  // namespace Halide
2294