1 #include "Simplify_Internal.h"
2
3 namespace Halide {
4 namespace Internal {
5
visit(const And * op,ExprInfo * bounds)6 Expr Simplify::visit(const And *op, ExprInfo *bounds) {
7 if (falsehoods.count(op)) {
8 return const_false(op->type.lanes());
9 }
10
11 Expr a = mutate(op->a, nullptr);
12 Expr b = mutate(op->b, nullptr);
13
14 // Order commutative operations by node type
15 if (should_commute(a, b)) {
16 std::swap(a, b);
17 }
18
19 auto rewrite = IRMatcher::rewriter(IRMatcher::and_op(a, b), op->type);
20
21 // clang-format off
22 if (EVAL_IN_LAMBDA
23 (rewrite(x && true, a) ||
24 rewrite(x && false, b) ||
25 rewrite(x && x, a) ||
26
27 rewrite((x && y) && x, a) ||
28 rewrite(x && (x && y), b) ||
29 rewrite((x && y) && y, a) ||
30 rewrite(y && (x && y), b) ||
31
32 rewrite(((x && y) && z) && x, a) ||
33 rewrite(x && ((x && y) && z), b) ||
34 rewrite((z && (x && y)) && x, a) ||
35 rewrite(x && (z && (x && y)), b) ||
36 rewrite(((x && y) && z) && y, a) ||
37 rewrite(y && ((x && y) && z), b) ||
38 rewrite((z && (x && y)) && y, a) ||
39 rewrite(y && (z && (x && y)), b) ||
40
41 rewrite((x || y) && x, b) ||
42 rewrite(x && (x || y), a) ||
43 rewrite((x || y) && y, b) ||
44 rewrite(y && (x || y), a) ||
45
46 rewrite(x != y && x == y, false) ||
47 rewrite(x != y && y == x, false) ||
48 rewrite((z && x != y) && x == y, false) ||
49 rewrite((z && x != y) && y == x, false) ||
50 rewrite((x != y && z) && x == y, false) ||
51 rewrite((x != y && z) && y == x, false) ||
52 rewrite((z && x == y) && x != y, false) ||
53 rewrite((z && x == y) && y != x, false) ||
54 rewrite((x == y && z) && x != y, false) ||
55 rewrite((x == y && z) && y != x, false) ||
56 rewrite(x && !x, false) ||
57 rewrite(!x && x, false) ||
58 rewrite(y <= x && x < y, false) ||
59 rewrite(x != c0 && x == c1, b, c0 != c1) ||
60 // Note: In the predicate below, if undefined overflow
61 // occurs, the predicate counts as false. If well-defined
62 // overflow occurs, the condition couldn't possibly
63 // trigger because c0 + 1 will be the smallest possible
64 // value.
65 rewrite(c0 < x && x < c1, false, !is_float(x) && c1 <= c0 + 1) ||
66 rewrite(x < c1 && c0 < x, false, !is_float(x) && c1 <= c0 + 1) ||
67 rewrite(x <= c1 && c0 < x, false, c1 <= c0) ||
68 rewrite(c0 <= x && x < c1, false, c1 <= c0) ||
69 rewrite(c0 <= x && x <= c1, false, c1 < c0) ||
70 rewrite(x <= c1 && c0 <= x, false, c1 < c0) ||
71 rewrite(c0 < x && c1 < x, fold(max(c0, c1)) < x) ||
72 rewrite(c0 <= x && c1 <= x, fold(max(c0, c1)) <= x) ||
73 rewrite(x < c0 && x < c1, x < fold(min(c0, c1))) ||
74 rewrite(x <= c0 && x <= c1, x <= fold(min(c0, c1))))) {
75 return rewrite.result;
76 }
77 // clang-format on
78
79 if (rewrite(broadcast(x) && broadcast(y), broadcast(x && y, op->type.lanes())) ||
80
81 rewrite((x || (y && z)) && y, (x || z) && y) ||
82 rewrite((x || (z && y)) && y, (x || z) && y) ||
83 rewrite(y && (x || (y && z)), y && (x || z)) ||
84 rewrite(y && (x || (z && y)), y && (x || z)) ||
85
86 rewrite(((y && z) || x) && y, (z || x) && y) ||
87 rewrite(((z && y) || x) && y, (z || x) && y) ||
88 rewrite(y && ((y && z) || x), y && (z || x)) ||
89 rewrite(y && ((z && y) || x), y && (z || x)) ||
90
91 rewrite((x && (y || z)) && y, x && y) ||
92 rewrite((x && (z || y)) && y, x && y) ||
93 rewrite(y && (x && (y || z)), y && x) ||
94 rewrite(y && (x && (z || y)), y && x) ||
95
96 rewrite(((y || z) && x) && y, x && y) ||
97 rewrite(((z || y) && x) && y, x && y) ||
98 rewrite(y && ((y || z) && x), y && x) ||
99 rewrite(y && ((z || y) && x), y && x) ||
100
101 rewrite((x || y) && (x || z), x || (y && z)) ||
102 rewrite((x || y) && (z || x), x || (y && z)) ||
103 rewrite((y || x) && (x || z), x || (y && z)) ||
104 rewrite((y || x) && (z || x), x || (y && z)) ||
105
106 rewrite(x < y && x < z, x < min(y, z)) ||
107 rewrite(y < x && z < x, max(y, z) < x) ||
108 rewrite(x <= y && x <= z, x <= min(y, z)) ||
109 rewrite(y <= x && z <= x, max(y, z) <= x)) {
110
111 return mutate(rewrite.result, bounds);
112 }
113
114 if (a.same_as(op->a) &&
115 b.same_as(op->b)) {
116 return op;
117 } else {
118 return And::make(a, b);
119 }
120 }
121
122 } // namespace Internal
123 } // namespace Halide
124