1 #ifndef SIMD_OP_CHECK_H 2 #define SIMD_OP_CHECK_H 3 4 #include "Halide.h" 5 #include "halide_test_dirs.h" 6 7 #include <fstream> 8 9 namespace Halide { 10 struct TestResult { 11 std::string op; 12 std::string error_msg; 13 }; 14 15 struct Task { 16 std::string op; 17 std::string name; 18 int vector_width; 19 Expr expr; 20 }; 21 22 class SimdOpCheckTest { 23 public: 24 std::string filter{"*"}; 25 std::string output_directory{Internal::get_test_tmp_dir()}; 26 std::vector<Task> tasks; 27 28 Target target; 29 30 ImageParam in_f32{Float(32), 1, "in_f32"}; 31 ImageParam in_f64{Float(64), 1, "in_f64"}; 32 ImageParam in_i8{Int(8), 1, "in_i8"}; 33 ImageParam in_u8{UInt(8), 1, "in_u8"}; 34 ImageParam in_i16{Int(16), 1, "in_i16"}; 35 ImageParam in_u16{UInt(16), 1, "in_u16"}; 36 ImageParam in_i32{Int(32), 1, "in_i32"}; 37 ImageParam in_u32{UInt(32), 1, "in_u32"}; 38 ImageParam in_i64{Int(64), 1, "in_i64"}; 39 ImageParam in_u64{UInt(64), 1, "in_u64"}; 40 41 const std::vector<ImageParam> image_params{in_f32, in_f64, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64}; 42 const std::vector<Argument> arg_types{in_f32, in_f64, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64}; 43 int W; 44 int H; 45 SimdOpCheckTest(const Target t,int w,int h)46 SimdOpCheckTest(const Target t, int w, int h) 47 : target(t), W(w), H(h) { 48 target = target 49 .with_feature(Target::NoBoundsQuery) 50 .with_feature(Target::NoAsserts) 51 .with_feature(Target::NoRuntime) 52 .with_feature(Target::DisableLLVMLoopOpt); 53 num_threads = Internal::ThreadPool<void>::num_processors_online(); 54 } 55 virtual ~SimdOpCheckTest() = default; get_num_threads()56 size_t get_num_threads() const { 57 return num_threads; 58 } 59 set_num_threads(size_t n)60 void set_num_threads(size_t n) { 61 num_threads = n; 62 } can_run_code()63 bool can_run_code() const { 64 // Assume we are configured to run wasm if requested 65 // (we'll fail further downstream if not) 66 if (target.arch == Target::WebAssembly) { 67 return true; 68 } 69 // If we can (target matches host), run the error checking Halide::Func. 70 Target host_target = get_host_target(); 71 bool can_run_the_code = 72 (target.arch == host_target.arch && 73 target.bits == host_target.bits && 74 target.os == host_target.os); 75 // A bunch of feature flags also need to match between the 76 // compiled code and the host in order to run the code. 77 for (Target::Feature f : {Target::SSE41, Target::AVX, 78 Target::AVX2, Target::AVX512, 79 Target::FMA, Target::FMA4, Target::F16C, 80 Target::VSX, Target::POWER_ARCH_2_07, 81 Target::ARMv7s, Target::NoNEON, 82 Target::WasmSimd128}) { 83 if (target.has_feature(f) != host_target.has_feature(f)) { 84 can_run_the_code = false; 85 } 86 } 87 return can_run_the_code; 88 } 89 90 // Check if pattern p matches str, allowing for wildcards (*). wildcard_match(const char * p,const char * str)91 bool wildcard_match(const char *p, const char *str) const { 92 // Match all non-wildcard characters. 93 while (*p && *str && *p == *str && *p != '*') { 94 str++; 95 p++; 96 } 97 98 if (!*p) { 99 return *str == 0; 100 } else if (*p == '*') { 101 p++; 102 do { 103 if (wildcard_match(p, str)) { 104 return true; 105 } 106 } while (*str++); 107 } else if (*p == ' ') { // ignore whitespace in pattern 108 p++; 109 if (wildcard_match(p, str)) { 110 return true; 111 } 112 } else if (*str == ' ') { // ignore whitespace in string 113 str++; 114 if (wildcard_match(p, str)) { 115 return true; 116 } 117 } 118 return !*p; 119 } 120 wildcard_match(const std::string & p,const std::string & str)121 bool wildcard_match(const std::string &p, const std::string &str) const { 122 return wildcard_match(p.c_str(), str.c_str()); 123 } 124 125 // Check if a substring of str matches a pattern p. wildcard_search(const std::string & p,const std::string & str)126 bool wildcard_search(const std::string &p, const std::string &str) const { 127 return wildcard_match("*" + p + "*", str); 128 } 129 check_one(const std::string & op,const std::string & name,int vector_width,Expr e)130 TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e) { 131 std::ostringstream error_msg; 132 133 class HasInlineReduction : public Internal::IRVisitor { 134 using Internal::IRVisitor::visit; 135 void visit(const Internal::Call *op) override { 136 if (op->call_type == Internal::Call::Halide) { 137 Internal::Function f(op->func); 138 if (f.has_update_definition()) { 139 inline_reduction = f; 140 result = true; 141 } 142 } 143 IRVisitor::visit(op); 144 } 145 146 public: 147 Internal::Function inline_reduction; 148 bool result = false; 149 } has_inline_reduction; 150 e.accept(&has_inline_reduction); 151 152 // Define a vectorized Halide::Func that uses the pattern. 153 Halide::Func f(name); 154 f(x, y) = e; 155 f.bound(x, 0, W).vectorize(x, vector_width); 156 f.compute_root(); 157 158 // Include a scalar version 159 Halide::Func f_scalar("scalar_" + name); 160 f_scalar(x, y) = e; 161 f_scalar.bound(x, 0, W); 162 f_scalar.compute_root(); 163 164 if (has_inline_reduction.result) { 165 // If there's an inline reduction, we want to vectorize it 166 // over the RVar. 167 Var xo, xi; 168 RVar rxi; 169 Func g{has_inline_reduction.inline_reduction}; 170 171 // Do the reduction separately in f_scalar 172 g.clone_in(f_scalar); 173 174 g.compute_at(f, x) 175 .update() 176 .split(x, xo, xi, vector_width) 177 .fuse(g.rvars()[0], xi, rxi) 178 .atomic() 179 .vectorize(rxi); 180 } 181 182 // The output to the pipeline is the maximum absolute difference as a double. 183 RDom r_check(0, W, 0, H); 184 Halide::Func error("error_" + name); 185 error() = Halide::cast<double>(maximum(absd(f(r_check.x, r_check.y), f_scalar(r_check.x, r_check.y)))); 186 187 setup_images(); 188 { 189 // Compile just the vector Func to assembly. 190 std::string asm_filename = output_directory + "check_" + name + ".s"; 191 f.compile_to_assembly(asm_filename, arg_types, target); 192 193 std::ifstream asm_file; 194 asm_file.open(asm_filename); 195 196 bool found_it = false; 197 198 std::ostringstream msg; 199 msg << op << " did not generate for target=" << target.to_string() << " vector_width=" << vector_width << ". Instead we got:\n"; 200 201 std::string line; 202 while (getline(asm_file, line)) { 203 msg << line << "\n"; 204 205 // Check for the op in question 206 found_it |= wildcard_search(op, line) && !wildcard_search("_" + op, line); 207 } 208 209 if (!found_it) { 210 error_msg << "Failed: " << msg.str() << "\n"; 211 } 212 213 asm_file.close(); 214 } 215 216 // Also compile the error checking Func (to be sure it compiles without error) 217 std::string fn_name = "test_" + name; 218 error.compile_to_file(output_directory + fn_name, arg_types, fn_name, target); 219 220 bool can_run_the_code = can_run_code(); 221 if (can_run_the_code) { 222 Target run_target = target 223 .without_feature(Target::NoRuntime) 224 .without_feature(Target::NoAsserts) 225 .without_feature(Target::NoBoundsQuery); 226 227 error.infer_input_bounds({}, run_target); 228 // Fill the inputs with noise 229 std::mt19937 rng(123); 230 for (auto p : image_params) { 231 Halide::Buffer<> buf = p.get(); 232 if (!buf.defined()) continue; 233 assert(buf.data()); 234 Type t = buf.type(); 235 // For floats/doubles, we only use values that aren't 236 // subject to rounding error that may differ between 237 // vectorized and non-vectorized versions 238 if (t == Float(32)) { 239 buf.as<float>().for_each_value([&](float &f) { f = (rng() & 0xfff) / 8.0f - 0xff; }); 240 } else if (t == Float(64)) { 241 buf.as<double>().for_each_value([&](double &f) { f = (rng() & 0xfff) / 8.0 - 0xff; }); 242 } else { 243 // Random bits is fine 244 for (uint32_t *ptr = (uint32_t *)buf.data(); 245 ptr != (uint32_t *)buf.data() + buf.size_in_bytes() / 4; 246 ptr++) { 247 // Never use the top four bits, to avoid 248 // signed integer overflow. 249 *ptr = ((uint32_t)rng()) & 0x0fffffff; 250 } 251 } 252 } 253 Realization r = error.realize(); 254 double e = Buffer<double>(r[0])(); 255 // Use a very loose tolerance for floating point tests. The 256 // kinds of bugs we're looking for are codegen bugs that 257 // return the wrong value entirely, not floating point 258 // accuracy differences between vectors and scalars. 259 if (e > 0.001) { 260 error_msg << "The vector and scalar versions of " << name << " disagree. Maximum error: " << e << "\n"; 261 262 std::string error_filename = output_directory + "error_" + name + ".s"; 263 error.compile_to_assembly(error_filename, arg_types, target); 264 265 std::ifstream error_file; 266 error_file.open(error_filename); 267 268 error_msg << "Error assembly: \n"; 269 std::string line; 270 while (getline(error_file, line)) { 271 error_msg << line << "\n"; 272 } 273 274 error_file.close(); 275 } 276 } 277 278 return {op, error_msg.str()}; 279 } 280 check(std::string op,int vector_width,Expr e)281 void check(std::string op, int vector_width, Expr e) { 282 // Make a name for the test by uniquing then sanitizing the op name 283 std::string name = "op_" + op; 284 for (size_t i = 0; i < name.size(); i++) { 285 if (!isalnum(name[i])) name[i] = '_'; 286 } 287 288 name += "_" + std::to_string(tasks.size()); 289 290 // Bail out after generating the unique_name, so that names are 291 // unique across different processes and don't depend on filter 292 // settings. 293 if (!wildcard_match(filter, op)) return; 294 295 tasks.emplace_back(Task{op, name, vector_width, e}); 296 } 297 virtual void add_tests() = 0; setup_images()298 virtual void setup_images() { 299 for (auto p : image_params) { 300 p.reset(); 301 } 302 } test_all()303 virtual bool test_all() { 304 /* First add some tests based on the target */ 305 add_tests(); 306 Internal::ThreadPool<TestResult> pool(num_threads); 307 std::vector<std::future<TestResult>> futures; 308 for (const Task &task : tasks) { 309 futures.push_back(pool.async([this, task]() { 310 return check_one(task.op, task.name, task.vector_width, task.expr); 311 })); 312 } 313 314 bool success = true; 315 for (auto &f : futures) { 316 const TestResult &result = f.get(); 317 std::cout << result.op << "\n"; 318 if (!result.error_msg.empty()) { 319 std::cerr << result.error_msg; 320 success = false; 321 } 322 } 323 324 return success; 325 } 326 327 private: 328 size_t num_threads; 329 const Halide::Var x{"x"}, y{"y"}; 330 }; 331 } // namespace Halide 332 #endif // SIMD_OP_CHECK_H 333