1 #include <iostream>
2
3 #include "CodeGen_X86.h"
4 #include "ConciseCasts.h"
5 #include "Debug.h"
6 #include "IRMatch.h"
7 #include "IRMutator.h"
8 #include "IROperator.h"
9 #include "JITModule.h"
10 #include "LLVM_Headers.h"
11 #include "Param.h"
12 #include "Util.h"
13 #include "Var.h"
14
15 namespace Halide {
16 namespace Internal {
17
18 using std::string;
19 using std::vector;
20
21 using namespace Halide::ConciseCasts;
22 using namespace llvm;
23
24 namespace {
25 // Populate feature flags in a target according to those implied by
26 // existing flags, so that instruction patterns can just check for the
27 // oldest feature flag that supports an instruction.
complete_x86_target(Target t)28 Target complete_x86_target(Target t) {
29 if (t.has_feature(Target::AVX512_Cannonlake) ||
30 t.has_feature(Target::AVX512_Skylake) ||
31 t.has_feature(Target::AVX512_KNL)) {
32 t.set_feature(Target::AVX2);
33 }
34 if (t.has_feature(Target::AVX2)) {
35 t.set_feature(Target::AVX);
36 }
37 if (t.has_feature(Target::AVX)) {
38 t.set_feature(Target::SSE41);
39 }
40 return t;
41 }
42 } // namespace
43
CodeGen_X86(Target t)44 CodeGen_X86::CodeGen_X86(Target t)
45 : CodeGen_Posix(complete_x86_target(t)) {
46
47 #if !defined(WITH_X86)
48 user_error << "x86 not enabled for this build of Halide.\n";
49 #endif
50
51 user_assert(llvm_X86_enabled) << "llvm build not configured with X86 target enabled.\n";
52 }
53
54 namespace {
55
56 // i32(i16_a)*i32(i16_b) +/- i32(i16_c)*i32(i16_d) can be done by
57 // interleaving a, c, and b, d, and then using pmaddwd. We
58 // recognize it here, and implement it in the initial module.
should_use_pmaddwd(const Expr & a,const Expr & b,vector<Expr> & result)59 bool should_use_pmaddwd(const Expr &a, const Expr &b, vector<Expr> &result) {
60 Type t = a.type();
61 internal_assert(b.type() == t);
62
63 const Mul *ma = a.as<Mul>();
64 const Mul *mb = b.as<Mul>();
65
66 if (!(ma && mb && t.is_int() && t.bits() == 32 && (t.lanes() >= 4))) {
67 return false;
68 }
69
70 Type narrow = t.with_bits(16);
71 vector<Expr> args = {lossless_cast(narrow, ma->a),
72 lossless_cast(narrow, ma->b),
73 lossless_cast(narrow, mb->a),
74 lossless_cast(narrow, mb->b)};
75 if (!args[0].defined() || !args[1].defined() ||
76 !args[2].defined() || !args[3].defined()) {
77 return false;
78 }
79
80 result.swap(args);
81 return true;
82 }
83
84 } // namespace
85
visit(const Add * op)86 void CodeGen_X86::visit(const Add *op) {
87 vector<Expr> matches;
88 if (should_use_pmaddwd(op->a, op->b, matches)) {
89 codegen(Call::make(op->type, "pmaddwd", matches, Call::Extern));
90 } else {
91 CodeGen_Posix::visit(op);
92 }
93 }
94
visit(const Sub * op)95 void CodeGen_X86::visit(const Sub *op) {
96 vector<Expr> matches;
97 if (should_use_pmaddwd(op->a, op->b, matches)) {
98 // Negate one of the factors in the second expression
99 if (is_const(matches[2])) {
100 matches[2] = -matches[2];
101 } else {
102 matches[3] = -matches[3];
103 }
104 codegen(Call::make(op->type, "pmaddwd", matches, Call::Extern));
105 } else {
106 CodeGen_Posix::visit(op);
107 }
108 }
109
visit(const Mul * op)110 void CodeGen_X86::visit(const Mul *op) {
111
112 #if LLVM_VERSION < 110
113 // Widening integer multiply of non-power-of-two vector sizes is
114 // broken in older llvms for older x86:
115 // https://bugs.llvm.org/show_bug.cgi?id=44976
116 const int lanes = op->type.lanes();
117 if (!target.has_feature(Target::SSE41) &&
118 (lanes & (lanes - 1)) &&
119 (op->type.bits() >= 32) &&
120 !op->type.is_float()) {
121 // Any fancy shuffles to pad or slice into smaller vectors
122 // just gets undone by LLVM and retriggers the bug. Just
123 // scalarize.
124 vector<Expr> result;
125 for (int i = 0; i < lanes; i++) {
126 result.emplace_back(Shuffle::make_extract_element(op->a, i) *
127 Shuffle::make_extract_element(op->b, i));
128 }
129 codegen(Shuffle::make_concat(result));
130 return;
131 }
132 #endif
133
134 return CodeGen_Posix::visit(op);
135 }
136
visit(const GT * op)137 void CodeGen_X86::visit(const GT *op) {
138 Type t = op->a.type();
139
140 if (t.is_vector() &&
141 upgrade_type_for_arithmetic(t) == t) {
142 // Non-native vector widths get legalized poorly by llvm. We
143 // split it up ourselves.
144
145 int slice_size = vector_lanes_for_slice(t);
146
147 Value *a = codegen(op->a), *b = codegen(op->b);
148 vector<Value *> result;
149 for (int i = 0; i < op->type.lanes(); i += slice_size) {
150 Value *sa = slice_vector(a, i, slice_size);
151 Value *sb = slice_vector(b, i, slice_size);
152 Value *slice_value;
153 if (t.is_float()) {
154 slice_value = builder->CreateFCmpOGT(sa, sb);
155 } else if (t.is_int()) {
156 slice_value = builder->CreateICmpSGT(sa, sb);
157 } else {
158 slice_value = builder->CreateICmpUGT(sa, sb);
159 }
160 result.push_back(slice_value);
161 }
162
163 value = concat_vectors(result);
164 value = slice_vector(value, 0, t.lanes());
165 } else {
166 CodeGen_Posix::visit(op);
167 }
168 }
169
visit(const EQ * op)170 void CodeGen_X86::visit(const EQ *op) {
171 Type t = op->a.type();
172
173 if (t.is_vector() &&
174 upgrade_type_for_arithmetic(t) == t) {
175 // Non-native vector widths get legalized poorly by llvm. We
176 // split it up ourselves.
177
178 int slice_size = vector_lanes_for_slice(t);
179
180 Value *a = codegen(op->a), *b = codegen(op->b);
181 vector<Value *> result;
182 for (int i = 0; i < op->type.lanes(); i += slice_size) {
183 Value *sa = slice_vector(a, i, slice_size);
184 Value *sb = slice_vector(b, i, slice_size);
185 Value *slice_value;
186 if (t.is_float()) {
187 slice_value = builder->CreateFCmpOEQ(sa, sb);
188 } else {
189 slice_value = builder->CreateICmpEQ(sa, sb);
190 }
191 result.push_back(slice_value);
192 }
193
194 value = concat_vectors(result);
195 value = slice_vector(value, 0, t.lanes());
196 } else {
197 CodeGen_Posix::visit(op);
198 }
199 }
200
visit(const LT * op)201 void CodeGen_X86::visit(const LT *op) {
202 codegen(op->b > op->a);
203 }
204
visit(const LE * op)205 void CodeGen_X86::visit(const LE *op) {
206 codegen(!(op->a > op->b));
207 }
208
visit(const GE * op)209 void CodeGen_X86::visit(const GE *op) {
210 codegen(!(op->b > op->a));
211 }
212
visit(const NE * op)213 void CodeGen_X86::visit(const NE *op) {
214 codegen(!(op->a == op->b));
215 }
216
visit(const Select * op)217 void CodeGen_X86::visit(const Select *op) {
218 if (op->condition.type().is_vector()) {
219 // LLVM handles selects on vector conditions much better at native width
220 Value *cond = codegen(op->condition);
221 Value *true_val = codegen(op->true_value);
222 Value *false_val = codegen(op->false_value);
223 Type t = op->true_value.type();
224 int slice_size = vector_lanes_for_slice(t);
225
226 vector<Value *> result;
227 for (int i = 0; i < t.lanes(); i += slice_size) {
228 Value *st = slice_vector(true_val, i, slice_size);
229 Value *sf = slice_vector(false_val, i, slice_size);
230 Value *sc = slice_vector(cond, i, slice_size);
231 Value *slice_value = builder->CreateSelect(sc, st, sf);
232 result.push_back(slice_value);
233 }
234
235 value = concat_vectors(result);
236 value = slice_vector(value, 0, t.lanes());
237 } else {
238 CodeGen_Posix::visit(op);
239 }
240 }
241
visit(const Cast * op)242 void CodeGen_X86::visit(const Cast *op) {
243
244 if (!op->type.is_vector()) {
245 // We only have peephole optimizations for vectors in here.
246 CodeGen_Posix::visit(op);
247 return;
248 }
249
250 vector<Expr> matches;
251
252 struct Pattern {
253 Target::Feature feature;
254 bool wide_op;
255 Type type;
256 int min_lanes;
257 string intrin;
258 Expr pattern;
259 };
260
261 static Pattern patterns[] = {
262 {Target::AVX2, true, Int(8, 32), 17, "llvm.sadd.sat.v32i8",
263 i8_sat(wild_i16x_ + wild_i16x_)},
264 {Target::FeatureEnd, true, Int(8, 16), 9, "llvm.sadd.sat.v16i8",
265 i8_sat(wild_i16x_ + wild_i16x_)},
266 {Target::FeatureEnd, true, Int(8, 8), 0, "llvm.sadd.sat.v8i8",
267 i8_sat(wild_i16x_ + wild_i16x_)},
268 {Target::AVX2, true, Int(8, 32), 17, "llvm.ssub.sat.v32i8",
269 i8_sat(wild_i16x_ - wild_i16x_)},
270 {Target::FeatureEnd, true, Int(8, 16), 9, "llvm.ssub.sat.v16i8",
271 i8_sat(wild_i16x_ - wild_i16x_)},
272 {Target::FeatureEnd, true, Int(8, 8), 0, "llvm.ssub.sat.v8i8",
273 i8_sat(wild_i16x_ - wild_i16x_)},
274 {Target::AVX2, true, Int(16, 16), 9, "llvm.sadd.sat.v16i16",
275 i16_sat(wild_i32x_ + wild_i32x_)},
276 {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.sadd.sat.v8i16",
277 i16_sat(wild_i32x_ + wild_i32x_)},
278 {Target::AVX2, true, Int(16, 16), 9, "llvm.ssub.sat.v16i16",
279 i16_sat(wild_i32x_ - wild_i32x_)},
280 {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.ssub.sat.v8i16",
281 i16_sat(wild_i32x_ - wild_i32x_)},
282
283 // Some of the instructions referred to below only appear with
284 // AVX2, but LLVM generates better AVX code if you give it
285 // full 256-bit vectors and let it do the slicing up into
286 // individual instructions itself. This is why we use
287 // Target::AVX instead of Target::AVX2 as the feature flag
288 // requirement.
289 {Target::AVX, true, UInt(8, 32), 17, "paddusbx32",
290 u8_sat(wild_u16x_ + wild_u16x_)},
291 {Target::FeatureEnd, true, UInt(8, 16), 0, "paddusbx16",
292 u8_sat(wild_u16x_ + wild_u16x_)},
293 {Target::AVX, true, UInt(8, 32), 17, "psubusbx32",
294 u8(max(wild_i16x_ - wild_i16x_, 0))},
295 {Target::FeatureEnd, true, UInt(8, 16), 0, "psubusbx16",
296 u8(max(wild_i16x_ - wild_i16x_, 0))},
297 {Target::AVX, true, UInt(16, 16), 9, "padduswx16",
298 u16_sat(wild_u32x_ + wild_u32x_)},
299 {Target::FeatureEnd, true, UInt(16, 8), 0, "padduswx8",
300 u16_sat(wild_u32x_ + wild_u32x_)},
301 {Target::AVX, true, UInt(16, 16), 9, "psubuswx16",
302 u16(max(wild_i32x_ - wild_i32x_, 0))},
303 {Target::FeatureEnd, true, UInt(16, 8), 0, "psubuswx8",
304 u16(max(wild_i32x_ - wild_i32x_, 0))},
305
306 // Only use the avx2 version if we have > 8 lanes
307 {Target::AVX2, true, Int(16, 16), 9, "llvm.x86.avx2.pmulh.w",
308 i16((wild_i32x_ * wild_i32x_) / 65536)},
309 {Target::AVX2, true, UInt(16, 16), 9, "llvm.x86.avx2.pmulhu.w",
310 u16((wild_u32x_ * wild_u32x_) / 65536)},
311 {Target::AVX2, true, Int(16, 16), 9, "llvm.x86.avx2.pmul.hr.sw",
312 i16((((wild_i32x_ * wild_i32x_) + 16384)) / 32768)},
313
314 {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.x86.sse2.pmulh.w",
315 i16((wild_i32x_ * wild_i32x_) / 65536)},
316 {Target::FeatureEnd, true, UInt(16, 8), 0, "llvm.x86.sse2.pmulhu.w",
317 u16((wild_u32x_ * wild_u32x_) / 65536)},
318 {Target::SSE41, true, Int(16, 8), 0, "llvm.x86.ssse3.pmul.hr.sw.128",
319 i16((((wild_i32x_ * wild_i32x_) + 16384)) / 32768)},
320 // LLVM 6.0+ require using helpers from x86.ll, x86_avx.ll
321 {Target::AVX2, true, UInt(8, 32), 17, "pavgbx32",
322 u8(((wild_u16x_ + wild_u16x_) + 1) / 2)},
323 {Target::FeatureEnd, true, UInt(8, 16), 0, "pavgbx16",
324 u8(((wild_u16x_ + wild_u16x_) + 1) / 2)},
325 {Target::AVX2, true, UInt(16, 16), 9, "pavgwx16",
326 u16(((wild_u32x_ + wild_u32x_) + 1) / 2)},
327 {Target::FeatureEnd, true, UInt(16, 8), 0, "pavgwx8",
328 u16(((wild_u32x_ + wild_u32x_) + 1) / 2)},
329 {Target::AVX2, false, Int(16, 16), 9, "packssdwx16",
330 i16_sat(wild_i32x_)},
331 {Target::FeatureEnd, false, Int(16, 8), 0, "packssdwx8",
332 i16_sat(wild_i32x_)},
333 {Target::AVX2, false, Int(8, 32), 17, "packsswbx32",
334 i8_sat(wild_i16x_)},
335 {Target::FeatureEnd, false, Int(8, 16), 0, "packsswbx16",
336 i8_sat(wild_i16x_)},
337 {Target::AVX2, false, UInt(8, 32), 17, "packuswbx32",
338 u8_sat(wild_i16x_)},
339 {Target::FeatureEnd, false, UInt(8, 16), 0, "packuswbx16",
340 u8_sat(wild_i16x_)},
341 {Target::AVX2, false, UInt(16, 16), 9, "packusdwx16",
342 u16_sat(wild_i32x_)},
343 {Target::SSE41, false, UInt(16, 8), 0, "packusdwx8",
344 u16_sat(wild_i32x_)}};
345
346 for (size_t i = 0; i < sizeof(patterns) / sizeof(patterns[0]); i++) {
347 const Pattern &pattern = patterns[i];
348
349 if (!target.has_feature(pattern.feature)) {
350 continue;
351 }
352
353 if (op->type.lanes() < pattern.min_lanes) {
354 continue;
355 }
356
357 if (expr_match(pattern.pattern, op, matches)) {
358 bool match = true;
359 if (pattern.wide_op) {
360 // Try to narrow the matches to the target type.
361 for (size_t i = 0; i < matches.size(); i++) {
362 matches[i] = lossless_cast(op->type, matches[i]);
363 if (!matches[i].defined()) match = false;
364 }
365 }
366 if (match) {
367 value = call_intrin(op->type, pattern.type.lanes(), pattern.intrin, matches);
368 return;
369 }
370 }
371 }
372
373 // Workaround for https://llvm.org/bugs/show_bug.cgi?id=24512
374 // LLVM uses a numerically unstable method for vector
375 // uint32->float conversion before AVX.
376 if (op->value.type().element_of() == UInt(32) &&
377 op->type.is_float() &&
378 op->type.is_vector() &&
379 !target.has_feature(Target::AVX)) {
380 Type signed_type = Int(32, op->type.lanes());
381
382 // Convert the top 31 bits to float using the signed version
383 Expr top_bits = cast(signed_type, op->value / 2);
384 top_bits = cast(op->type, top_bits);
385
386 // Convert the bottom bit
387 Expr bottom_bit = cast(signed_type, op->value % 2);
388 bottom_bit = cast(op->type, bottom_bit);
389
390 // Recombine as floats
391 codegen(top_bits + top_bits + bottom_bit);
392 return;
393 }
394
395 CodeGen_Posix::visit(op);
396 }
397
visit(const Call * op)398 void CodeGen_X86::visit(const Call *op) {
399 if (op->is_intrinsic(Call::mulhi_shr) &&
400 op->type.is_vector() && op->type.bits() == 16) {
401 internal_assert(op->args.size() == 3);
402 Expr p;
403 if (op->type.is_uint()) {
404 p = u16(u32(op->args[0]) * u32(op->args[1]) / 65536);
405 } else {
406 p = i16(i32(op->args[0]) * i32(op->args[1]) / 65536);
407 }
408 const UIntImm *shift = op->args[2].as<UIntImm>();
409 internal_assert(shift != nullptr) << "Third argument to mulhi_shr intrinsic must be an unsigned integer immediate.\n";
410 if (shift->value != 0) {
411 p = p >> shift->value;
412 }
413 value = codegen(p);
414 return;
415 }
416
417 CodeGen_Posix::visit(op);
418 }
419
visit(const VectorReduce * op)420 void CodeGen_X86::visit(const VectorReduce *op) {
421 const int factor = op->value.type().lanes() / op->type.lanes();
422
423 // Match pmaddwd. X86 doesn't have many horizontal reduction ops,
424 // and the ones that exist are hit by llvm automatically using the
425 // base class lowering of VectorReduce (see
426 // test/correctness/simd_op_check.cpp).
427 if (const Mul *mul = op->value.as<Mul>()) {
428 Type narrower = Int(16, mul->type.lanes());
429 Expr a = lossless_cast(narrower, mul->a);
430 Expr b = lossless_cast(narrower, mul->b);
431 if (op->type.is_int() &&
432 op->type.bits() == 32 &&
433 a.defined() &&
434 b.defined() &&
435 factor == 2 &&
436 op->op == VectorReduce::Add) {
437 if (target.has_feature(Target::AVX2) && op->type.lanes() > 4) {
438 value = call_intrin(op->type, 8, "llvm.x86.avx2.pmadd.wd", {a, b});
439 } else {
440 value = call_intrin(op->type, 4, "llvm.x86.sse2.pmadd.wd", {a, b});
441 }
442 return;
443 }
444 }
445
446 CodeGen_Posix::visit(op);
447 }
448
mcpu() const449 string CodeGen_X86::mcpu() const {
450 if (target.has_feature(Target::AVX512_Cannonlake)) return "cannonlake";
451 if (target.has_feature(Target::AVX512_Skylake)) return "skylake-avx512";
452 if (target.has_feature(Target::AVX512_KNL)) return "knl";
453 if (target.has_feature(Target::AVX2)) return "haswell";
454 if (target.has_feature(Target::AVX)) return "corei7-avx";
455 // We want SSE4.1 but not SSE4.2, hence "penryn" rather than "corei7"
456 if (target.has_feature(Target::SSE41)) return "penryn";
457 // Default should not include SSSE3, hence "k8" rather than "core2"
458 return "k8";
459 }
460
mattrs() const461 string CodeGen_X86::mattrs() const {
462 std::string features;
463 std::string separator;
464 if (target.has_feature(Target::FMA)) {
465 features += "+fma";
466 separator = ",";
467 }
468 if (target.has_feature(Target::FMA4)) {
469 features += separator + "+fma4";
470 separator = ",";
471 }
472 if (target.has_feature(Target::F16C)) {
473 features += separator + "+f16c";
474 separator = ",";
475 }
476 if (target.has_feature(Target::AVX512) ||
477 target.has_feature(Target::AVX512_KNL) ||
478 target.has_feature(Target::AVX512_Skylake) ||
479 target.has_feature(Target::AVX512_Cannonlake)) {
480 features += separator + "+avx512f,+avx512cd";
481 separator = ",";
482 if (target.has_feature(Target::AVX512_KNL)) {
483 features += ",+avx512pf,+avx512er";
484 }
485 if (target.has_feature(Target::AVX512_Skylake) ||
486 target.has_feature(Target::AVX512_Cannonlake)) {
487 features += ",+avx512vl,+avx512bw,+avx512dq";
488 }
489 if (target.has_feature(Target::AVX512_Cannonlake)) {
490 features += ",+avx512ifma,+avx512vbmi";
491 }
492 }
493 return features;
494 }
495
use_soft_float_abi() const496 bool CodeGen_X86::use_soft_float_abi() const {
497 return false;
498 }
499
native_vector_bits() const500 int CodeGen_X86::native_vector_bits() const {
501 if (target.has_feature(Target::AVX512) ||
502 target.has_feature(Target::AVX512_Skylake) ||
503 target.has_feature(Target::AVX512_KNL) ||
504 target.has_feature(Target::AVX512_Cannonlake)) {
505 return 512;
506 } else if (target.has_feature(Target::AVX) ||
507 target.has_feature(Target::AVX2)) {
508 return 256;
509 } else {
510 return 128;
511 }
512 }
513
vector_lanes_for_slice(const Type & t) const514 int CodeGen_X86::vector_lanes_for_slice(const Type &t) const {
515 // We don't want to pad all the way out to natural_vector_size,
516 // because llvm generates crappy code. Better to use a smaller
517 // type if we can.
518 int vec_bits = t.lanes() * t.bits();
519 int natural_vec_bits = target.natural_vector_size(t) * t.bits();
520 // clang-format off
521 int slice_bits = ((vec_bits > 256 && natural_vec_bits > 256) ? 512 :
522 (vec_bits > 128 && natural_vec_bits > 128) ? 256 :
523 128);
524 // clang-format on
525 return slice_bits / t.bits();
526 }
527
llvm_type_of(const Type & t) const528 llvm::Type *CodeGen_X86::llvm_type_of(const Type &t) const {
529 if (t.is_float() && t.bits() < 32) {
530 // LLVM as of August 2019 has all sorts of issues in the x86
531 // backend for half types. It injects expensive calls to
532 // convert between float and half for seemingly no reason
533 // (e.g. to do a select), and bitcasting to int16 doesn't
534 // help, because it simplifies away the bitcast for you.
535 // See: https://bugs.llvm.org/show_bug.cgi?id=43065
536 // and: https://github.com/halide/Halide/issues/4166
537 return llvm_type_of(t.with_code(halide_type_uint));
538 } else {
539 return CodeGen_Posix::llvm_type_of(t);
540 }
541 }
542
543 } // namespace Internal
544 } // namespace Halide
545