1 #include <iostream>
2 #include <sstream>
3
4 #include "CSE.h"
5 #include "CodeGen_ARM.h"
6 #include "CodeGen_Internal.h"
7 #include "ConciseCasts.h"
8 #include "Debug.h"
9 #include "IREquality.h"
10 #include "IRMatch.h"
11 #include "IROperator.h"
12 #include "IRPrinter.h"
13 #include "LLVM_Headers.h"
14 #include "Simplify.h"
15 #include "Util.h"
16
17 namespace Halide {
18 namespace Internal {
19
20 using std::ostringstream;
21 using std::pair;
22 using std::string;
23 using std::vector;
24
25 using namespace Halide::ConciseCasts;
26 using namespace llvm;
27
CodeGen_ARM(Target target)28 CodeGen_ARM::CodeGen_ARM(Target target)
29 : CodeGen_Posix(target) {
30 if (target.bits == 32) {
31 #if !defined(WITH_ARM)
32 user_error << "arm not enabled for this build of Halide.";
33 #endif
34 user_assert(llvm_ARM_enabled) << "llvm build not configured with ARM target enabled\n.";
35 } else {
36 #if !defined(WITH_AARCH64)
37 user_error << "aarch64 not enabled for this build of Halide.";
38 #endif
39 user_assert(llvm_AArch64_enabled) << "llvm build not configured with AArch64 target enabled.\n";
40 }
41
42 // Generate the cast patterns that can take vector types. We need
43 // to iterate over all 64 and 128 bit integer types relevant for
44 // neon.
45 Type types[] = {Int(8, 8), Int(8, 16), UInt(8, 8), UInt(8, 16),
46 Int(16, 4), Int(16, 8), UInt(16, 4), UInt(16, 8),
47 Int(32, 2), Int(32, 4), UInt(32, 2), UInt(32, 4)};
48 for (size_t i = 0; i < sizeof(types) / sizeof(types[0]); i++) {
49 Type t = types[i];
50
51 int intrin_lanes = t.lanes();
52 std::ostringstream oss;
53 oss << ".v" << intrin_lanes << "i" << t.bits();
54 string t_str = oss.str();
55
56 // For the 128-bit versions, we want to match any vector width.
57 if (t.bits() * t.lanes() == 128) {
58 t = t.with_lanes(0);
59 }
60
61 // Wider versions of the type
62 Type w = t.with_bits(t.bits() * 2);
63 Type ws = Int(t.bits() * 2, t.lanes());
64
65 // Vector wildcard for this type
66 Expr vector = Variable::make(t, "*");
67 Expr w_vector = Variable::make(w, "*");
68 Expr ws_vector = Variable::make(ws, "*");
69
70 // Bounds of the type stored in the wider vector type
71 Expr tmin = simplify(cast(w, t.min()));
72 Expr tmax = simplify(cast(w, t.max()));
73 Expr tsmin = simplify(cast(ws, t.min()));
74 Expr tsmax = simplify(cast(ws, t.max()));
75
76 Pattern p("", "", intrin_lanes, Expr(), Pattern::NarrowArgs);
77
78 // Rounding-up averaging
79 if (t.is_int()) {
80 p.intrin32 = "llvm.arm.neon.vrhadds" + t_str;
81 p.intrin64 = "llvm.aarch64.neon.srhadd" + t_str;
82 } else {
83 p.intrin32 = "llvm.arm.neon.vrhaddu" + t_str;
84 p.intrin64 = "llvm.aarch64.neon.urhadd" + t_str;
85 }
86
87 p.pattern = cast(t, (w_vector + w_vector + 1) / 2);
88 casts.push_back(p);
89 p.pattern = cast(t, (w_vector + (w_vector + 1)) / 2);
90 casts.push_back(p);
91 p.pattern = cast(t, ((w_vector + 1) + w_vector) / 2);
92 casts.push_back(p);
93
94 // Rounding down averaging
95 if (t.is_int()) {
96 p.intrin32 = "llvm.arm.neon.vhadds" + t_str;
97 p.intrin64 = "llvm.aarch64.neon.shadd" + t_str;
98 } else {
99 p.intrin32 = "llvm.arm.neon.vhaddu" + t_str;
100 p.intrin64 = "llvm.aarch64.neon.uhadd" + t_str;
101 }
102 p.pattern = cast(t, (w_vector + w_vector) / 2);
103 casts.push_back(p);
104
105 // Halving subtract
106 if (t.is_int()) {
107 p.intrin32 = "llvm.arm.neon.vhsubs" + t_str;
108 p.intrin64 = "llvm.aarch64.neon.shsub" + t_str;
109 } else {
110 p.intrin32 = "llvm.arm.neon.vhsubu" + t_str;
111 p.intrin64 = "llvm.aarch64.neon.uhsub" + t_str;
112 }
113 p.pattern = cast(t, (w_vector - w_vector) / 2);
114 casts.push_back(p);
115
116 // Saturating add
117 #if LLVM_VERSION >= 100
118 if (t.is_int()) {
119 p.intrin32 = "llvm.sadd.sat" + t_str;
120 p.intrin64 = "llvm.sadd.sat" + t_str;
121 } else {
122 p.intrin32 = "llvm.uadd.sat" + t_str;
123 p.intrin64 = "llvm.uadd.sat" + t_str;
124 }
125 #else
126 if (t.is_int()) {
127 p.intrin32 = "llvm.arm.neon.vqadds" + t_str;
128 p.intrin64 = "llvm.aarch64.neon.sqadd" + t_str;
129 } else {
130 p.intrin32 = "llvm.arm.neon.vqaddu" + t_str;
131 p.intrin64 = "llvm.aarch64.neon.uqadd" + t_str;
132 }
133 #endif
134 p.pattern = cast(t, clamp(w_vector + w_vector, tmin, tmax));
135 casts.push_back(p);
136
137 // In the unsigned case, the saturation below is unnecessary
138 if (t.is_uint()) {
139 p.pattern = cast(t, min(w_vector + w_vector, tmax));
140 casts.push_back(p);
141 }
142
143 // Saturating subtract
144 // N.B. Saturating subtracts always widen to a signed type
145 #if LLVM_VERSION >= 100
146 if (t.is_int()) {
147 p.intrin32 = "llvm.ssub.sat" + t_str;
148 p.intrin64 = "llvm.ssub.sat" + t_str;
149 } else {
150 p.intrin32 = "llvm.usub.sat" + t_str;
151 p.intrin64 = "llvm.usub.sat" + t_str;
152 }
153 #else
154 if (t.is_int()) {
155 p.intrin32 = "llvm.arm.neon.vqsubs" + t_str;
156 p.intrin64 = "llvm.aarch64.neon.sqsub" + t_str;
157 } else {
158 p.intrin32 = "llvm.arm.neon.vqsubu" + t_str;
159 p.intrin64 = "llvm.aarch64.neon.uqsub" + t_str;
160 }
161 #endif
162 p.pattern = cast(t, clamp(ws_vector - ws_vector, tsmin, tsmax));
163 casts.push_back(p);
164
165 // In the unsigned case, we may detect that the top of the clamp is unnecessary
166 if (t.is_uint()) {
167 p.pattern = cast(t, max(ws_vector - ws_vector, 0));
168 casts.push_back(p);
169 }
170 }
171
172 casts.emplace_back("vqrdmulh.v4i16", "sqrdmulh.v4i16", 4,
173 i16_sat((wild_i32x4 * wild_i32x4 + (1 << 14)) / (1 << 15)),
174 Pattern::NarrowArgs);
175 casts.emplace_back("vqrdmulh.v8i16", "sqrdmulh.v8i16", 8,
176 i16_sat((wild_i32x_ * wild_i32x_ + (1 << 14)) / (1 << 15)),
177 Pattern::NarrowArgs);
178 casts.emplace_back("vqrdmulh.v2i32", "sqrdmulh.v2i32", 2,
179 i32_sat((wild_i64x2 * wild_i64x2 + (1 << 30)) / Expr(int64_t(1) << 31)),
180 Pattern::NarrowArgs);
181 casts.emplace_back("vqrdmulh.v4i32", "sqrdmulh.v4i32", 4,
182 i32_sat((wild_i64x_ * wild_i64x_ + (1 << 30)) / Expr(int64_t(1) << 31)),
183 Pattern::NarrowArgs);
184
185 casts.emplace_back("vqshiftns.v8i8", "sqshrn.v8i8", 8, i8_sat(wild_i16x_ / wild_i16x_), Pattern::RightShift);
186 casts.emplace_back("vqshiftns.v4i16", "sqshrn.v4i16", 4, i16_sat(wild_i32x_ / wild_i32x_), Pattern::RightShift);
187 casts.emplace_back("vqshiftns.v2i32", "sqshrn.v2i32", 2, i32_sat(wild_i64x_ / wild_i64x_), Pattern::RightShift);
188 casts.emplace_back("vqshiftnu.v8i8", "uqshrn.v8i8", 8, u8_sat(wild_u16x_ / wild_u16x_), Pattern::RightShift);
189 casts.emplace_back("vqshiftnu.v4i16", "uqshrn.v4i16", 4, u16_sat(wild_u32x_ / wild_u32x_), Pattern::RightShift);
190 casts.emplace_back("vqshiftnu.v2i32", "uqshrn.v2i32", 2, u32_sat(wild_u64x_ / wild_u64x_), Pattern::RightShift);
191 casts.emplace_back("vqshiftnsu.v8i8", "sqshrun.v8i8", 8, u8_sat(wild_i16x_ / wild_i16x_), Pattern::RightShift);
192 casts.emplace_back("vqshiftnsu.v4i16", "sqshrun.v4i16", 4, u16_sat(wild_i32x_ / wild_i32x_), Pattern::RightShift);
193 casts.emplace_back("vqshiftnsu.v2i32", "sqshrun.v2i32", 2, u32_sat(wild_i64x_ / wild_i64x_), Pattern::RightShift);
194
195 // Where a 64-bit and 128-bit version exist, we use the 64-bit
196 // version only when the args are 64-bits wide.
197 casts.emplace_back("vqshifts.v8i8", "sqshl.v8i8", 8, i8_sat(i16(wild_i8x8) * wild_i16x8), Pattern::LeftShift);
198 casts.emplace_back("vqshifts.v4i16", "sqshl.v4i16", 4, i16_sat(i32(wild_i16x4) * wild_i32x4), Pattern::LeftShift);
199 casts.emplace_back("vqshifts.v2i32", "sqshl.v2i32", 2, i32_sat(i64(wild_i32x2) * wild_i64x2), Pattern::LeftShift);
200 casts.emplace_back("vqshiftu.v8i8", "uqshl.v8i8", 8, u8_sat(u16(wild_u8x8) * wild_u16x8), Pattern::LeftShift);
201 casts.emplace_back("vqshiftu.v4i16", "uqshl.v4i16", 4, u16_sat(u32(wild_u16x4) * wild_u32x4), Pattern::LeftShift);
202 casts.emplace_back("vqshiftu.v2i32", "uqshl.v2i32", 2, u32_sat(u64(wild_u32x2) * wild_u64x2), Pattern::LeftShift);
203 casts.emplace_back("vqshiftsu.v8i8", "sqshlu.v8i8", 8, u8_sat(i16(wild_i8x8) * wild_i16x8), Pattern::LeftShift);
204 casts.emplace_back("vqshiftsu.v4i16", "sqshlu.v4i16", 4, u16_sat(i32(wild_i16x4) * wild_i32x4), Pattern::LeftShift);
205 casts.emplace_back("vqshiftsu.v2i32", "sqshlu.v2i32", 2, u32_sat(i64(wild_i32x2) * wild_i64x2), Pattern::LeftShift);
206
207 // We use the 128-bit version for all other vector widths.
208 casts.emplace_back("vqshifts.v16i8", "sqshl.v16i8", 16, i8_sat(i16(wild_i8x_) * wild_i16x_), Pattern::LeftShift);
209 casts.emplace_back("vqshifts.v8i16", "sqshl.v8i16", 8, i16_sat(i32(wild_i16x_) * wild_i32x_), Pattern::LeftShift);
210 casts.emplace_back("vqshifts.v4i32", "sqshl.v4i32", 4, i32_sat(i64(wild_i32x_) * wild_i64x_), Pattern::LeftShift);
211 casts.emplace_back("vqshiftu.v16i8", "uqshl.v16i8", 16, u8_sat(u16(wild_u8x_) * wild_u16x_), Pattern::LeftShift);
212 casts.emplace_back("vqshiftu.v8i16", "uqshl.v8i16", 8, u16_sat(u32(wild_u16x_) * wild_u32x_), Pattern::LeftShift);
213 casts.emplace_back("vqshiftu.v4i32", "uqshl.v4i32", 4, u32_sat(u64(wild_u32x_) * wild_u64x_), Pattern::LeftShift);
214 casts.emplace_back("vqshiftsu.v16i8", "sqshlu.v16i8", 16, u8_sat(i16(wild_i8x_) * wild_i16x_), Pattern::LeftShift);
215 casts.emplace_back("vqshiftsu.v8i16", "sqshlu.v8i16", 8, u16_sat(i32(wild_i16x_) * wild_i32x_), Pattern::LeftShift);
216 casts.emplace_back("vqshiftsu.v4i32", "sqshlu.v4i32", 4, u32_sat(i64(wild_i32x_) * wild_i64x_), Pattern::LeftShift);
217
218 casts.emplace_back("vqmovns.v8i8", "sqxtn.v8i8", 8, i8_sat(wild_i16x_));
219 casts.emplace_back("vqmovns.v4i16", "sqxtn.v4i16", 4, i16_sat(wild_i32x_));
220 casts.emplace_back("vqmovns.v2i32", "sqxtn.v2i32", 2, i32_sat(wild_i64x_));
221 casts.emplace_back("vqmovnu.v8i8", "uqxtn.v8i8", 8, u8_sat(wild_u16x_));
222 casts.emplace_back("vqmovnu.v4i16", "uqxtn.v4i16", 4, u16_sat(wild_u32x_));
223 casts.emplace_back("vqmovnu.v2i32", "uqxtn.v2i32", 2, u32_sat(wild_u64x_));
224 casts.emplace_back("vqmovnsu.v8i8", "sqxtun.v8i8", 8, u8_sat(wild_i16x_));
225 casts.emplace_back("vqmovnsu.v4i16", "sqxtun.v4i16", 4, u16_sat(wild_i32x_));
226 casts.emplace_back("vqmovnsu.v2i32", "sqxtun.v2i32", 2, u32_sat(wild_i64x_));
227
228 // Overflow for int32 is not defined by Halide, so for those we can take
229 // advantage of special add-and-halve instructions.
230 //
231 // 64-bit averaging round-down
232 averagings.emplace_back("vhadds.v2i32", "shadd.v2i32", 2, (wild_i32x2 + wild_i32x2));
233
234 // 128-bit
235 averagings.emplace_back("vhadds.v4i32", "shadd.v4i32", 4, (wild_i32x_ + wild_i32x_));
236
237 // 64-bit halving subtract
238 averagings.emplace_back("vhsubs.v2i32", "shsub.v2i32", 2, (wild_i32x2 - wild_i32x2));
239
240 // 128-bit
241 averagings.emplace_back("vhsubs.v4i32", "shsub.v4i32", 4, (wild_i32x_ - wild_i32x_));
242
243 // 64-bit saturating negation
244 negations.emplace_back("vqneg.v8i8", "sqneg.v8i8", 8, -max(wild_i8x8, -127));
245 negations.emplace_back("vqneg.v4i16", "sqneg.v4i16", 4, -max(wild_i16x4, -32767));
246 negations.emplace_back("vqneg.v2i32", "sqneg.v2i32", 2, -max(wild_i32x2, -(0x7fffffff)));
247
248 // 128-bit
249 negations.emplace_back("vqneg.v16i8", "sqneg.v16i8", 16, -max(wild_i8x_, -127));
250 negations.emplace_back("vqneg.v8i16", "sqneg.v8i16", 8, -max(wild_i16x_, -32767));
251 negations.emplace_back("vqneg.v4i32", "sqneg.v4i32", 4, -max(wild_i32x_, -(0x7fffffff)));
252
253 // Widening multiplies.
254 multiplies.emplace_back("vmulls.v2i64", "smull.v2i64", 2,
255 wild_i64x_ * wild_i64x_,
256 Pattern::NarrowArgs);
257 multiplies.emplace_back("vmullu.v2i64", "umull.v2i64", 2,
258 wild_u64x_ * wild_u64x_,
259 Pattern::NarrowArgs);
260 multiplies.emplace_back("vmulls.v4i32", "smull.v4i32", 4,
261 wild_i32x_ * wild_i32x_,
262 Pattern::NarrowArgs);
263 multiplies.emplace_back("vmullu.v4i32", "umull.v4i32", 4,
264 wild_u32x_ * wild_u32x_,
265 Pattern::NarrowArgs);
266 multiplies.emplace_back("vmulls.v8i16", "smull.v8i16", 8,
267 wild_i16x_ * wild_i16x_,
268 Pattern::NarrowArgs);
269 multiplies.emplace_back("vmullu.v8i16", "umull.v8i16", 8,
270 wild_u16x_ * wild_u16x_,
271 Pattern::NarrowArgs);
272 }
273
call_pattern(const Pattern & p,Type t,const vector<Expr> & args)274 Value *CodeGen_ARM::call_pattern(const Pattern &p, Type t, const vector<Expr> &args) {
275 if (target.bits == 32) {
276 return call_intrin(t, p.intrin_lanes, p.intrin32, args);
277 } else {
278 return call_intrin(t, p.intrin_lanes, p.intrin64, args);
279 }
280 }
281
call_pattern(const Pattern & p,llvm::Type * t,const vector<llvm::Value * > & args)282 Value *CodeGen_ARM::call_pattern(const Pattern &p, llvm::Type *t, const vector<llvm::Value *> &args) {
283 if (target.bits == 32) {
284 return call_intrin(t, p.intrin_lanes, p.intrin32, args);
285 } else {
286 return call_intrin(t, p.intrin_lanes, p.intrin64, args);
287 }
288 }
289
visit(const Cast * op)290 void CodeGen_ARM::visit(const Cast *op) {
291 if (neon_intrinsics_disabled()) {
292 CodeGen_Posix::visit(op);
293 return;
294 }
295
296 Type t = op->type;
297
298 vector<Expr> matches;
299
300 for (size_t i = 0; i < casts.size(); i++) {
301 const Pattern &pattern = casts[i];
302 //debug(4) << "Trying pattern: " << patterns[i].intrin << " " << patterns[i].pattern << "\n";
303 if (expr_match(pattern.pattern, op, matches)) {
304
305 //debug(4) << "Match!\n";
306 if (pattern.type == Pattern::Simple) {
307 value = call_pattern(pattern, t, matches);
308 return;
309 } else if (pattern.type == Pattern::NarrowArgs) {
310 // Try to narrow all of the args.
311 bool all_narrow = true;
312 for (size_t i = 0; i < matches.size(); i++) {
313 internal_assert(matches[i].type().bits() == t.bits() * 2);
314 internal_assert(matches[i].type().lanes() == t.lanes());
315 // debug(4) << "Attemping to narrow " << matches[i] << " to " << t << "\n";
316 matches[i] = lossless_cast(t, matches[i]);
317 if (!matches[i].defined()) {
318 // debug(4) << "failed\n";
319 all_narrow = false;
320 } else {
321 // debug(4) << "success: " << matches[i] << "\n";
322 internal_assert(matches[i].type() == t);
323 }
324 }
325
326 if (all_narrow) {
327 value = call_pattern(pattern, t, matches);
328 return;
329 }
330 } else { // must be a shift
331 Expr constant = matches[1];
332 int shift_amount;
333 bool power_of_two = is_const_power_of_two_integer(constant, &shift_amount);
334 if (power_of_two && shift_amount < matches[0].type().bits()) {
335 if (target.bits == 32 && pattern.type == Pattern::RightShift) {
336 // The arm32 llvm backend wants right shifts to come in as negative values.
337 shift_amount = -shift_amount;
338 }
339 Value *shift = nullptr;
340 // The arm64 llvm backend wants i32 constants for right shifts.
341 if (target.bits == 64 && pattern.type == Pattern::RightShift) {
342 shift = ConstantInt::get(i32_t, shift_amount);
343 } else {
344 shift = ConstantInt::get(llvm_type_of(matches[0].type()), shift_amount);
345 }
346 value = call_pattern(pattern, llvm_type_of(t),
347 {codegen(matches[0]), shift});
348 return;
349 }
350 }
351 }
352 }
353
354 // Catch extract-high-half-of-signed integer pattern and convert
355 // it to extract-high-half-of-unsigned-integer. llvm peephole
356 // optimization recognizes logical shift right but not arithemtic
357 // shift right for this pattern. This matters for vaddhn of signed
358 // integers.
359 if (t.is_vector() &&
360 (t.is_int() || t.is_uint()) &&
361 op->value.type().is_int() &&
362 t.bits() == op->value.type().bits() / 2) {
363 const Div *d = op->value.as<Div>();
364 if (d && is_const(d->b, int64_t(1) << t.bits())) {
365 Type unsigned_type = UInt(t.bits() * 2, t.lanes());
366 Expr replacement = cast(t,
367 cast(unsigned_type, d->a) /
368 cast(unsigned_type, d->b));
369 replacement.accept(this);
370 return;
371 }
372 }
373
374 // Catch widening of absolute difference
375 if (t.is_vector() &&
376 (t.is_int() || t.is_uint()) &&
377 (op->value.type().is_int() || op->value.type().is_uint()) &&
378 t.bits() == op->value.type().bits() * 2) {
379 Expr a, b;
380 const Call *c = op->value.as<Call>();
381 if (c && c->is_intrinsic(Call::absd)) {
382 ostringstream ss;
383 int intrin_lanes = 128 / t.bits();
384 ss << "vabdl_" << (c->args[0].type().is_int() ? "i" : "u") << t.bits() / 2 << "x" << intrin_lanes;
385 value = call_intrin(t, intrin_lanes, ss.str(), c->args);
386 return;
387 }
388 }
389
390 CodeGen_Posix::visit(op);
391 }
392
visit(const Mul * op)393 void CodeGen_ARM::visit(const Mul *op) {
394 if (neon_intrinsics_disabled()) {
395 CodeGen_Posix::visit(op);
396 return;
397 }
398
399 // We only have peephole optimizations for int vectors for now
400 if (op->type.is_scalar() || op->type.is_float()) {
401 CodeGen_Posix::visit(op);
402 return;
403 }
404
405 Type t = op->type;
406 vector<Expr> matches;
407
408 int shift_amount = 0;
409 if (is_const_power_of_two_integer(op->b, &shift_amount)) {
410 // Let LLVM handle these.
411 CodeGen_Posix::visit(op);
412 return;
413 }
414
415 // LLVM really struggles to generate mlal unless we generate mull intrinsics
416 // for the multiplication part first.
417 for (size_t i = 0; i < multiplies.size(); i++) {
418 const Pattern &pattern = multiplies[i];
419 //debug(4) << "Trying pattern: " << patterns[i].intrin << " " << patterns[i].pattern << "\n";
420 if (expr_match(pattern.pattern, op, matches)) {
421
422 //debug(4) << "Match!\n";
423 if (pattern.type == Pattern::Simple) {
424 value = call_pattern(pattern, t, matches);
425 return;
426 } else if (pattern.type == Pattern::NarrowArgs) {
427 Type narrow_t = t.with_bits(t.bits() / 2);
428 // Try to narrow all of the args.
429 bool all_narrow = true;
430 for (size_t i = 0; i < matches.size(); i++) {
431 internal_assert(matches[i].type().bits() == t.bits());
432 internal_assert(matches[i].type().lanes() == t.lanes());
433 // debug(4) << "Attemping to narrow " << matches[i] << " to " << t << "\n";
434 matches[i] = lossless_cast(narrow_t, matches[i]);
435 if (!matches[i].defined()) {
436 // debug(4) << "failed\n";
437 all_narrow = false;
438 } else {
439 // debug(4) << "success: " << matches[i] << "\n";
440 internal_assert(matches[i].type() == narrow_t);
441 }
442 }
443
444 if (all_narrow) {
445 value = call_pattern(pattern, t, matches);
446 return;
447 }
448 }
449 }
450 }
451
452 // Vector multiplies by 3, 5, 7, 9 should do shift-and-add or
453 // shift-and-sub instead to reduce register pressure (the
454 // shift is an immediate)
455 // TODO: Verify this is still good codegen.
456 if (is_const(op->b, 3)) {
457 value = codegen(op->a * 2 + op->a);
458 return;
459 } else if (is_const(op->b, 5)) {
460 value = codegen(op->a * 4 + op->a);
461 return;
462 } else if (is_const(op->b, 7)) {
463 value = codegen(op->a * 8 - op->a);
464 return;
465 } else if (is_const(op->b, 9)) {
466 value = codegen(op->a * 8 + op->a);
467 return;
468 }
469
470 CodeGen_Posix::visit(op);
471 }
472
visit(const Div * op)473 void CodeGen_ARM::visit(const Div *op) {
474 if (!neon_intrinsics_disabled() &&
475 op->type.is_vector() && is_two(op->b) &&
476 (op->a.as<Add>() || op->a.as<Sub>())) {
477 vector<Expr> matches;
478 for (size_t i = 0; i < averagings.size(); i++) {
479 if (expr_match(averagings[i].pattern, op->a, matches)) {
480 value = call_pattern(averagings[i], op->type, matches);
481 return;
482 }
483 }
484 }
485 CodeGen_Posix::visit(op);
486 }
487
visit(const Sub * op)488 void CodeGen_ARM::visit(const Sub *op) {
489 if (neon_intrinsics_disabled()) {
490 CodeGen_Posix::visit(op);
491 return;
492 }
493
494 vector<Expr> matches;
495 for (size_t i = 0; i < negations.size(); i++) {
496 if (op->type.is_vector() &&
497 expr_match(negations[i].pattern, op, matches)) {
498 value = call_pattern(negations[i], op->type, matches);
499 return;
500 }
501 }
502
503 // llvm will generate floating point negate instructions if we ask for (-0.0f)-x
504 if (op->type.is_float() &&
505 op->type.bits() >= 32 &&
506 is_zero(op->a)) {
507 Constant *a;
508 if (op->type.bits() == 32) {
509 a = ConstantFP::getNegativeZero(f32_t);
510 } else if (op->type.bits() == 64) {
511 a = ConstantFP::getNegativeZero(f64_t);
512 } else {
513 a = nullptr;
514 internal_error << "Unknown bit width for floating point type: " << op->type << "\n";
515 }
516
517 Value *b = codegen(op->b);
518
519 if (op->type.lanes() > 1) {
520 a = ConstantVector::getSplat(element_count(op->type.lanes()), a);
521 }
522 value = builder->CreateFSub(a, b);
523 return;
524 }
525
526 CodeGen_Posix::visit(op);
527 }
528
visit(const Min * op)529 void CodeGen_ARM::visit(const Min *op) {
530 if (neon_intrinsics_disabled()) {
531 CodeGen_Posix::visit(op);
532 return;
533 }
534
535 if (op->type == Float(32)) {
536 // Use a 2-wide vector instead
537 Value *undef = UndefValue::get(f32x2);
538 Constant *zero = ConstantInt::get(i32_t, 0);
539 Value *a = codegen(op->a);
540 Value *a_wide = builder->CreateInsertElement(undef, a, zero);
541 Value *b = codegen(op->b);
542 Value *b_wide = builder->CreateInsertElement(undef, b, zero);
543 Value *wide_result;
544 if (target.bits == 32) {
545 wide_result = call_intrin(f32x2, 2, "llvm.arm.neon.vmins.v2f32", {a_wide, b_wide});
546 } else {
547 wide_result = call_intrin(f32x2, 2, "llvm.aarch64.neon.fmin.v2f32", {a_wide, b_wide});
548 }
549 value = builder->CreateExtractElement(wide_result, zero);
550 return;
551 }
552
553 struct {
554 Type t;
555 const char *op;
556 } patterns[] = {
557 {UInt(8, 8), "v8i8"},
558 {UInt(16, 4), "v4i16"},
559 {UInt(32, 2), "v2i32"},
560 {Int(8, 8), "v8i8"},
561 {Int(16, 4), "v4i16"},
562 {Int(32, 2), "v2i32"},
563 {Float(32, 2), "v2f32"},
564 {UInt(8, 16), "v16i8"},
565 {UInt(16, 8), "v8i16"},
566 {UInt(32, 4), "v4i32"},
567 {Int(8, 16), "v16i8"},
568 {Int(16, 8), "v8i16"},
569 {Int(32, 4), "v4i32"},
570 {Float(32, 4), "v4f32"}};
571
572 for (size_t i = 0; i < sizeof(patterns) / sizeof(patterns[0]); i++) {
573 bool match = op->type == patterns[i].t;
574
575 // The 128-bit versions are also used for other vector widths.
576 if (op->type.is_vector() && patterns[i].t.lanes() * patterns[i].t.bits() == 128) {
577 match = match || (op->type.element_of() == patterns[i].t.element_of());
578 }
579
580 if (match) {
581 string intrin;
582 if (target.bits == 32) {
583 intrin = (string("llvm.arm.neon.") + (op->type.is_uint() ? "vminu." : "vmins.")) + patterns[i].op;
584 } else {
585 intrin = "llvm.aarch64.neon.";
586 if (op->type.is_int()) {
587 intrin += "smin.";
588 } else if (op->type.is_float()) {
589 intrin += "fmin.";
590 } else {
591 intrin += "umin.";
592 }
593 intrin += patterns[i].op;
594 }
595 value = call_intrin(op->type, patterns[i].t.lanes(), intrin, {op->a, op->b});
596 return;
597 }
598 }
599
600 CodeGen_Posix::visit(op);
601 }
602
visit(const Max * op)603 void CodeGen_ARM::visit(const Max *op) {
604 if (neon_intrinsics_disabled()) {
605 CodeGen_Posix::visit(op);
606 return;
607 }
608
609 if (op->type == Float(32)) {
610 // Use a 2-wide vector instead
611 Value *undef = UndefValue::get(f32x2);
612 Constant *zero = ConstantInt::get(i32_t, 0);
613 Value *a = codegen(op->a);
614 Value *a_wide = builder->CreateInsertElement(undef, a, zero);
615 Value *b = codegen(op->b);
616 Value *b_wide = builder->CreateInsertElement(undef, b, zero);
617 Value *wide_result;
618 if (target.bits == 32) {
619 wide_result = call_intrin(f32x2, 2, "llvm.arm.neon.vmaxs.v2f32", {a_wide, b_wide});
620 } else {
621 wide_result = call_intrin(f32x2, 2, "llvm.aarch64.neon.fmax.v2f32", {a_wide, b_wide});
622 }
623 value = builder->CreateExtractElement(wide_result, zero);
624 return;
625 }
626
627 struct {
628 Type t;
629 const char *op;
630 } patterns[] = {
631 {UInt(8, 8), "v8i8"},
632 {UInt(16, 4), "v4i16"},
633 {UInt(32, 2), "v2i32"},
634 {Int(8, 8), "v8i8"},
635 {Int(16, 4), "v4i16"},
636 {Int(32, 2), "v2i32"},
637 {Float(32, 2), "v2f32"},
638 {UInt(8, 16), "v16i8"},
639 {UInt(16, 8), "v8i16"},
640 {UInt(32, 4), "v4i32"},
641 {Int(8, 16), "v16i8"},
642 {Int(16, 8), "v8i16"},
643 {Int(32, 4), "v4i32"},
644 {Float(32, 4), "v4f32"}};
645
646 for (size_t i = 0; i < sizeof(patterns) / sizeof(patterns[0]); i++) {
647 bool match = op->type == patterns[i].t;
648
649 // The 128-bit versions are also used for other vector widths.
650 if (op->type.is_vector() && patterns[i].t.lanes() * patterns[i].t.bits() == 128) {
651 match = match || (op->type.element_of() == patterns[i].t.element_of());
652 }
653
654 if (match) {
655 string intrin;
656 if (target.bits == 32) {
657 intrin = (string("llvm.arm.neon.") + (op->type.is_uint() ? "vmaxu." : "vmaxs.")) + patterns[i].op;
658 } else {
659 intrin = "llvm.aarch64.neon.";
660 if (op->type.is_int()) {
661 intrin += "smax.";
662 } else if (op->type.is_float()) {
663 intrin += "fmax.";
664 } else {
665 intrin += "umax.";
666 }
667 intrin += patterns[i].op;
668 }
669 value = call_intrin(op->type, patterns[i].t.lanes(), intrin, {op->a, op->b});
670 return;
671 }
672 }
673
674 CodeGen_Posix::visit(op);
675 }
676
visit(const Store * op)677 void CodeGen_ARM::visit(const Store *op) {
678 // Predicated store
679 if (!is_one(op->predicate)) {
680 CodeGen_Posix::visit(op);
681 return;
682 }
683
684 if (neon_intrinsics_disabled()) {
685 CodeGen_Posix::visit(op);
686 return;
687 }
688
689 // A dense store of an interleaving can be done using a vst2 intrinsic
690 const Ramp *ramp = op->index.as<Ramp>();
691
692 // We only deal with ramps here
693 if (!ramp) {
694 CodeGen_Posix::visit(op);
695 return;
696 }
697
698 // First dig through let expressions
699 Expr rhs = op->value;
700 vector<pair<string, Expr>> lets;
701 while (const Let *let = rhs.as<Let>()) {
702 rhs = let->body;
703 lets.emplace_back(let->name, let->value);
704 }
705 const Shuffle *shuffle = rhs.as<Shuffle>();
706
707 // Interleaving store instructions only exist for certain types.
708 bool type_ok_for_vst = false;
709 Type intrin_type = Handle();
710 if (shuffle) {
711 Type t = shuffle->vectors[0].type();
712 intrin_type = t;
713 Type elt = t.element_of();
714 int vec_bits = t.bits() * t.lanes();
715 if (elt == Float(32) ||
716 elt == Int(8) || elt == Int(16) || elt == Int(32) ||
717 elt == UInt(8) || elt == UInt(16) || elt == UInt(32)) {
718 if (vec_bits % 128 == 0) {
719 type_ok_for_vst = true;
720 intrin_type = intrin_type.with_lanes(128 / t.bits());
721 } else if (vec_bits % 64 == 0) {
722 type_ok_for_vst = true;
723 intrin_type = intrin_type.with_lanes(64 / t.bits());
724 }
725 }
726 }
727
728 if (is_one(ramp->stride) &&
729 shuffle && shuffle->is_interleave() &&
730 type_ok_for_vst &&
731 2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) {
732
733 const int num_vecs = shuffle->vectors.size();
734 vector<Value *> args(num_vecs);
735
736 Type t = shuffle->vectors[0].type();
737
738 // Assume element-aligned.
739 int alignment = t.bytes();
740
741 // Codegen the lets
742 for (size_t i = 0; i < lets.size(); i++) {
743 sym_push(lets[i].first, codegen(lets[i].second));
744 }
745
746 // Codegen all the vector args.
747 for (int i = 0; i < num_vecs; ++i) {
748 args[i] = codegen(shuffle->vectors[i]);
749 }
750
751 // Declare the function
752 std::ostringstream instr;
753 vector<llvm::Type *> arg_types;
754 if (target.bits == 32) {
755 instr << "llvm.arm.neon.vst"
756 << num_vecs
757 << ".p0i8"
758 << ".v"
759 << intrin_type.lanes()
760 << (t.is_float() ? 'f' : 'i')
761 << t.bits();
762 arg_types = vector<llvm::Type *>(num_vecs + 2, llvm_type_of(intrin_type));
763 arg_types.front() = i8_t->getPointerTo();
764 arg_types.back() = i32_t;
765 } else {
766 instr << "llvm.aarch64.neon.st"
767 << num_vecs
768 << ".v"
769 << intrin_type.lanes()
770 << (t.is_float() ? 'f' : 'i')
771 << t.bits()
772 << ".p0"
773 << (t.is_float() ? 'f' : 'i')
774 << t.bits();
775 arg_types = vector<llvm::Type *>(num_vecs + 1, llvm_type_of(intrin_type));
776 arg_types.back() = llvm_type_of(intrin_type.element_of())->getPointerTo();
777 }
778 llvm::FunctionType *fn_type = FunctionType::get(llvm::Type::getVoidTy(*context), arg_types, false);
779 llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
780 internal_assert(fn);
781
782 // How many vst instructions do we need to generate?
783 int slices = t.lanes() / intrin_type.lanes();
784
785 internal_assert(slices >= 1);
786 for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
787 Expr slice_base = simplify(ramp->base + i * num_vecs);
788 Expr slice_ramp = Ramp::make(slice_base, ramp->stride, intrin_type.lanes() * num_vecs);
789 Value *ptr = codegen_buffer_pointer(op->name, shuffle->vectors[0].type().element_of(), slice_base);
790
791 vector<Value *> slice_args = args;
792 // Take a slice of each arg
793 for (int j = 0; j < num_vecs; j++) {
794 slice_args[j] = slice_vector(slice_args[j], i, intrin_type.lanes());
795 }
796
797 if (target.bits == 32) {
798 // The arm32 versions take an i8*, regardless of the type stored.
799 ptr = builder->CreatePointerCast(ptr, i8_t->getPointerTo());
800 // Set the pointer argument
801 slice_args.insert(slice_args.begin(), ptr);
802 // Set the alignment argument
803 slice_args.push_back(ConstantInt::get(i32_t, alignment));
804 } else {
805 // Set the pointer argument
806 slice_args.push_back(ptr);
807 }
808
809 CallInst *store = builder->CreateCall(fn, slice_args);
810 add_tbaa_metadata(store, op->name, slice_ramp);
811 }
812
813 // pop the lets from the symbol table
814 for (size_t i = 0; i < lets.size(); i++) {
815 sym_pop(lets[i].first);
816 }
817
818 return;
819 }
820
821 // If the stride is one or minus one, we can deal with that using vanilla codegen
822 const IntImm *stride = ramp->stride.as<IntImm>();
823 if (stride && (stride->value == 1 || stride->value == -1)) {
824 CodeGen_Posix::visit(op);
825 return;
826 }
827
828 // We have builtins for strided stores with fixed but unknown stride, but they use inline assembly
829 if (target.bits != 64 /* Not yet implemented for aarch64 */) {
830 ostringstream builtin;
831 builtin << "strided_store_"
832 << (op->value.type().is_float() ? "f" : "i")
833 << op->value.type().bits()
834 << "x" << op->value.type().lanes();
835
836 llvm::Function *fn = module->getFunction(builtin.str());
837 if (fn) {
838 Value *base = codegen_buffer_pointer(op->name, op->value.type().element_of(), ramp->base);
839 Value *stride = codegen(ramp->stride * op->value.type().bytes());
840 Value *val = codegen(op->value);
841 debug(4) << "Creating call to " << builtin.str() << "\n";
842 Value *store_args[] = {base, stride, val};
843 Instruction *store = builder->CreateCall(fn, store_args);
844 (void)store;
845 add_tbaa_metadata(store, op->name, op->index);
846 return;
847 }
848 }
849
850 CodeGen_Posix::visit(op);
851 }
852
visit(const Load * op)853 void CodeGen_ARM::visit(const Load *op) {
854 // Predicated load
855 if (!is_one(op->predicate)) {
856 CodeGen_Posix::visit(op);
857 return;
858 }
859
860 if (neon_intrinsics_disabled()) {
861 CodeGen_Posix::visit(op);
862 return;
863 }
864
865 const Ramp *ramp = op->index.as<Ramp>();
866
867 // We only deal with ramps here
868 if (!ramp) {
869 CodeGen_Posix::visit(op);
870 return;
871 }
872
873 const IntImm *stride = ramp ? ramp->stride.as<IntImm>() : nullptr;
874
875 // If the stride is one or minus one, we can deal with that using vanilla codegen
876 if (stride && (stride->value == 1 || stride->value == -1)) {
877 CodeGen_Posix::visit(op);
878 return;
879 }
880
881 // Strided loads with known stride
882 if (stride && stride->value >= 2 && stride->value <= 4) {
883 // Check alignment on the base. Attempt to shift to an earlier
884 // address if it simplifies the expression. This makes
885 // adjacent strided loads shared a vldN op.
886 Expr base = ramp->base;
887 int offset = 0;
888 ModulusRemainder mod_rem = modulus_remainder(ramp->base);
889
890 const Add *add = base.as<Add>();
891 const IntImm *add_b = add ? add->b.as<IntImm>() : nullptr;
892
893 if ((mod_rem.modulus % stride->value) == 0) {
894 offset = mod_rem.remainder % stride->value;
895 } else if ((mod_rem.modulus == 1) && add_b) {
896 offset = add_b->value % stride->value;
897 if (offset < 0) {
898 offset += stride->value;
899 }
900 }
901
902 if (offset) {
903 base = simplify(base - offset);
904 mod_rem.remainder -= offset;
905 if (mod_rem.modulus) {
906 mod_rem.remainder = mod_imp(mod_rem.remainder, mod_rem.modulus);
907 }
908 }
909
910 int alignment = op->type.bytes();
911 alignment *= gcd(mod_rem.modulus, mod_rem.remainder);
912 // Maximum stack alignment on arm is 16 bytes, so we should
913 // never claim alignment greater than that.
914 alignment = gcd(alignment, 16);
915 internal_assert(alignment > 0);
916
917 // Decide what width to slice things into. If not a multiple
918 // of 64 or 128 bits, then we can't safely slice it up into
919 // some number of vlds, so we hand it over the base class.
920 int bit_width = op->type.bits() * op->type.lanes();
921 int intrin_lanes = 0;
922 if (bit_width % 128 == 0) {
923 intrin_lanes = 128 / op->type.bits();
924 } else if (bit_width % 64 == 0) {
925 intrin_lanes = 64 / op->type.bits();
926 } else {
927 CodeGen_Posix::visit(op);
928 return;
929 }
930
931 llvm::Type *load_return_type = llvm_type_of(op->type.with_lanes(intrin_lanes * stride->value));
932 llvm::Type *load_return_pointer_type = load_return_type->getPointerTo();
933 Value *undef = UndefValue::get(load_return_type);
934 SmallVector<Constant *, 256> constants;
935 for (int j = 0; j < intrin_lanes; j++) {
936 Constant *constant = ConstantInt::get(i32_t, j * stride->value + offset);
937 constants.push_back(constant);
938 }
939 Constant *constantsV = ConstantVector::get(constants);
940
941 vector<Value *> results;
942 for (int i = 0; i < op->type.lanes(); i += intrin_lanes) {
943 Expr slice_base = simplify(base + i * ramp->stride);
944 Expr slice_ramp = Ramp::make(slice_base, ramp->stride, intrin_lanes);
945 Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base);
946 Value *bitcastI = builder->CreateBitOrPointerCast(ptr, load_return_pointer_type);
947 LoadInst *loadI = cast<LoadInst>(builder->CreateLoad(bitcastI));
948 #if LLVM_VERSION >= 110
949 loadI->setAlignment(Align(alignment));
950 #elif LLVM_VERSION >= 100
951 loadI->setAlignment(MaybeAlign(alignment));
952 #else
953 loadI->setAlignment(alignment);
954 #endif
955 add_tbaa_metadata(loadI, op->name, slice_ramp);
956 Value *shuffleInstr = builder->CreateShuffleVector(loadI, undef, constantsV);
957 results.push_back(shuffleInstr);
958 }
959
960 // Concat the results
961 value = concat_vectors(results);
962 return;
963 }
964
965 // We have builtins for strided loads with fixed but unknown stride, but they use inline assembly.
966 if (target.bits != 64 /* Not yet implemented for aarch64 */) {
967 ostringstream builtin;
968 builtin << "strided_load_"
969 << (op->type.is_float() ? "f" : "i")
970 << op->type.bits()
971 << "x" << op->type.lanes();
972
973 llvm::Function *fn = module->getFunction(builtin.str());
974 if (fn) {
975 Value *base = codegen_buffer_pointer(op->name, op->type.element_of(), ramp->base);
976 Value *stride = codegen(ramp->stride * op->type.bytes());
977 debug(4) << "Creating call to " << builtin.str() << "\n";
978 Value *args[] = {base, stride};
979 Instruction *load = builder->CreateCall(fn, args, builtin.str());
980 add_tbaa_metadata(load, op->name, op->index);
981 value = load;
982 return;
983 }
984 }
985
986 CodeGen_Posix::visit(op);
987 }
988
visit(const Call * op)989 void CodeGen_ARM::visit(const Call *op) {
990 if (op->is_intrinsic(Call::abs) && op->type.is_uint()) {
991 internal_assert(op->args.size() == 1);
992 // If the arg is a subtract with narrowable args, we can use vabdl.
993 const Sub *sub = op->args[0].as<Sub>();
994 if (sub) {
995 Expr a = sub->a, b = sub->b;
996 Type narrow = UInt(a.type().bits() / 2, a.type().lanes());
997 Expr na = lossless_cast(narrow, a);
998 Expr nb = lossless_cast(narrow, b);
999
1000 // Also try an unsigned narrowing
1001 if (!na.defined() || !nb.defined()) {
1002 narrow = Int(narrow.bits(), narrow.lanes());
1003 na = lossless_cast(narrow, a);
1004 nb = lossless_cast(narrow, b);
1005 }
1006
1007 if (na.defined() && nb.defined()) {
1008 Expr absd = Call::make(UInt(narrow.bits(), narrow.lanes()), Call::absd,
1009 {na, nb}, Call::PureIntrinsic);
1010
1011 absd = Cast::make(op->type, absd);
1012 codegen(absd);
1013 return;
1014 }
1015 }
1016 } else if (op->is_intrinsic(Call::sorted_avg)) {
1017 Type ty = op->type;
1018 Type wide_ty = ty.with_bits(ty.bits() * 2);
1019 // This will codegen to vhaddu (arm32) or uhadd (arm64).
1020 value = codegen(cast(ty, (cast(wide_ty, op->args[0]) + cast(wide_ty, op->args[1])) / 2));
1021 return;
1022 }
1023
1024 CodeGen_Posix::visit(op);
1025 }
1026
visit(const LT * op)1027 void CodeGen_ARM::visit(const LT *op) {
1028 #if LLVM_VERSION >= 100
1029 if (op->a.type().is_float() && op->type.is_vector()) {
1030 // Fast-math flags confuse LLVM's aarch64 backend, so
1031 // temporarily clear them for this instruction.
1032 // See https://bugs.llvm.org/show_bug.cgi?id=45036
1033 llvm::IRBuilderBase::FastMathFlagGuard guard(*builder);
1034 builder->clearFastMathFlags();
1035 CodeGen_Posix::visit(op);
1036 return;
1037 }
1038 #endif
1039
1040 CodeGen_Posix::visit(op);
1041 }
1042
visit(const LE * op)1043 void CodeGen_ARM::visit(const LE *op) {
1044 #if LLVM_VERSION >= 100
1045 if (op->a.type().is_float() && op->type.is_vector()) {
1046 // Fast-math flags confuse LLVM's aarch64 backend, so
1047 // temporarily clear them for this instruction.
1048 // See https://bugs.llvm.org/show_bug.cgi?id=45036
1049 llvm::IRBuilderBase::FastMathFlagGuard guard(*builder);
1050 builder->clearFastMathFlags();
1051 CodeGen_Posix::visit(op);
1052 return;
1053 }
1054 #endif
1055
1056 CodeGen_Posix::visit(op);
1057 }
1058
codegen_vector_reduce(const VectorReduce * op,const Expr & init)1059 void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) {
1060 if (neon_intrinsics_disabled() ||
1061 op->op == VectorReduce::Or ||
1062 op->op == VectorReduce::And ||
1063 op->op == VectorReduce::Mul ||
1064 // LLVM 9 has bugs in the arm backend for vector reduce
1065 // ops. See https://github.com/halide/Halide/issues/5081
1066 !(LLVM_VERSION >= 100)) {
1067 CodeGen_Posix::codegen_vector_reduce(op, init);
1068 return;
1069 }
1070
1071 // ARM has a variety of pairwise reduction ops for +, min,
1072 // max. The versions that do not widen take two 64-bit args and
1073 // return one 64-bit vector of the same type. The versions that
1074 // widen take one arg and return something with half the vector
1075 // lanes and double the bit-width.
1076
1077 int factor = op->value.type().lanes() / op->type.lanes();
1078
1079 // These are the types for which we have reduce intrinsics in the
1080 // runtime.
1081 bool have_reduce_intrinsic = (op->type.is_int() ||
1082 op->type.is_uint() ||
1083 op->type.is_float());
1084
1085 // We don't have 16-bit float or bfloat horizontal ops
1086 if (op->type.is_bfloat() || (op->type.is_float() && op->type.bits() < 32)) {
1087 have_reduce_intrinsic = false;
1088 }
1089
1090 // Only aarch64 has float64 horizontal ops
1091 if (target.bits == 32 && op->type.element_of() == Float(64)) {
1092 have_reduce_intrinsic = false;
1093 }
1094
1095 // For 64-bit integers, we only have addition, not min/max
1096 if (op->type.bits() == 64 &&
1097 !op->type.is_float() &&
1098 op->op != VectorReduce::Add) {
1099 have_reduce_intrinsic = false;
1100 }
1101
1102 // We only have intrinsics that reduce by a factor of two
1103 if (factor != 2) {
1104 have_reduce_intrinsic = false;
1105 }
1106
1107 if (have_reduce_intrinsic) {
1108 Expr arg = op->value;
1109 if (op->op == VectorReduce::Add &&
1110 op->type.bits() >= 16 &&
1111 !op->type.is_float()) {
1112 Type narrower_type = arg.type().with_bits(arg.type().bits() / 2);
1113 Expr narrower = lossless_cast(narrower_type, arg);
1114 if (!narrower.defined() && arg.type().is_int()) {
1115 // We can also safely accumulate from a uint into a
1116 // wider int, because the addition uses at most one
1117 // extra bit.
1118 narrower = lossless_cast(narrower_type.with_code(Type::UInt), arg);
1119 }
1120 if (narrower.defined()) {
1121 arg = narrower;
1122 }
1123 }
1124 int output_bits;
1125 if (target.bits == 32 && arg.type().bits() == op->type.bits()) {
1126 // For the non-widening version, the output must be 64-bit
1127 output_bits = 64;
1128 } else if (op->type.bits() * op->type.lanes() <= 64) {
1129 // No point using the 128-bit version of the instruction if the output is narrow.
1130 output_bits = 64;
1131 } else {
1132 output_bits = 128;
1133 }
1134
1135 const int output_lanes = output_bits / op->type.bits();
1136 Type intrin_type = op->type.with_lanes(output_lanes);
1137 Type arg_type = arg.type().with_lanes(output_lanes * 2);
1138 if (op->op == VectorReduce::Add &&
1139 arg.type().bits() == op->type.bits() &&
1140 arg_type.is_uint()) {
1141 // For non-widening additions, there is only a signed
1142 // version (because it's equivalent).
1143 arg_type = arg_type.with_code(Type::Int);
1144 intrin_type = intrin_type.with_code(Type::Int);
1145 } else if (arg.type().is_uint() && intrin_type.is_int()) {
1146 // Use the uint version
1147 intrin_type = intrin_type.with_code(Type::UInt);
1148 }
1149
1150 std::stringstream ss;
1151 vector<Expr> args;
1152 ss << "pairwise_" << op->op << "_" << intrin_type << "_" << arg_type;
1153 Expr accumulator = init;
1154 if (op->op == VectorReduce::Add &&
1155 accumulator.defined() &&
1156 arg_type.bits() < intrin_type.bits()) {
1157 // We can use the accumulating variant
1158 ss << "_accumulate";
1159 args.push_back(init);
1160 accumulator = Expr();
1161 }
1162 args.push_back(arg);
1163 value = call_intrin(op->type, output_lanes, ss.str(), args);
1164
1165 if (accumulator.defined()) {
1166 // We still have an initial value to take care of
1167 string n = unique_name('t');
1168 sym_push(n, value);
1169 Expr v = Variable::make(accumulator.type(), n);
1170 switch (op->op) {
1171 case VectorReduce::Add:
1172 accumulator += v;
1173 break;
1174 case VectorReduce::Min:
1175 accumulator = min(accumulator, v);
1176 break;
1177 case VectorReduce::Max:
1178 accumulator = max(accumulator, v);
1179 break;
1180 default:
1181 internal_error << "unreachable";
1182 }
1183 codegen(accumulator);
1184 sym_pop(n);
1185 }
1186
1187 return;
1188 }
1189
1190 // Pattern-match 8-bit dot product instructions available on newer
1191 // ARM cores.
1192 if (target.has_feature(Target::ARMDotProd) &&
1193 factor % 4 == 0 &&
1194 op->op == VectorReduce::Add &&
1195 target.bits == 64 &&
1196 (op->type.element_of() == Int(32) ||
1197 op->type.element_of() == UInt(32))) {
1198 const Mul *mul = op->value.as<Mul>();
1199 if (mul) {
1200 const int input_lanes = mul->type.lanes();
1201 Expr a = lossless_cast(UInt(8, input_lanes), mul->a);
1202 Expr b = lossless_cast(UInt(8, input_lanes), mul->b);
1203 if (!a.defined()) {
1204 a = lossless_cast(Int(8, input_lanes), mul->a);
1205 b = lossless_cast(Int(8, input_lanes), mul->b);
1206 }
1207 if (a.defined() && b.defined()) {
1208 if (factor != 4) {
1209 Expr equiv = VectorReduce::make(op->op, op->value, input_lanes / 4);
1210 equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
1211 codegen_vector_reduce(equiv.as<VectorReduce>(), init);
1212 return;
1213 }
1214 Expr i = init;
1215 if (!i.defined()) {
1216 i = make_zero(op->type);
1217 }
1218 vector<Expr> args{i, a, b};
1219 if (op->type.lanes() <= 2) {
1220 if (op->type.is_uint()) {
1221 value = call_intrin(op->type, 2, "llvm.aarch64.neon.udot.v2i32.v8i8", args);
1222 } else {
1223 value = call_intrin(op->type, 2, "llvm.aarch64.neon.sdot.v2i32.v8i8", args);
1224 }
1225 } else {
1226 if (op->type.is_uint()) {
1227 value = call_intrin(op->type, 4, "llvm.aarch64.neon.udot.v4i32.v16i8", args);
1228 } else {
1229 value = call_intrin(op->type, 4, "llvm.aarch64.neon.sdot.v4i32.v16i8", args);
1230 }
1231 }
1232 return;
1233 }
1234 }
1235 }
1236
1237 CodeGen_Posix::codegen_vector_reduce(op, init);
1238 }
1239
mcpu() const1240 string CodeGen_ARM::mcpu() const {
1241 if (target.bits == 32) {
1242 if (target.has_feature(Target::ARMv7s)) {
1243 return "swift";
1244 } else {
1245 return "cortex-a9";
1246 }
1247 } else {
1248 if (target.os == Target::IOS) {
1249 return "cyclone";
1250 } else if (target.os == Target::OSX) {
1251 return "apple-a12";
1252 } else {
1253 return "generic";
1254 }
1255 }
1256 }
1257
mattrs() const1258 string CodeGen_ARM::mattrs() const {
1259 if (target.bits == 32) {
1260 if (target.has_feature(Target::ARMv7s)) {
1261 return "+neon";
1262 }
1263 if (!target.has_feature(Target::NoNEON)) {
1264 return "+neon";
1265 } else {
1266 return "-neon";
1267 }
1268 } else {
1269 // TODO: Should Halide's SVE flags be 64-bit only?
1270 string arch_flags;
1271 if (target.has_feature(Target::SVE2)) {
1272 arch_flags = "+sve2";
1273 } else if (target.has_feature(Target::SVE)) {
1274 arch_flags = "+sve";
1275 }
1276
1277 if (target.has_feature(Target::ARMDotProd)) {
1278 arch_flags += "+dotprod";
1279 }
1280
1281 if (target.os == Target::IOS || target.os == Target::OSX) {
1282 return arch_flags + "+reserve-x18";
1283 } else {
1284 return arch_flags;
1285 }
1286 }
1287 }
1288
use_soft_float_abi() const1289 bool CodeGen_ARM::use_soft_float_abi() const {
1290 // One expects the flag is irrelevant on 64-bit, but we'll make the logic
1291 // exhaustive anyway. It is not clear the armv7s case is necessary either.
1292 return target.has_feature(Target::SoftFloatABI) ||
1293 (target.bits == 32 &&
1294 ((target.os == Target::Android) ||
1295 (target.os == Target::IOS && !target.has_feature(Target::ARMv7s))));
1296 }
1297
native_vector_bits() const1298 int CodeGen_ARM::native_vector_bits() const {
1299 return 128;
1300 }
1301
1302 } // namespace Internal
1303 } // namespace Halide
1304