1 #include "Halide.h"
2
3 using namespace Halide;
4 using namespace Halide::Internal;
5
6 class Checker : public IRMutator {
visit(const Atomic * op)7 Stmt visit(const Atomic *op) override {
8 count_atomics++;
9 if (!op->mutex_name.empty()) {
10 count_atomics_with_mutexes++;
11 }
12 return IRMutator::visit(op);
13 }
14
15 public:
16 int count_atomics = 0;
17 int count_atomics_with_mutexes = 0;
18 };
19
main(int argc,char ** argv)20 int main(int argc, char **argv) {
21 if (get_jit_target_from_environment().arch == Target::WebAssembly) {
22 printf("[SKIP] Skipping test for WebAssembly as it does not support atomics yet.\n");
23 return 0;
24 }
25
26 {
27 Func f, g;
28 Var x, y;
29
30 f(x, y) = {x, y};
31 f(x, y) = {f(x, y)[0] + x,
32 f(x, y)[1] + y};
33
34 // The summation is independent in the two tuple components,
35 // so it can just be two atomic add instructions. No CAS loop
36 // required.
37 f.compute_root().update().parallel(y).atomic();
38
39 g(x, y) = f(x, y)[0] + f(x, y)[1];
40
41 Checker checker;
42 g.add_custom_lowering_pass(&checker, []() {});
43
44 Buffer<int> out = g.realize(128, 128);
45 for (int y = 0; y < 128; y++) {
46 for (int x = 0; x < 128; x++) {
47 int correct = 2 * x + 2 * y;
48 if (out(x, y) != correct) {
49 printf("out(%d, %d) = %d instead of %d\n",
50 x, y, out(x, y), correct);
51 return -1;
52 }
53 }
54 }
55
56 if (checker.count_atomics != 2 || checker.count_atomics_with_mutexes != 0) {
57 printf("Expected two atomic nodes, neither of them with mutexes\n");
58 return -1;
59 }
60 }
61
62 {
63 Func f, g;
64 Var x, y;
65
66 f(x, y) = {x, y};
67 f(x, y) = {f(x, y)[1] + x,
68 f(x, y)[0] + y};
69
70 // The summation is coupled across the two tuple components
71 // and there are two stores, so we need a mutex.
72 f.compute_root().update().parallel(y).atomic();
73
74 g(x, y) = f(x, y)[0] + f(x, y)[1];
75
76 Checker checker;
77 g.add_custom_lowering_pass(&checker, []() {});
78
79 Buffer<int> out = g.realize(128, 128);
80 for (int y = 0; y < 128; y++) {
81 for (int x = 0; x < 128; x++) {
82 int correct = 2 * x + 2 * y;
83 if (out(x, y) != correct) {
84 printf("out(%d, %d) = %d instead of %d\n",
85 x, y, out(x, y), correct);
86 return -1;
87 }
88 }
89 }
90
91 if (checker.count_atomics != 1 || checker.count_atomics_with_mutexes != 1) {
92 printf("Expected one atomic node, with mutex\n");
93 return -1;
94 }
95 }
96
97 {
98 Func f, g;
99 Var x, y;
100
101 f(x, y) = {x, y, 0};
102 f(x, y) = {f(x, y)[1] + x,
103 f(x, y)[0] + y,
104 f(x, y)[2] + 1};
105
106 // The summation is coupled across the first two tuple
107 // components and there are two stores, so we need a mutex
108 // there. The last store could in principle be a separate atomic
109 // add, but we instead just pack it into the critical section.
110 f.compute_root().update().parallel(y).atomic();
111
112 g(x, y) = f(x, y)[0] + f(x, y)[1] + f(x, y)[2];
113
114 Checker checker;
115 g.add_custom_lowering_pass(&checker, []() {});
116
117 Buffer<int> out = g.realize(128, 128);
118 for (int y = 0; y < 128; y++) {
119 for (int x = 0; x < 128; x++) {
120 int correct = 2 * x + 2 * y + 1;
121 if (out(x, y) != correct) {
122 printf("out(%d, %d) = %d instead of %d\n",
123 x, y, out(x, y), correct);
124 return -1;
125 }
126 }
127 }
128
129 if (checker.count_atomics != 1 || checker.count_atomics_with_mutexes != 1) {
130 printf("Expected one atomic nodes, with mutex\n");
131 return -1;
132 }
133 }
134
135 {
136 Func f, g;
137 Var x, y;
138
139 f(x, y) = {x, y, x, y};
140 f(x, y) = {f(x, y)[1] + x,
141 f(x, y)[0] + y,
142 f(x, y)[3] + x,
143 f(x, y)[2] + y};
144
145 // The summation is coupled across the first two tuple
146 // components and the last two components, but they're
147 // independent so they *could* get two critical sections, but
148 // it would be on the same mutex, so we just pack them all
149 // into one critical section.
150 f.compute_root().update().parallel(y).atomic();
151
152 g(x, y) = f(x, y)[0] + f(x, y)[1] + f(x, y)[2] + f(x, y)[3];
153
154 Checker checker;
155 g.add_custom_lowering_pass(&checker, []() {});
156
157 Buffer<int> out = g.realize(128, 128);
158 for (int y = 0; y < 128; y++) {
159 for (int x = 0; x < 128; x++) {
160 int correct = 4 * x + 4 * y;
161 if (out(x, y) != correct) {
162 printf("out(%d, %d) = %d instead of %d\n",
163 x, y, out(x, y), correct);
164 return -1;
165 }
166 }
167 }
168
169 if (checker.count_atomics != 1 || checker.count_atomics_with_mutexes != 1) {
170 printf("Expected one atomic node, with mutex\n");
171 return -1;
172 }
173 }
174
175 {
176 Func f, g;
177 Var x, y;
178
179 f(x, y) = {x, y};
180 RDom r(0, 65);
181 // Update even rows
182 f(x, r * 2) = {f(x, r * 2 + 1)[1] + x,
183 f(x, r * 2 - 1)[0] + r * 2};
184 // Update odd rows using even rows
185 f(x, r * 2 + 1) = {f(x, r * 2)[1] + x,
186 f(x, r * 2 + 2)[0] + r * 2 + 1};
187
188 // The tuple components have cross-talk, but the loads
189 // couldn't possibly alias with the stores because of the
190 // even/odd split. We can just use four atomic adds safely.
191 f.compute_root();
192 f.update(0)
193 .parallel(r)
194 .atomic();
195 f.update(1)
196 .parallel(r)
197 .atomic();
198
199 g(x, y) = f(x, y)[0] + f(x, y)[1];
200
201 Checker checker;
202 g.add_custom_lowering_pass(&checker, []() {});
203
204 Buffer<int> out = g.realize(128, 128);
205 for (int y = 0; y < 128; y++) {
206 for (int x = 0; x < 128; x++) {
207 int correct = 2 * x + 2 * y + 1;
208 if (y & 1) {
209 // The odd rows happen after the even rows, so
210 // they get another dose of x + y.
211 correct += x + y;
212 }
213 if (out(x, y) != correct) {
214 printf("out(%d, %d) = %d instead of %d\n",
215 x, y, out(x, y), correct);
216 //return -1;
217 }
218 }
219 }
220
221 if (checker.count_atomics != 4 || checker.count_atomics_with_mutexes != 0) {
222 printf("Expected four atomic nodes, with no mutexes\n");
223 return -1;
224 }
225 }
226
227 printf("Success!\n");
228 return 0;
229 }
230