1 #ifndef SIMD_OP_CHECK_H
2 #define SIMD_OP_CHECK_H
3 
4 #include "Halide.h"
5 #include "halide_test_dirs.h"
6 
7 #include <fstream>
8 
9 namespace Halide {
10 struct TestResult {
11     std::string op;
12     std::string error_msg;
13 };
14 
15 struct Task {
16     std::string op;
17     std::string name;
18     int vector_width;
19     Expr expr;
20 };
21 
22 class SimdOpCheckTest {
23 public:
24     std::string filter{"*"};
25     std::string output_directory{Internal::get_test_tmp_dir()};
26     std::vector<Task> tasks;
27 
28     Target target;
29 
30     ImageParam in_f32{Float(32), 1, "in_f32"};
31     ImageParam in_f64{Float(64), 1, "in_f64"};
32     ImageParam in_i8{Int(8), 1, "in_i8"};
33     ImageParam in_u8{UInt(8), 1, "in_u8"};
34     ImageParam in_i16{Int(16), 1, "in_i16"};
35     ImageParam in_u16{UInt(16), 1, "in_u16"};
36     ImageParam in_i32{Int(32), 1, "in_i32"};
37     ImageParam in_u32{UInt(32), 1, "in_u32"};
38     ImageParam in_i64{Int(64), 1, "in_i64"};
39     ImageParam in_u64{UInt(64), 1, "in_u64"};
40 
41     const std::vector<ImageParam> image_params{in_f32, in_f64, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
42     const std::vector<Argument> arg_types{in_f32, in_f64, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
43     int W;
44     int H;
45 
SimdOpCheckTest(const Target t,int w,int h)46     SimdOpCheckTest(const Target t, int w, int h)
47         : target(t), W(w), H(h) {
48         target = target
49                      .with_feature(Target::NoBoundsQuery)
50                      .with_feature(Target::NoAsserts)
51                      .with_feature(Target::NoRuntime)
52                      .with_feature(Target::DisableLLVMLoopOpt);
53         num_threads = Internal::ThreadPool<void>::num_processors_online();
54     }
55     virtual ~SimdOpCheckTest() = default;
get_num_threads()56     size_t get_num_threads() const {
57         return num_threads;
58     }
59 
set_num_threads(size_t n)60     void set_num_threads(size_t n) {
61         num_threads = n;
62     }
can_run_code()63     bool can_run_code() const {
64         // Assume we are configured to run wasm if requested
65         // (we'll fail further downstream if not)
66         if (target.arch == Target::WebAssembly) {
67             return true;
68         }
69         // If we can (target matches host), run the error checking Halide::Func.
70         Target host_target = get_host_target();
71         bool can_run_the_code =
72             (target.arch == host_target.arch &&
73              target.bits == host_target.bits &&
74              target.os == host_target.os);
75         // A bunch of feature flags also need to match between the
76         // compiled code and the host in order to run the code.
77         for (Target::Feature f : {Target::SSE41, Target::AVX,
78                                   Target::AVX2, Target::AVX512,
79                                   Target::FMA, Target::FMA4, Target::F16C,
80                                   Target::VSX, Target::POWER_ARCH_2_07,
81                                   Target::ARMv7s, Target::NoNEON,
82                                   Target::WasmSimd128}) {
83             if (target.has_feature(f) != host_target.has_feature(f)) {
84                 can_run_the_code = false;
85             }
86         }
87         return can_run_the_code;
88     }
89 
90     // Check if pattern p matches str, allowing for wildcards (*).
wildcard_match(const char * p,const char * str)91     bool wildcard_match(const char *p, const char *str) const {
92         // Match all non-wildcard characters.
93         while (*p && *str && *p == *str && *p != '*') {
94             str++;
95             p++;
96         }
97 
98         if (!*p) {
99             return *str == 0;
100         } else if (*p == '*') {
101             p++;
102             do {
103                 if (wildcard_match(p, str)) {
104                     return true;
105                 }
106             } while (*str++);
107         } else if (*p == ' ') {  // ignore whitespace in pattern
108             p++;
109             if (wildcard_match(p, str)) {
110                 return true;
111             }
112         } else if (*str == ' ') {  // ignore whitespace in string
113             str++;
114             if (wildcard_match(p, str)) {
115                 return true;
116             }
117         }
118         return !*p;
119     }
120 
wildcard_match(const std::string & p,const std::string & str)121     bool wildcard_match(const std::string &p, const std::string &str) const {
122         return wildcard_match(p.c_str(), str.c_str());
123     }
124 
125     // Check if a substring of str matches a pattern p.
wildcard_search(const std::string & p,const std::string & str)126     bool wildcard_search(const std::string &p, const std::string &str) const {
127         return wildcard_match("*" + p + "*", str);
128     }
129 
check_one(const std::string & op,const std::string & name,int vector_width,Expr e)130     TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e) {
131         std::ostringstream error_msg;
132 
133         class HasInlineReduction : public Internal::IRVisitor {
134             using Internal::IRVisitor::visit;
135             void visit(const Internal::Call *op) override {
136                 if (op->call_type == Internal::Call::Halide) {
137                     Internal::Function f(op->func);
138                     if (f.has_update_definition()) {
139                         inline_reduction = f;
140                         result = true;
141                     }
142                 }
143                 IRVisitor::visit(op);
144             }
145 
146         public:
147             Internal::Function inline_reduction;
148             bool result = false;
149         } has_inline_reduction;
150         e.accept(&has_inline_reduction);
151 
152         // Define a vectorized Halide::Func that uses the pattern.
153         Halide::Func f(name);
154         f(x, y) = e;
155         f.bound(x, 0, W).vectorize(x, vector_width);
156         f.compute_root();
157 
158         // Include a scalar version
159         Halide::Func f_scalar("scalar_" + name);
160         f_scalar(x, y) = e;
161         f_scalar.bound(x, 0, W);
162         f_scalar.compute_root();
163 
164         if (has_inline_reduction.result) {
165             // If there's an inline reduction, we want to vectorize it
166             // over the RVar.
167             Var xo, xi;
168             RVar rxi;
169             Func g{has_inline_reduction.inline_reduction};
170 
171             // Do the reduction separately in f_scalar
172             g.clone_in(f_scalar);
173 
174             g.compute_at(f, x)
175                 .update()
176                 .split(x, xo, xi, vector_width)
177                 .fuse(g.rvars()[0], xi, rxi)
178                 .atomic()
179                 .vectorize(rxi);
180         }
181 
182         // The output to the pipeline is the maximum absolute difference as a double.
183         RDom r_check(0, W, 0, H);
184         Halide::Func error("error_" + name);
185         error() = Halide::cast<double>(maximum(absd(f(r_check.x, r_check.y), f_scalar(r_check.x, r_check.y))));
186 
187         setup_images();
188         {
189             // Compile just the vector Func to assembly.
190             std::string asm_filename = output_directory + "check_" + name + ".s";
191             f.compile_to_assembly(asm_filename, arg_types, target);
192 
193             std::ifstream asm_file;
194             asm_file.open(asm_filename);
195 
196             bool found_it = false;
197 
198             std::ostringstream msg;
199             msg << op << " did not generate for target=" << target.to_string() << " vector_width=" << vector_width << ". Instead we got:\n";
200 
201             std::string line;
202             while (getline(asm_file, line)) {
203                 msg << line << "\n";
204 
205                 // Check for the op in question
206                 found_it |= wildcard_search(op, line) && !wildcard_search("_" + op, line);
207             }
208 
209             if (!found_it) {
210                 error_msg << "Failed: " << msg.str() << "\n";
211             }
212 
213             asm_file.close();
214         }
215 
216         // Also compile the error checking Func (to be sure it compiles without error)
217         std::string fn_name = "test_" + name;
218         error.compile_to_file(output_directory + fn_name, arg_types, fn_name, target);
219 
220         bool can_run_the_code = can_run_code();
221         if (can_run_the_code) {
222             Target run_target = target
223                                     .without_feature(Target::NoRuntime)
224                                     .without_feature(Target::NoAsserts)
225                                     .without_feature(Target::NoBoundsQuery);
226 
227             error.infer_input_bounds({}, run_target);
228             // Fill the inputs with noise
229             std::mt19937 rng(123);
230             for (auto p : image_params) {
231                 Halide::Buffer<> buf = p.get();
232                 if (!buf.defined()) continue;
233                 assert(buf.data());
234                 Type t = buf.type();
235                 // For floats/doubles, we only use values that aren't
236                 // subject to rounding error that may differ between
237                 // vectorized and non-vectorized versions
238                 if (t == Float(32)) {
239                     buf.as<float>().for_each_value([&](float &f) { f = (rng() & 0xfff) / 8.0f - 0xff; });
240                 } else if (t == Float(64)) {
241                     buf.as<double>().for_each_value([&](double &f) { f = (rng() & 0xfff) / 8.0 - 0xff; });
242                 } else {
243                     // Random bits is fine
244                     for (uint32_t *ptr = (uint32_t *)buf.data();
245                          ptr != (uint32_t *)buf.data() + buf.size_in_bytes() / 4;
246                          ptr++) {
247                         // Never use the top four bits, to avoid
248                         // signed integer overflow.
249                         *ptr = ((uint32_t)rng()) & 0x0fffffff;
250                     }
251                 }
252             }
253             Realization r = error.realize();
254             double e = Buffer<double>(r[0])();
255             // Use a very loose tolerance for floating point tests. The
256             // kinds of bugs we're looking for are codegen bugs that
257             // return the wrong value entirely, not floating point
258             // accuracy differences between vectors and scalars.
259             if (e > 0.001) {
260                 error_msg << "The vector and scalar versions of " << name << " disagree. Maximum error: " << e << "\n";
261 
262                 std::string error_filename = output_directory + "error_" + name + ".s";
263                 error.compile_to_assembly(error_filename, arg_types, target);
264 
265                 std::ifstream error_file;
266                 error_file.open(error_filename);
267 
268                 error_msg << "Error assembly: \n";
269                 std::string line;
270                 while (getline(error_file, line)) {
271                     error_msg << line << "\n";
272                 }
273 
274                 error_file.close();
275             }
276         }
277 
278         return {op, error_msg.str()};
279     }
280 
check(std::string op,int vector_width,Expr e)281     void check(std::string op, int vector_width, Expr e) {
282         // Make a name for the test by uniquing then sanitizing the op name
283         std::string name = "op_" + op;
284         for (size_t i = 0; i < name.size(); i++) {
285             if (!isalnum(name[i])) name[i] = '_';
286         }
287 
288         name += "_" + std::to_string(tasks.size());
289 
290         // Bail out after generating the unique_name, so that names are
291         // unique across different processes and don't depend on filter
292         // settings.
293         if (!wildcard_match(filter, op)) return;
294 
295         tasks.emplace_back(Task{op, name, vector_width, e});
296     }
297     virtual void add_tests() = 0;
setup_images()298     virtual void setup_images() {
299         for (auto p : image_params) {
300             p.reset();
301         }
302     }
test_all()303     virtual bool test_all() {
304         /* First add some tests based on the target */
305         add_tests();
306         Internal::ThreadPool<TestResult> pool(num_threads);
307         std::vector<std::future<TestResult>> futures;
308         for (const Task &task : tasks) {
309             futures.push_back(pool.async([this, task]() {
310                 return check_one(task.op, task.name, task.vector_width, task.expr);
311             }));
312         }
313 
314         bool success = true;
315         for (auto &f : futures) {
316             const TestResult &result = f.get();
317             std::cout << result.op << "\n";
318             if (!result.error_msg.empty()) {
319                 std::cerr << result.error_msg;
320                 success = false;
321             }
322         }
323 
324         return success;
325     }
326 
327 private:
328     size_t num_threads;
329     const Halide::Var x{"x"}, y{"y"};
330 };
331 }  // namespace Halide
332 #endif  // SIMD_OP_CHECK_H
333