1 #include "Halide.h"
2 
3 using namespace Halide;
4 using namespace Halide::Internal;
5 
6 class Checker : public IRMutator {
visit(const Atomic * op)7     Stmt visit(const Atomic *op) override {
8         count_atomics++;
9         if (!op->mutex_name.empty()) {
10             count_atomics_with_mutexes++;
11         }
12         return IRMutator::visit(op);
13     }
14 
15 public:
16     int count_atomics = 0;
17     int count_atomics_with_mutexes = 0;
18 };
19 
main(int argc,char ** argv)20 int main(int argc, char **argv) {
21     if (get_jit_target_from_environment().arch == Target::WebAssembly) {
22         printf("[SKIP] Skipping test for WebAssembly as it does not support atomics yet.\n");
23         return 0;
24     }
25 
26     {
27         Func f, g;
28         Var x, y;
29 
30         f(x, y) = {x, y};
31         f(x, y) = {f(x, y)[0] + x,
32                    f(x, y)[1] + y};
33 
34         // The summation is independent in the two tuple components,
35         // so it can just be two atomic add instructions. No CAS loop
36         // required.
37         f.compute_root().update().parallel(y).atomic();
38 
39         g(x, y) = f(x, y)[0] + f(x, y)[1];
40 
41         Checker checker;
42         g.add_custom_lowering_pass(&checker, []() {});
43 
44         Buffer<int> out = g.realize(128, 128);
45         for (int y = 0; y < 128; y++) {
46             for (int x = 0; x < 128; x++) {
47                 int correct = 2 * x + 2 * y;
48                 if (out(x, y) != correct) {
49                     printf("out(%d, %d) = %d instead of %d\n",
50                            x, y, out(x, y), correct);
51                     return -1;
52                 }
53             }
54         }
55 
56         if (checker.count_atomics != 2 || checker.count_atomics_with_mutexes != 0) {
57             printf("Expected two atomic nodes, neither of them with mutexes\n");
58             return -1;
59         }
60     }
61 
62     {
63         Func f, g;
64         Var x, y;
65 
66         f(x, y) = {x, y};
67         f(x, y) = {f(x, y)[1] + x,
68                    f(x, y)[0] + y};
69 
70         // The summation is coupled across the two tuple components
71         // and there are two stores, so we need a mutex.
72         f.compute_root().update().parallel(y).atomic();
73 
74         g(x, y) = f(x, y)[0] + f(x, y)[1];
75 
76         Checker checker;
77         g.add_custom_lowering_pass(&checker, []() {});
78 
79         Buffer<int> out = g.realize(128, 128);
80         for (int y = 0; y < 128; y++) {
81             for (int x = 0; x < 128; x++) {
82                 int correct = 2 * x + 2 * y;
83                 if (out(x, y) != correct) {
84                     printf("out(%d, %d) = %d instead of %d\n",
85                            x, y, out(x, y), correct);
86                     return -1;
87                 }
88             }
89         }
90 
91         if (checker.count_atomics != 1 || checker.count_atomics_with_mutexes != 1) {
92             printf("Expected one atomic node, with mutex\n");
93             return -1;
94         }
95     }
96 
97     {
98         Func f, g;
99         Var x, y;
100 
101         f(x, y) = {x, y, 0};
102         f(x, y) = {f(x, y)[1] + x,
103                    f(x, y)[0] + y,
104                    f(x, y)[2] + 1};
105 
106         // The summation is coupled across the first two tuple
107         // components and there are two stores, so we need a mutex
108         // there. The last store could in principle be a separate atomic
109         // add, but we instead just pack it into the critical section.
110         f.compute_root().update().parallel(y).atomic();
111 
112         g(x, y) = f(x, y)[0] + f(x, y)[1] + f(x, y)[2];
113 
114         Checker checker;
115         g.add_custom_lowering_pass(&checker, []() {});
116 
117         Buffer<int> out = g.realize(128, 128);
118         for (int y = 0; y < 128; y++) {
119             for (int x = 0; x < 128; x++) {
120                 int correct = 2 * x + 2 * y + 1;
121                 if (out(x, y) != correct) {
122                     printf("out(%d, %d) = %d instead of %d\n",
123                            x, y, out(x, y), correct);
124                     return -1;
125                 }
126             }
127         }
128 
129         if (checker.count_atomics != 1 || checker.count_atomics_with_mutexes != 1) {
130             printf("Expected one atomic nodes, with mutex\n");
131             return -1;
132         }
133     }
134 
135     {
136         Func f, g;
137         Var x, y;
138 
139         f(x, y) = {x, y, x, y};
140         f(x, y) = {f(x, y)[1] + x,
141                    f(x, y)[0] + y,
142                    f(x, y)[3] + x,
143                    f(x, y)[2] + y};
144 
145         // The summation is coupled across the first two tuple
146         // components and the last two components, but they're
147         // independent so they *could* get two critical sections, but
148         // it would be on the same mutex, so we just pack them all
149         // into one critical section.
150         f.compute_root().update().parallel(y).atomic();
151 
152         g(x, y) = f(x, y)[0] + f(x, y)[1] + f(x, y)[2] + f(x, y)[3];
153 
154         Checker checker;
155         g.add_custom_lowering_pass(&checker, []() {});
156 
157         Buffer<int> out = g.realize(128, 128);
158         for (int y = 0; y < 128; y++) {
159             for (int x = 0; x < 128; x++) {
160                 int correct = 4 * x + 4 * y;
161                 if (out(x, y) != correct) {
162                     printf("out(%d, %d) = %d instead of %d\n",
163                            x, y, out(x, y), correct);
164                     return -1;
165                 }
166             }
167         }
168 
169         if (checker.count_atomics != 1 || checker.count_atomics_with_mutexes != 1) {
170             printf("Expected one atomic node, with mutex\n");
171             return -1;
172         }
173     }
174 
175     {
176         Func f, g;
177         Var x, y;
178 
179         f(x, y) = {x, y};
180         RDom r(0, 65);
181         // Update even rows
182         f(x, r * 2) = {f(x, r * 2 + 1)[1] + x,
183                        f(x, r * 2 - 1)[0] + r * 2};
184         // Update odd rows using even rows
185         f(x, r * 2 + 1) = {f(x, r * 2)[1] + x,
186                            f(x, r * 2 + 2)[0] + r * 2 + 1};
187 
188         // The tuple components have cross-talk, but the loads
189         // couldn't possibly alias with the stores because of the
190         // even/odd split. We can just use four atomic adds safely.
191         f.compute_root();
192         f.update(0)
193             .parallel(r)
194             .atomic();
195         f.update(1)
196             .parallel(r)
197             .atomic();
198 
199         g(x, y) = f(x, y)[0] + f(x, y)[1];
200 
201         Checker checker;
202         g.add_custom_lowering_pass(&checker, []() {});
203 
204         Buffer<int> out = g.realize(128, 128);
205         for (int y = 0; y < 128; y++) {
206             for (int x = 0; x < 128; x++) {
207                 int correct = 2 * x + 2 * y + 1;
208                 if (y & 1) {
209                     // The odd rows happen after the even rows, so
210                     // they get another dose of x + y.
211                     correct += x + y;
212                 }
213                 if (out(x, y) != correct) {
214                     printf("out(%d, %d) = %d instead of %d\n",
215                            x, y, out(x, y), correct);
216                     //return -1;
217                 }
218             }
219         }
220 
221         if (checker.count_atomics != 4 || checker.count_atomics_with_mutexes != 0) {
222             printf("Expected four atomic nodes, with no mutexes\n");
223             return -1;
224         }
225     }
226 
227     printf("Success!\n");
228     return 0;
229 }
230