1 #include <iostream>
2 #include <sstream>
3 
4 #include "CSE.h"
5 #include "CodeGen_ARM.h"
6 #include "CodeGen_Internal.h"
7 #include "ConciseCasts.h"
8 #include "Debug.h"
9 #include "IREquality.h"
10 #include "IRMatch.h"
11 #include "IROperator.h"
12 #include "IRPrinter.h"
13 #include "LLVM_Headers.h"
14 #include "Simplify.h"
15 #include "Util.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 using std::ostringstream;
21 using std::pair;
22 using std::string;
23 using std::vector;
24 
25 using namespace Halide::ConciseCasts;
26 using namespace llvm;
27 
CodeGen_ARM(Target target)28 CodeGen_ARM::CodeGen_ARM(Target target)
29     : CodeGen_Posix(target) {
30     if (target.bits == 32) {
31 #if !defined(WITH_ARM)
32         user_error << "arm not enabled for this build of Halide.";
33 #endif
34         user_assert(llvm_ARM_enabled) << "llvm build not configured with ARM target enabled\n.";
35     } else {
36 #if !defined(WITH_AARCH64)
37         user_error << "aarch64 not enabled for this build of Halide.";
38 #endif
39         user_assert(llvm_AArch64_enabled) << "llvm build not configured with AArch64 target enabled.\n";
40     }
41 
42     // Generate the cast patterns that can take vector types.  We need
43     // to iterate over all 64 and 128 bit integer types relevant for
44     // neon.
45     Type types[] = {Int(8, 8), Int(8, 16), UInt(8, 8), UInt(8, 16),
46                     Int(16, 4), Int(16, 8), UInt(16, 4), UInt(16, 8),
47                     Int(32, 2), Int(32, 4), UInt(32, 2), UInt(32, 4)};
48     for (size_t i = 0; i < sizeof(types) / sizeof(types[0]); i++) {
49         Type t = types[i];
50 
51         int intrin_lanes = t.lanes();
52         std::ostringstream oss;
53         oss << ".v" << intrin_lanes << "i" << t.bits();
54         string t_str = oss.str();
55 
56         // For the 128-bit versions, we want to match any vector width.
57         if (t.bits() * t.lanes() == 128) {
58             t = t.with_lanes(0);
59         }
60 
61         // Wider versions of the type
62         Type w = t.with_bits(t.bits() * 2);
63         Type ws = Int(t.bits() * 2, t.lanes());
64 
65         // Vector wildcard for this type
66         Expr vector = Variable::make(t, "*");
67         Expr w_vector = Variable::make(w, "*");
68         Expr ws_vector = Variable::make(ws, "*");
69 
70         // Bounds of the type stored in the wider vector type
71         Expr tmin = simplify(cast(w, t.min()));
72         Expr tmax = simplify(cast(w, t.max()));
73         Expr tsmin = simplify(cast(ws, t.min()));
74         Expr tsmax = simplify(cast(ws, t.max()));
75 
76         Pattern p("", "", intrin_lanes, Expr(), Pattern::NarrowArgs);
77 
78         // Rounding-up averaging
79         if (t.is_int()) {
80             p.intrin32 = "llvm.arm.neon.vrhadds" + t_str;
81             p.intrin64 = "llvm.aarch64.neon.srhadd" + t_str;
82         } else {
83             p.intrin32 = "llvm.arm.neon.vrhaddu" + t_str;
84             p.intrin64 = "llvm.aarch64.neon.urhadd" + t_str;
85         }
86 
87         p.pattern = cast(t, (w_vector + w_vector + 1) / 2);
88         casts.push_back(p);
89         p.pattern = cast(t, (w_vector + (w_vector + 1)) / 2);
90         casts.push_back(p);
91         p.pattern = cast(t, ((w_vector + 1) + w_vector) / 2);
92         casts.push_back(p);
93 
94         // Rounding down averaging
95         if (t.is_int()) {
96             p.intrin32 = "llvm.arm.neon.vhadds" + t_str;
97             p.intrin64 = "llvm.aarch64.neon.shadd" + t_str;
98         } else {
99             p.intrin32 = "llvm.arm.neon.vhaddu" + t_str;
100             p.intrin64 = "llvm.aarch64.neon.uhadd" + t_str;
101         }
102         p.pattern = cast(t, (w_vector + w_vector) / 2);
103         casts.push_back(p);
104 
105         // Halving subtract
106         if (t.is_int()) {
107             p.intrin32 = "llvm.arm.neon.vhsubs" + t_str;
108             p.intrin64 = "llvm.aarch64.neon.shsub" + t_str;
109         } else {
110             p.intrin32 = "llvm.arm.neon.vhsubu" + t_str;
111             p.intrin64 = "llvm.aarch64.neon.uhsub" + t_str;
112         }
113         p.pattern = cast(t, (w_vector - w_vector) / 2);
114         casts.push_back(p);
115 
116         // Saturating add
117 #if LLVM_VERSION >= 100
118         if (t.is_int()) {
119             p.intrin32 = "llvm.sadd.sat" + t_str;
120             p.intrin64 = "llvm.sadd.sat" + t_str;
121         } else {
122             p.intrin32 = "llvm.uadd.sat" + t_str;
123             p.intrin64 = "llvm.uadd.sat" + t_str;
124         }
125 #else
126         if (t.is_int()) {
127             p.intrin32 = "llvm.arm.neon.vqadds" + t_str;
128             p.intrin64 = "llvm.aarch64.neon.sqadd" + t_str;
129         } else {
130             p.intrin32 = "llvm.arm.neon.vqaddu" + t_str;
131             p.intrin64 = "llvm.aarch64.neon.uqadd" + t_str;
132         }
133 #endif
134         p.pattern = cast(t, clamp(w_vector + w_vector, tmin, tmax));
135         casts.push_back(p);
136 
137         // In the unsigned case, the saturation below is unnecessary
138         if (t.is_uint()) {
139             p.pattern = cast(t, min(w_vector + w_vector, tmax));
140             casts.push_back(p);
141         }
142 
143         // Saturating subtract
144         // N.B. Saturating subtracts always widen to a signed type
145 #if LLVM_VERSION >= 100
146         if (t.is_int()) {
147             p.intrin32 = "llvm.ssub.sat" + t_str;
148             p.intrin64 = "llvm.ssub.sat" + t_str;
149         } else {
150             p.intrin32 = "llvm.usub.sat" + t_str;
151             p.intrin64 = "llvm.usub.sat" + t_str;
152         }
153 #else
154         if (t.is_int()) {
155             p.intrin32 = "llvm.arm.neon.vqsubs" + t_str;
156             p.intrin64 = "llvm.aarch64.neon.sqsub" + t_str;
157         } else {
158             p.intrin32 = "llvm.arm.neon.vqsubu" + t_str;
159             p.intrin64 = "llvm.aarch64.neon.uqsub" + t_str;
160         }
161 #endif
162         p.pattern = cast(t, clamp(ws_vector - ws_vector, tsmin, tsmax));
163         casts.push_back(p);
164 
165         // In the unsigned case, we may detect that the top of the clamp is unnecessary
166         if (t.is_uint()) {
167             p.pattern = cast(t, max(ws_vector - ws_vector, 0));
168             casts.push_back(p);
169         }
170     }
171 
172     casts.emplace_back("vqrdmulh.v4i16", "sqrdmulh.v4i16", 4,
173                        i16_sat((wild_i32x4 * wild_i32x4 + (1 << 14)) / (1 << 15)),
174                        Pattern::NarrowArgs);
175     casts.emplace_back("vqrdmulh.v8i16", "sqrdmulh.v8i16", 8,
176                        i16_sat((wild_i32x_ * wild_i32x_ + (1 << 14)) / (1 << 15)),
177                        Pattern::NarrowArgs);
178     casts.emplace_back("vqrdmulh.v2i32", "sqrdmulh.v2i32", 2,
179                        i32_sat((wild_i64x2 * wild_i64x2 + (1 << 30)) / Expr(int64_t(1) << 31)),
180                        Pattern::NarrowArgs);
181     casts.emplace_back("vqrdmulh.v4i32", "sqrdmulh.v4i32", 4,
182                        i32_sat((wild_i64x_ * wild_i64x_ + (1 << 30)) / Expr(int64_t(1) << 31)),
183                        Pattern::NarrowArgs);
184 
185     casts.emplace_back("vqshiftns.v8i8", "sqshrn.v8i8", 8, i8_sat(wild_i16x_ / wild_i16x_), Pattern::RightShift);
186     casts.emplace_back("vqshiftns.v4i16", "sqshrn.v4i16", 4, i16_sat(wild_i32x_ / wild_i32x_), Pattern::RightShift);
187     casts.emplace_back("vqshiftns.v2i32", "sqshrn.v2i32", 2, i32_sat(wild_i64x_ / wild_i64x_), Pattern::RightShift);
188     casts.emplace_back("vqshiftnu.v8i8", "uqshrn.v8i8", 8, u8_sat(wild_u16x_ / wild_u16x_), Pattern::RightShift);
189     casts.emplace_back("vqshiftnu.v4i16", "uqshrn.v4i16", 4, u16_sat(wild_u32x_ / wild_u32x_), Pattern::RightShift);
190     casts.emplace_back("vqshiftnu.v2i32", "uqshrn.v2i32", 2, u32_sat(wild_u64x_ / wild_u64x_), Pattern::RightShift);
191     casts.emplace_back("vqshiftnsu.v8i8", "sqshrun.v8i8", 8, u8_sat(wild_i16x_ / wild_i16x_), Pattern::RightShift);
192     casts.emplace_back("vqshiftnsu.v4i16", "sqshrun.v4i16", 4, u16_sat(wild_i32x_ / wild_i32x_), Pattern::RightShift);
193     casts.emplace_back("vqshiftnsu.v2i32", "sqshrun.v2i32", 2, u32_sat(wild_i64x_ / wild_i64x_), Pattern::RightShift);
194 
195     // Where a 64-bit and 128-bit version exist, we use the 64-bit
196     // version only when the args are 64-bits wide.
197     casts.emplace_back("vqshifts.v8i8", "sqshl.v8i8", 8, i8_sat(i16(wild_i8x8) * wild_i16x8), Pattern::LeftShift);
198     casts.emplace_back("vqshifts.v4i16", "sqshl.v4i16", 4, i16_sat(i32(wild_i16x4) * wild_i32x4), Pattern::LeftShift);
199     casts.emplace_back("vqshifts.v2i32", "sqshl.v2i32", 2, i32_sat(i64(wild_i32x2) * wild_i64x2), Pattern::LeftShift);
200     casts.emplace_back("vqshiftu.v8i8", "uqshl.v8i8", 8, u8_sat(u16(wild_u8x8) * wild_u16x8), Pattern::LeftShift);
201     casts.emplace_back("vqshiftu.v4i16", "uqshl.v4i16", 4, u16_sat(u32(wild_u16x4) * wild_u32x4), Pattern::LeftShift);
202     casts.emplace_back("vqshiftu.v2i32", "uqshl.v2i32", 2, u32_sat(u64(wild_u32x2) * wild_u64x2), Pattern::LeftShift);
203     casts.emplace_back("vqshiftsu.v8i8", "sqshlu.v8i8", 8, u8_sat(i16(wild_i8x8) * wild_i16x8), Pattern::LeftShift);
204     casts.emplace_back("vqshiftsu.v4i16", "sqshlu.v4i16", 4, u16_sat(i32(wild_i16x4) * wild_i32x4), Pattern::LeftShift);
205     casts.emplace_back("vqshiftsu.v2i32", "sqshlu.v2i32", 2, u32_sat(i64(wild_i32x2) * wild_i64x2), Pattern::LeftShift);
206 
207     // We use the 128-bit version for all other vector widths.
208     casts.emplace_back("vqshifts.v16i8", "sqshl.v16i8", 16, i8_sat(i16(wild_i8x_) * wild_i16x_), Pattern::LeftShift);
209     casts.emplace_back("vqshifts.v8i16", "sqshl.v8i16", 8, i16_sat(i32(wild_i16x_) * wild_i32x_), Pattern::LeftShift);
210     casts.emplace_back("vqshifts.v4i32", "sqshl.v4i32", 4, i32_sat(i64(wild_i32x_) * wild_i64x_), Pattern::LeftShift);
211     casts.emplace_back("vqshiftu.v16i8", "uqshl.v16i8", 16, u8_sat(u16(wild_u8x_) * wild_u16x_), Pattern::LeftShift);
212     casts.emplace_back("vqshiftu.v8i16", "uqshl.v8i16", 8, u16_sat(u32(wild_u16x_) * wild_u32x_), Pattern::LeftShift);
213     casts.emplace_back("vqshiftu.v4i32", "uqshl.v4i32", 4, u32_sat(u64(wild_u32x_) * wild_u64x_), Pattern::LeftShift);
214     casts.emplace_back("vqshiftsu.v16i8", "sqshlu.v16i8", 16, u8_sat(i16(wild_i8x_) * wild_i16x_), Pattern::LeftShift);
215     casts.emplace_back("vqshiftsu.v8i16", "sqshlu.v8i16", 8, u16_sat(i32(wild_i16x_) * wild_i32x_), Pattern::LeftShift);
216     casts.emplace_back("vqshiftsu.v4i32", "sqshlu.v4i32", 4, u32_sat(i64(wild_i32x_) * wild_i64x_), Pattern::LeftShift);
217 
218     casts.emplace_back("vqmovns.v8i8", "sqxtn.v8i8", 8, i8_sat(wild_i16x_));
219     casts.emplace_back("vqmovns.v4i16", "sqxtn.v4i16", 4, i16_sat(wild_i32x_));
220     casts.emplace_back("vqmovns.v2i32", "sqxtn.v2i32", 2, i32_sat(wild_i64x_));
221     casts.emplace_back("vqmovnu.v8i8", "uqxtn.v8i8", 8, u8_sat(wild_u16x_));
222     casts.emplace_back("vqmovnu.v4i16", "uqxtn.v4i16", 4, u16_sat(wild_u32x_));
223     casts.emplace_back("vqmovnu.v2i32", "uqxtn.v2i32", 2, u32_sat(wild_u64x_));
224     casts.emplace_back("vqmovnsu.v8i8", "sqxtun.v8i8", 8, u8_sat(wild_i16x_));
225     casts.emplace_back("vqmovnsu.v4i16", "sqxtun.v4i16", 4, u16_sat(wild_i32x_));
226     casts.emplace_back("vqmovnsu.v2i32", "sqxtun.v2i32", 2, u32_sat(wild_i64x_));
227 
228     // Overflow for int32 is not defined by Halide, so for those we can take
229     // advantage of special add-and-halve instructions.
230     //
231     // 64-bit averaging round-down
232     averagings.emplace_back("vhadds.v2i32", "shadd.v2i32", 2, (wild_i32x2 + wild_i32x2));
233 
234     // 128-bit
235     averagings.emplace_back("vhadds.v4i32", "shadd.v4i32", 4, (wild_i32x_ + wild_i32x_));
236 
237     // 64-bit halving subtract
238     averagings.emplace_back("vhsubs.v2i32", "shsub.v2i32", 2, (wild_i32x2 - wild_i32x2));
239 
240     // 128-bit
241     averagings.emplace_back("vhsubs.v4i32", "shsub.v4i32", 4, (wild_i32x_ - wild_i32x_));
242 
243     // 64-bit saturating negation
244     negations.emplace_back("vqneg.v8i8", "sqneg.v8i8", 8, -max(wild_i8x8, -127));
245     negations.emplace_back("vqneg.v4i16", "sqneg.v4i16", 4, -max(wild_i16x4, -32767));
246     negations.emplace_back("vqneg.v2i32", "sqneg.v2i32", 2, -max(wild_i32x2, -(0x7fffffff)));
247 
248     // 128-bit
249     negations.emplace_back("vqneg.v16i8", "sqneg.v16i8", 16, -max(wild_i8x_, -127));
250     negations.emplace_back("vqneg.v8i16", "sqneg.v8i16", 8, -max(wild_i16x_, -32767));
251     negations.emplace_back("vqneg.v4i32", "sqneg.v4i32", 4, -max(wild_i32x_, -(0x7fffffff)));
252 
253     // Widening multiplies.
254     multiplies.emplace_back("vmulls.v2i64", "smull.v2i64", 2,
255                             wild_i64x_ * wild_i64x_,
256                             Pattern::NarrowArgs);
257     multiplies.emplace_back("vmullu.v2i64", "umull.v2i64", 2,
258                             wild_u64x_ * wild_u64x_,
259                             Pattern::NarrowArgs);
260     multiplies.emplace_back("vmulls.v4i32", "smull.v4i32", 4,
261                             wild_i32x_ * wild_i32x_,
262                             Pattern::NarrowArgs);
263     multiplies.emplace_back("vmullu.v4i32", "umull.v4i32", 4,
264                             wild_u32x_ * wild_u32x_,
265                             Pattern::NarrowArgs);
266     multiplies.emplace_back("vmulls.v8i16", "smull.v8i16", 8,
267                             wild_i16x_ * wild_i16x_,
268                             Pattern::NarrowArgs);
269     multiplies.emplace_back("vmullu.v8i16", "umull.v8i16", 8,
270                             wild_u16x_ * wild_u16x_,
271                             Pattern::NarrowArgs);
272 }
273 
call_pattern(const Pattern & p,Type t,const vector<Expr> & args)274 Value *CodeGen_ARM::call_pattern(const Pattern &p, Type t, const vector<Expr> &args) {
275     if (target.bits == 32) {
276         return call_intrin(t, p.intrin_lanes, p.intrin32, args);
277     } else {
278         return call_intrin(t, p.intrin_lanes, p.intrin64, args);
279     }
280 }
281 
call_pattern(const Pattern & p,llvm::Type * t,const vector<llvm::Value * > & args)282 Value *CodeGen_ARM::call_pattern(const Pattern &p, llvm::Type *t, const vector<llvm::Value *> &args) {
283     if (target.bits == 32) {
284         return call_intrin(t, p.intrin_lanes, p.intrin32, args);
285     } else {
286         return call_intrin(t, p.intrin_lanes, p.intrin64, args);
287     }
288 }
289 
visit(const Cast * op)290 void CodeGen_ARM::visit(const Cast *op) {
291     if (neon_intrinsics_disabled()) {
292         CodeGen_Posix::visit(op);
293         return;
294     }
295 
296     Type t = op->type;
297 
298     vector<Expr> matches;
299 
300     for (size_t i = 0; i < casts.size(); i++) {
301         const Pattern &pattern = casts[i];
302         //debug(4) << "Trying pattern: " << patterns[i].intrin << " " << patterns[i].pattern << "\n";
303         if (expr_match(pattern.pattern, op, matches)) {
304 
305             //debug(4) << "Match!\n";
306             if (pattern.type == Pattern::Simple) {
307                 value = call_pattern(pattern, t, matches);
308                 return;
309             } else if (pattern.type == Pattern::NarrowArgs) {
310                 // Try to narrow all of the args.
311                 bool all_narrow = true;
312                 for (size_t i = 0; i < matches.size(); i++) {
313                     internal_assert(matches[i].type().bits() == t.bits() * 2);
314                     internal_assert(matches[i].type().lanes() == t.lanes());
315                     // debug(4) << "Attemping to narrow " << matches[i] << " to " << t << "\n";
316                     matches[i] = lossless_cast(t, matches[i]);
317                     if (!matches[i].defined()) {
318                         // debug(4) << "failed\n";
319                         all_narrow = false;
320                     } else {
321                         // debug(4) << "success: " << matches[i] << "\n";
322                         internal_assert(matches[i].type() == t);
323                     }
324                 }
325 
326                 if (all_narrow) {
327                     value = call_pattern(pattern, t, matches);
328                     return;
329                 }
330             } else {  // must be a shift
331                 Expr constant = matches[1];
332                 int shift_amount;
333                 bool power_of_two = is_const_power_of_two_integer(constant, &shift_amount);
334                 if (power_of_two && shift_amount < matches[0].type().bits()) {
335                     if (target.bits == 32 && pattern.type == Pattern::RightShift) {
336                         // The arm32 llvm backend wants right shifts to come in as negative values.
337                         shift_amount = -shift_amount;
338                     }
339                     Value *shift = nullptr;
340                     // The arm64 llvm backend wants i32 constants for right shifts.
341                     if (target.bits == 64 && pattern.type == Pattern::RightShift) {
342                         shift = ConstantInt::get(i32_t, shift_amount);
343                     } else {
344                         shift = ConstantInt::get(llvm_type_of(matches[0].type()), shift_amount);
345                     }
346                     value = call_pattern(pattern, llvm_type_of(t),
347                                          {codegen(matches[0]), shift});
348                     return;
349                 }
350             }
351         }
352     }
353 
354     // Catch extract-high-half-of-signed integer pattern and convert
355     // it to extract-high-half-of-unsigned-integer. llvm peephole
356     // optimization recognizes logical shift right but not arithemtic
357     // shift right for this pattern. This matters for vaddhn of signed
358     // integers.
359     if (t.is_vector() &&
360         (t.is_int() || t.is_uint()) &&
361         op->value.type().is_int() &&
362         t.bits() == op->value.type().bits() / 2) {
363         const Div *d = op->value.as<Div>();
364         if (d && is_const(d->b, int64_t(1) << t.bits())) {
365             Type unsigned_type = UInt(t.bits() * 2, t.lanes());
366             Expr replacement = cast(t,
367                                     cast(unsigned_type, d->a) /
368                                         cast(unsigned_type, d->b));
369             replacement.accept(this);
370             return;
371         }
372     }
373 
374     // Catch widening of absolute difference
375     if (t.is_vector() &&
376         (t.is_int() || t.is_uint()) &&
377         (op->value.type().is_int() || op->value.type().is_uint()) &&
378         t.bits() == op->value.type().bits() * 2) {
379         Expr a, b;
380         const Call *c = op->value.as<Call>();
381         if (c && c->is_intrinsic(Call::absd)) {
382             ostringstream ss;
383             int intrin_lanes = 128 / t.bits();
384             ss << "vabdl_" << (c->args[0].type().is_int() ? "i" : "u") << t.bits() / 2 << "x" << intrin_lanes;
385             value = call_intrin(t, intrin_lanes, ss.str(), c->args);
386             return;
387         }
388     }
389 
390     CodeGen_Posix::visit(op);
391 }
392 
visit(const Mul * op)393 void CodeGen_ARM::visit(const Mul *op) {
394     if (neon_intrinsics_disabled()) {
395         CodeGen_Posix::visit(op);
396         return;
397     }
398 
399     // We only have peephole optimizations for int vectors for now
400     if (op->type.is_scalar() || op->type.is_float()) {
401         CodeGen_Posix::visit(op);
402         return;
403     }
404 
405     Type t = op->type;
406     vector<Expr> matches;
407 
408     int shift_amount = 0;
409     if (is_const_power_of_two_integer(op->b, &shift_amount)) {
410         // Let LLVM handle these.
411         CodeGen_Posix::visit(op);
412         return;
413     }
414 
415     // LLVM really struggles to generate mlal unless we generate mull intrinsics
416     // for the multiplication part first.
417     for (size_t i = 0; i < multiplies.size(); i++) {
418         const Pattern &pattern = multiplies[i];
419         //debug(4) << "Trying pattern: " << patterns[i].intrin << " " << patterns[i].pattern << "\n";
420         if (expr_match(pattern.pattern, op, matches)) {
421 
422             //debug(4) << "Match!\n";
423             if (pattern.type == Pattern::Simple) {
424                 value = call_pattern(pattern, t, matches);
425                 return;
426             } else if (pattern.type == Pattern::NarrowArgs) {
427                 Type narrow_t = t.with_bits(t.bits() / 2);
428                 // Try to narrow all of the args.
429                 bool all_narrow = true;
430                 for (size_t i = 0; i < matches.size(); i++) {
431                     internal_assert(matches[i].type().bits() == t.bits());
432                     internal_assert(matches[i].type().lanes() == t.lanes());
433                     // debug(4) << "Attemping to narrow " << matches[i] << " to " << t << "\n";
434                     matches[i] = lossless_cast(narrow_t, matches[i]);
435                     if (!matches[i].defined()) {
436                         // debug(4) << "failed\n";
437                         all_narrow = false;
438                     } else {
439                         // debug(4) << "success: " << matches[i] << "\n";
440                         internal_assert(matches[i].type() == narrow_t);
441                     }
442                 }
443 
444                 if (all_narrow) {
445                     value = call_pattern(pattern, t, matches);
446                     return;
447                 }
448             }
449         }
450     }
451 
452     // Vector multiplies by 3, 5, 7, 9 should do shift-and-add or
453     // shift-and-sub instead to reduce register pressure (the
454     // shift is an immediate)
455     // TODO: Verify this is still good codegen.
456     if (is_const(op->b, 3)) {
457         value = codegen(op->a * 2 + op->a);
458         return;
459     } else if (is_const(op->b, 5)) {
460         value = codegen(op->a * 4 + op->a);
461         return;
462     } else if (is_const(op->b, 7)) {
463         value = codegen(op->a * 8 - op->a);
464         return;
465     } else if (is_const(op->b, 9)) {
466         value = codegen(op->a * 8 + op->a);
467         return;
468     }
469 
470     CodeGen_Posix::visit(op);
471 }
472 
visit(const Div * op)473 void CodeGen_ARM::visit(const Div *op) {
474     if (!neon_intrinsics_disabled() &&
475         op->type.is_vector() && is_two(op->b) &&
476         (op->a.as<Add>() || op->a.as<Sub>())) {
477         vector<Expr> matches;
478         for (size_t i = 0; i < averagings.size(); i++) {
479             if (expr_match(averagings[i].pattern, op->a, matches)) {
480                 value = call_pattern(averagings[i], op->type, matches);
481                 return;
482             }
483         }
484     }
485     CodeGen_Posix::visit(op);
486 }
487 
visit(const Sub * op)488 void CodeGen_ARM::visit(const Sub *op) {
489     if (neon_intrinsics_disabled()) {
490         CodeGen_Posix::visit(op);
491         return;
492     }
493 
494     vector<Expr> matches;
495     for (size_t i = 0; i < negations.size(); i++) {
496         if (op->type.is_vector() &&
497             expr_match(negations[i].pattern, op, matches)) {
498             value = call_pattern(negations[i], op->type, matches);
499             return;
500         }
501     }
502 
503     // llvm will generate floating point negate instructions if we ask for (-0.0f)-x
504     if (op->type.is_float() &&
505         op->type.bits() >= 32 &&
506         is_zero(op->a)) {
507         Constant *a;
508         if (op->type.bits() == 32) {
509             a = ConstantFP::getNegativeZero(f32_t);
510         } else if (op->type.bits() == 64) {
511             a = ConstantFP::getNegativeZero(f64_t);
512         } else {
513             a = nullptr;
514             internal_error << "Unknown bit width for floating point type: " << op->type << "\n";
515         }
516 
517         Value *b = codegen(op->b);
518 
519         if (op->type.lanes() > 1) {
520             a = ConstantVector::getSplat(element_count(op->type.lanes()), a);
521         }
522         value = builder->CreateFSub(a, b);
523         return;
524     }
525 
526     CodeGen_Posix::visit(op);
527 }
528 
visit(const Min * op)529 void CodeGen_ARM::visit(const Min *op) {
530     if (neon_intrinsics_disabled()) {
531         CodeGen_Posix::visit(op);
532         return;
533     }
534 
535     if (op->type == Float(32)) {
536         // Use a 2-wide vector instead
537         Value *undef = UndefValue::get(f32x2);
538         Constant *zero = ConstantInt::get(i32_t, 0);
539         Value *a = codegen(op->a);
540         Value *a_wide = builder->CreateInsertElement(undef, a, zero);
541         Value *b = codegen(op->b);
542         Value *b_wide = builder->CreateInsertElement(undef, b, zero);
543         Value *wide_result;
544         if (target.bits == 32) {
545             wide_result = call_intrin(f32x2, 2, "llvm.arm.neon.vmins.v2f32", {a_wide, b_wide});
546         } else {
547             wide_result = call_intrin(f32x2, 2, "llvm.aarch64.neon.fmin.v2f32", {a_wide, b_wide});
548         }
549         value = builder->CreateExtractElement(wide_result, zero);
550         return;
551     }
552 
553     struct {
554         Type t;
555         const char *op;
556     } patterns[] = {
557         {UInt(8, 8), "v8i8"},
558         {UInt(16, 4), "v4i16"},
559         {UInt(32, 2), "v2i32"},
560         {Int(8, 8), "v8i8"},
561         {Int(16, 4), "v4i16"},
562         {Int(32, 2), "v2i32"},
563         {Float(32, 2), "v2f32"},
564         {UInt(8, 16), "v16i8"},
565         {UInt(16, 8), "v8i16"},
566         {UInt(32, 4), "v4i32"},
567         {Int(8, 16), "v16i8"},
568         {Int(16, 8), "v8i16"},
569         {Int(32, 4), "v4i32"},
570         {Float(32, 4), "v4f32"}};
571 
572     for (size_t i = 0; i < sizeof(patterns) / sizeof(patterns[0]); i++) {
573         bool match = op->type == patterns[i].t;
574 
575         // The 128-bit versions are also used for other vector widths.
576         if (op->type.is_vector() && patterns[i].t.lanes() * patterns[i].t.bits() == 128) {
577             match = match || (op->type.element_of() == patterns[i].t.element_of());
578         }
579 
580         if (match) {
581             string intrin;
582             if (target.bits == 32) {
583                 intrin = (string("llvm.arm.neon.") + (op->type.is_uint() ? "vminu." : "vmins.")) + patterns[i].op;
584             } else {
585                 intrin = "llvm.aarch64.neon.";
586                 if (op->type.is_int()) {
587                     intrin += "smin.";
588                 } else if (op->type.is_float()) {
589                     intrin += "fmin.";
590                 } else {
591                     intrin += "umin.";
592                 }
593                 intrin += patterns[i].op;
594             }
595             value = call_intrin(op->type, patterns[i].t.lanes(), intrin, {op->a, op->b});
596             return;
597         }
598     }
599 
600     CodeGen_Posix::visit(op);
601 }
602 
visit(const Max * op)603 void CodeGen_ARM::visit(const Max *op) {
604     if (neon_intrinsics_disabled()) {
605         CodeGen_Posix::visit(op);
606         return;
607     }
608 
609     if (op->type == Float(32)) {
610         // Use a 2-wide vector instead
611         Value *undef = UndefValue::get(f32x2);
612         Constant *zero = ConstantInt::get(i32_t, 0);
613         Value *a = codegen(op->a);
614         Value *a_wide = builder->CreateInsertElement(undef, a, zero);
615         Value *b = codegen(op->b);
616         Value *b_wide = builder->CreateInsertElement(undef, b, zero);
617         Value *wide_result;
618         if (target.bits == 32) {
619             wide_result = call_intrin(f32x2, 2, "llvm.arm.neon.vmaxs.v2f32", {a_wide, b_wide});
620         } else {
621             wide_result = call_intrin(f32x2, 2, "llvm.aarch64.neon.fmax.v2f32", {a_wide, b_wide});
622         }
623         value = builder->CreateExtractElement(wide_result, zero);
624         return;
625     }
626 
627     struct {
628         Type t;
629         const char *op;
630     } patterns[] = {
631         {UInt(8, 8), "v8i8"},
632         {UInt(16, 4), "v4i16"},
633         {UInt(32, 2), "v2i32"},
634         {Int(8, 8), "v8i8"},
635         {Int(16, 4), "v4i16"},
636         {Int(32, 2), "v2i32"},
637         {Float(32, 2), "v2f32"},
638         {UInt(8, 16), "v16i8"},
639         {UInt(16, 8), "v8i16"},
640         {UInt(32, 4), "v4i32"},
641         {Int(8, 16), "v16i8"},
642         {Int(16, 8), "v8i16"},
643         {Int(32, 4), "v4i32"},
644         {Float(32, 4), "v4f32"}};
645 
646     for (size_t i = 0; i < sizeof(patterns) / sizeof(patterns[0]); i++) {
647         bool match = op->type == patterns[i].t;
648 
649         // The 128-bit versions are also used for other vector widths.
650         if (op->type.is_vector() && patterns[i].t.lanes() * patterns[i].t.bits() == 128) {
651             match = match || (op->type.element_of() == patterns[i].t.element_of());
652         }
653 
654         if (match) {
655             string intrin;
656             if (target.bits == 32) {
657                 intrin = (string("llvm.arm.neon.") + (op->type.is_uint() ? "vmaxu." : "vmaxs.")) + patterns[i].op;
658             } else {
659                 intrin = "llvm.aarch64.neon.";
660                 if (op->type.is_int()) {
661                     intrin += "smax.";
662                 } else if (op->type.is_float()) {
663                     intrin += "fmax.";
664                 } else {
665                     intrin += "umax.";
666                 }
667                 intrin += patterns[i].op;
668             }
669             value = call_intrin(op->type, patterns[i].t.lanes(), intrin, {op->a, op->b});
670             return;
671         }
672     }
673 
674     CodeGen_Posix::visit(op);
675 }
676 
visit(const Store * op)677 void CodeGen_ARM::visit(const Store *op) {
678     // Predicated store
679     if (!is_one(op->predicate)) {
680         CodeGen_Posix::visit(op);
681         return;
682     }
683 
684     if (neon_intrinsics_disabled()) {
685         CodeGen_Posix::visit(op);
686         return;
687     }
688 
689     // A dense store of an interleaving can be done using a vst2 intrinsic
690     const Ramp *ramp = op->index.as<Ramp>();
691 
692     // We only deal with ramps here
693     if (!ramp) {
694         CodeGen_Posix::visit(op);
695         return;
696     }
697 
698     // First dig through let expressions
699     Expr rhs = op->value;
700     vector<pair<string, Expr>> lets;
701     while (const Let *let = rhs.as<Let>()) {
702         rhs = let->body;
703         lets.emplace_back(let->name, let->value);
704     }
705     const Shuffle *shuffle = rhs.as<Shuffle>();
706 
707     // Interleaving store instructions only exist for certain types.
708     bool type_ok_for_vst = false;
709     Type intrin_type = Handle();
710     if (shuffle) {
711         Type t = shuffle->vectors[0].type();
712         intrin_type = t;
713         Type elt = t.element_of();
714         int vec_bits = t.bits() * t.lanes();
715         if (elt == Float(32) ||
716             elt == Int(8) || elt == Int(16) || elt == Int(32) ||
717             elt == UInt(8) || elt == UInt(16) || elt == UInt(32)) {
718             if (vec_bits % 128 == 0) {
719                 type_ok_for_vst = true;
720                 intrin_type = intrin_type.with_lanes(128 / t.bits());
721             } else if (vec_bits % 64 == 0) {
722                 type_ok_for_vst = true;
723                 intrin_type = intrin_type.with_lanes(64 / t.bits());
724             }
725         }
726     }
727 
728     if (is_one(ramp->stride) &&
729         shuffle && shuffle->is_interleave() &&
730         type_ok_for_vst &&
731         2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) {
732 
733         const int num_vecs = shuffle->vectors.size();
734         vector<Value *> args(num_vecs);
735 
736         Type t = shuffle->vectors[0].type();
737 
738         // Assume element-aligned.
739         int alignment = t.bytes();
740 
741         // Codegen the lets
742         for (size_t i = 0; i < lets.size(); i++) {
743             sym_push(lets[i].first, codegen(lets[i].second));
744         }
745 
746         // Codegen all the vector args.
747         for (int i = 0; i < num_vecs; ++i) {
748             args[i] = codegen(shuffle->vectors[i]);
749         }
750 
751         // Declare the function
752         std::ostringstream instr;
753         vector<llvm::Type *> arg_types;
754         if (target.bits == 32) {
755             instr << "llvm.arm.neon.vst"
756                   << num_vecs
757                   << ".p0i8"
758                   << ".v"
759                   << intrin_type.lanes()
760                   << (t.is_float() ? 'f' : 'i')
761                   << t.bits();
762             arg_types = vector<llvm::Type *>(num_vecs + 2, llvm_type_of(intrin_type));
763             arg_types.front() = i8_t->getPointerTo();
764             arg_types.back() = i32_t;
765         } else {
766             instr << "llvm.aarch64.neon.st"
767                   << num_vecs
768                   << ".v"
769                   << intrin_type.lanes()
770                   << (t.is_float() ? 'f' : 'i')
771                   << t.bits()
772                   << ".p0"
773                   << (t.is_float() ? 'f' : 'i')
774                   << t.bits();
775             arg_types = vector<llvm::Type *>(num_vecs + 1, llvm_type_of(intrin_type));
776             arg_types.back() = llvm_type_of(intrin_type.element_of())->getPointerTo();
777         }
778         llvm::FunctionType *fn_type = FunctionType::get(llvm::Type::getVoidTy(*context), arg_types, false);
779         llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
780         internal_assert(fn);
781 
782         // How many vst instructions do we need to generate?
783         int slices = t.lanes() / intrin_type.lanes();
784 
785         internal_assert(slices >= 1);
786         for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
787             Expr slice_base = simplify(ramp->base + i * num_vecs);
788             Expr slice_ramp = Ramp::make(slice_base, ramp->stride, intrin_type.lanes() * num_vecs);
789             Value *ptr = codegen_buffer_pointer(op->name, shuffle->vectors[0].type().element_of(), slice_base);
790 
791             vector<Value *> slice_args = args;
792             // Take a slice of each arg
793             for (int j = 0; j < num_vecs; j++) {
794                 slice_args[j] = slice_vector(slice_args[j], i, intrin_type.lanes());
795             }
796 
797             if (target.bits == 32) {
798                 // The arm32 versions take an i8*, regardless of the type stored.
799                 ptr = builder->CreatePointerCast(ptr, i8_t->getPointerTo());
800                 // Set the pointer argument
801                 slice_args.insert(slice_args.begin(), ptr);
802                 // Set the alignment argument
803                 slice_args.push_back(ConstantInt::get(i32_t, alignment));
804             } else {
805                 // Set the pointer argument
806                 slice_args.push_back(ptr);
807             }
808 
809             CallInst *store = builder->CreateCall(fn, slice_args);
810             add_tbaa_metadata(store, op->name, slice_ramp);
811         }
812 
813         // pop the lets from the symbol table
814         for (size_t i = 0; i < lets.size(); i++) {
815             sym_pop(lets[i].first);
816         }
817 
818         return;
819     }
820 
821     // If the stride is one or minus one, we can deal with that using vanilla codegen
822     const IntImm *stride = ramp->stride.as<IntImm>();
823     if (stride && (stride->value == 1 || stride->value == -1)) {
824         CodeGen_Posix::visit(op);
825         return;
826     }
827 
828     // We have builtins for strided stores with fixed but unknown stride, but they use inline assembly
829     if (target.bits != 64 /* Not yet implemented for aarch64 */) {
830         ostringstream builtin;
831         builtin << "strided_store_"
832                 << (op->value.type().is_float() ? "f" : "i")
833                 << op->value.type().bits()
834                 << "x" << op->value.type().lanes();
835 
836         llvm::Function *fn = module->getFunction(builtin.str());
837         if (fn) {
838             Value *base = codegen_buffer_pointer(op->name, op->value.type().element_of(), ramp->base);
839             Value *stride = codegen(ramp->stride * op->value.type().bytes());
840             Value *val = codegen(op->value);
841             debug(4) << "Creating call to " << builtin.str() << "\n";
842             Value *store_args[] = {base, stride, val};
843             Instruction *store = builder->CreateCall(fn, store_args);
844             (void)store;
845             add_tbaa_metadata(store, op->name, op->index);
846             return;
847         }
848     }
849 
850     CodeGen_Posix::visit(op);
851 }
852 
visit(const Load * op)853 void CodeGen_ARM::visit(const Load *op) {
854     // Predicated load
855     if (!is_one(op->predicate)) {
856         CodeGen_Posix::visit(op);
857         return;
858     }
859 
860     if (neon_intrinsics_disabled()) {
861         CodeGen_Posix::visit(op);
862         return;
863     }
864 
865     const Ramp *ramp = op->index.as<Ramp>();
866 
867     // We only deal with ramps here
868     if (!ramp) {
869         CodeGen_Posix::visit(op);
870         return;
871     }
872 
873     const IntImm *stride = ramp ? ramp->stride.as<IntImm>() : nullptr;
874 
875     // If the stride is one or minus one, we can deal with that using vanilla codegen
876     if (stride && (stride->value == 1 || stride->value == -1)) {
877         CodeGen_Posix::visit(op);
878         return;
879     }
880 
881     // Strided loads with known stride
882     if (stride && stride->value >= 2 && stride->value <= 4) {
883         // Check alignment on the base. Attempt to shift to an earlier
884         // address if it simplifies the expression. This makes
885         // adjacent strided loads shared a vldN op.
886         Expr base = ramp->base;
887         int offset = 0;
888         ModulusRemainder mod_rem = modulus_remainder(ramp->base);
889 
890         const Add *add = base.as<Add>();
891         const IntImm *add_b = add ? add->b.as<IntImm>() : nullptr;
892 
893         if ((mod_rem.modulus % stride->value) == 0) {
894             offset = mod_rem.remainder % stride->value;
895         } else if ((mod_rem.modulus == 1) && add_b) {
896             offset = add_b->value % stride->value;
897             if (offset < 0) {
898                 offset += stride->value;
899             }
900         }
901 
902         if (offset) {
903             base = simplify(base - offset);
904             mod_rem.remainder -= offset;
905             if (mod_rem.modulus) {
906                 mod_rem.remainder = mod_imp(mod_rem.remainder, mod_rem.modulus);
907             }
908         }
909 
910         int alignment = op->type.bytes();
911         alignment *= gcd(mod_rem.modulus, mod_rem.remainder);
912         // Maximum stack alignment on arm is 16 bytes, so we should
913         // never claim alignment greater than that.
914         alignment = gcd(alignment, 16);
915         internal_assert(alignment > 0);
916 
917         // Decide what width to slice things into. If not a multiple
918         // of 64 or 128 bits, then we can't safely slice it up into
919         // some number of vlds, so we hand it over the base class.
920         int bit_width = op->type.bits() * op->type.lanes();
921         int intrin_lanes = 0;
922         if (bit_width % 128 == 0) {
923             intrin_lanes = 128 / op->type.bits();
924         } else if (bit_width % 64 == 0) {
925             intrin_lanes = 64 / op->type.bits();
926         } else {
927             CodeGen_Posix::visit(op);
928             return;
929         }
930 
931         llvm::Type *load_return_type = llvm_type_of(op->type.with_lanes(intrin_lanes * stride->value));
932         llvm::Type *load_return_pointer_type = load_return_type->getPointerTo();
933         Value *undef = UndefValue::get(load_return_type);
934         SmallVector<Constant *, 256> constants;
935         for (int j = 0; j < intrin_lanes; j++) {
936             Constant *constant = ConstantInt::get(i32_t, j * stride->value + offset);
937             constants.push_back(constant);
938         }
939         Constant *constantsV = ConstantVector::get(constants);
940 
941         vector<Value *> results;
942         for (int i = 0; i < op->type.lanes(); i += intrin_lanes) {
943             Expr slice_base = simplify(base + i * ramp->stride);
944             Expr slice_ramp = Ramp::make(slice_base, ramp->stride, intrin_lanes);
945             Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base);
946             Value *bitcastI = builder->CreateBitOrPointerCast(ptr, load_return_pointer_type);
947             LoadInst *loadI = cast<LoadInst>(builder->CreateLoad(bitcastI));
948 #if LLVM_VERSION >= 110
949             loadI->setAlignment(Align(alignment));
950 #elif LLVM_VERSION >= 100
951             loadI->setAlignment(MaybeAlign(alignment));
952 #else
953             loadI->setAlignment(alignment);
954 #endif
955             add_tbaa_metadata(loadI, op->name, slice_ramp);
956             Value *shuffleInstr = builder->CreateShuffleVector(loadI, undef, constantsV);
957             results.push_back(shuffleInstr);
958         }
959 
960         // Concat the results
961         value = concat_vectors(results);
962         return;
963     }
964 
965     // We have builtins for strided loads with fixed but unknown stride, but they use inline assembly.
966     if (target.bits != 64 /* Not yet implemented for aarch64 */) {
967         ostringstream builtin;
968         builtin << "strided_load_"
969                 << (op->type.is_float() ? "f" : "i")
970                 << op->type.bits()
971                 << "x" << op->type.lanes();
972 
973         llvm::Function *fn = module->getFunction(builtin.str());
974         if (fn) {
975             Value *base = codegen_buffer_pointer(op->name, op->type.element_of(), ramp->base);
976             Value *stride = codegen(ramp->stride * op->type.bytes());
977             debug(4) << "Creating call to " << builtin.str() << "\n";
978             Value *args[] = {base, stride};
979             Instruction *load = builder->CreateCall(fn, args, builtin.str());
980             add_tbaa_metadata(load, op->name, op->index);
981             value = load;
982             return;
983         }
984     }
985 
986     CodeGen_Posix::visit(op);
987 }
988 
visit(const Call * op)989 void CodeGen_ARM::visit(const Call *op) {
990     if (op->is_intrinsic(Call::abs) && op->type.is_uint()) {
991         internal_assert(op->args.size() == 1);
992         // If the arg is a subtract with narrowable args, we can use vabdl.
993         const Sub *sub = op->args[0].as<Sub>();
994         if (sub) {
995             Expr a = sub->a, b = sub->b;
996             Type narrow = UInt(a.type().bits() / 2, a.type().lanes());
997             Expr na = lossless_cast(narrow, a);
998             Expr nb = lossless_cast(narrow, b);
999 
1000             // Also try an unsigned narrowing
1001             if (!na.defined() || !nb.defined()) {
1002                 narrow = Int(narrow.bits(), narrow.lanes());
1003                 na = lossless_cast(narrow, a);
1004                 nb = lossless_cast(narrow, b);
1005             }
1006 
1007             if (na.defined() && nb.defined()) {
1008                 Expr absd = Call::make(UInt(narrow.bits(), narrow.lanes()), Call::absd,
1009                                        {na, nb}, Call::PureIntrinsic);
1010 
1011                 absd = Cast::make(op->type, absd);
1012                 codegen(absd);
1013                 return;
1014             }
1015         }
1016     } else if (op->is_intrinsic(Call::sorted_avg)) {
1017         Type ty = op->type;
1018         Type wide_ty = ty.with_bits(ty.bits() * 2);
1019         // This will codegen to vhaddu (arm32) or uhadd (arm64).
1020         value = codegen(cast(ty, (cast(wide_ty, op->args[0]) + cast(wide_ty, op->args[1])) / 2));
1021         return;
1022     }
1023 
1024     CodeGen_Posix::visit(op);
1025 }
1026 
visit(const LT * op)1027 void CodeGen_ARM::visit(const LT *op) {
1028 #if LLVM_VERSION >= 100
1029     if (op->a.type().is_float() && op->type.is_vector()) {
1030         // Fast-math flags confuse LLVM's aarch64 backend, so
1031         // temporarily clear them for this instruction.
1032         // See https://bugs.llvm.org/show_bug.cgi?id=45036
1033         llvm::IRBuilderBase::FastMathFlagGuard guard(*builder);
1034         builder->clearFastMathFlags();
1035         CodeGen_Posix::visit(op);
1036         return;
1037     }
1038 #endif
1039 
1040     CodeGen_Posix::visit(op);
1041 }
1042 
visit(const LE * op)1043 void CodeGen_ARM::visit(const LE *op) {
1044 #if LLVM_VERSION >= 100
1045     if (op->a.type().is_float() && op->type.is_vector()) {
1046         // Fast-math flags confuse LLVM's aarch64 backend, so
1047         // temporarily clear them for this instruction.
1048         // See https://bugs.llvm.org/show_bug.cgi?id=45036
1049         llvm::IRBuilderBase::FastMathFlagGuard guard(*builder);
1050         builder->clearFastMathFlags();
1051         CodeGen_Posix::visit(op);
1052         return;
1053     }
1054 #endif
1055 
1056     CodeGen_Posix::visit(op);
1057 }
1058 
codegen_vector_reduce(const VectorReduce * op,const Expr & init)1059 void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) {
1060     if (neon_intrinsics_disabled() ||
1061         op->op == VectorReduce::Or ||
1062         op->op == VectorReduce::And ||
1063         op->op == VectorReduce::Mul ||
1064         // LLVM 9 has bugs in the arm backend for vector reduce
1065         // ops. See https://github.com/halide/Halide/issues/5081
1066         !(LLVM_VERSION >= 100)) {
1067         CodeGen_Posix::codegen_vector_reduce(op, init);
1068         return;
1069     }
1070 
1071     // ARM has a variety of pairwise reduction ops for +, min,
1072     // max. The versions that do not widen take two 64-bit args and
1073     // return one 64-bit vector of the same type. The versions that
1074     // widen take one arg and return something with half the vector
1075     // lanes and double the bit-width.
1076 
1077     int factor = op->value.type().lanes() / op->type.lanes();
1078 
1079     // These are the types for which we have reduce intrinsics in the
1080     // runtime.
1081     bool have_reduce_intrinsic = (op->type.is_int() ||
1082                                   op->type.is_uint() ||
1083                                   op->type.is_float());
1084 
1085     // We don't have 16-bit float or bfloat horizontal ops
1086     if (op->type.is_bfloat() || (op->type.is_float() && op->type.bits() < 32)) {
1087         have_reduce_intrinsic = false;
1088     }
1089 
1090     // Only aarch64 has float64 horizontal ops
1091     if (target.bits == 32 && op->type.element_of() == Float(64)) {
1092         have_reduce_intrinsic = false;
1093     }
1094 
1095     // For 64-bit integers, we only have addition, not min/max
1096     if (op->type.bits() == 64 &&
1097         !op->type.is_float() &&
1098         op->op != VectorReduce::Add) {
1099         have_reduce_intrinsic = false;
1100     }
1101 
1102     // We only have intrinsics that reduce by a factor of two
1103     if (factor != 2) {
1104         have_reduce_intrinsic = false;
1105     }
1106 
1107     if (have_reduce_intrinsic) {
1108         Expr arg = op->value;
1109         if (op->op == VectorReduce::Add &&
1110             op->type.bits() >= 16 &&
1111             !op->type.is_float()) {
1112             Type narrower_type = arg.type().with_bits(arg.type().bits() / 2);
1113             Expr narrower = lossless_cast(narrower_type, arg);
1114             if (!narrower.defined() && arg.type().is_int()) {
1115                 // We can also safely accumulate from a uint into a
1116                 // wider int, because the addition uses at most one
1117                 // extra bit.
1118                 narrower = lossless_cast(narrower_type.with_code(Type::UInt), arg);
1119             }
1120             if (narrower.defined()) {
1121                 arg = narrower;
1122             }
1123         }
1124         int output_bits;
1125         if (target.bits == 32 && arg.type().bits() == op->type.bits()) {
1126             // For the non-widening version, the output must be 64-bit
1127             output_bits = 64;
1128         } else if (op->type.bits() * op->type.lanes() <= 64) {
1129             // No point using the 128-bit version of the instruction if the output is narrow.
1130             output_bits = 64;
1131         } else {
1132             output_bits = 128;
1133         }
1134 
1135         const int output_lanes = output_bits / op->type.bits();
1136         Type intrin_type = op->type.with_lanes(output_lanes);
1137         Type arg_type = arg.type().with_lanes(output_lanes * 2);
1138         if (op->op == VectorReduce::Add &&
1139             arg.type().bits() == op->type.bits() &&
1140             arg_type.is_uint()) {
1141             // For non-widening additions, there is only a signed
1142             // version (because it's equivalent).
1143             arg_type = arg_type.with_code(Type::Int);
1144             intrin_type = intrin_type.with_code(Type::Int);
1145         } else if (arg.type().is_uint() && intrin_type.is_int()) {
1146             // Use the uint version
1147             intrin_type = intrin_type.with_code(Type::UInt);
1148         }
1149 
1150         std::stringstream ss;
1151         vector<Expr> args;
1152         ss << "pairwise_" << op->op << "_" << intrin_type << "_" << arg_type;
1153         Expr accumulator = init;
1154         if (op->op == VectorReduce::Add &&
1155             accumulator.defined() &&
1156             arg_type.bits() < intrin_type.bits()) {
1157             // We can use the accumulating variant
1158             ss << "_accumulate";
1159             args.push_back(init);
1160             accumulator = Expr();
1161         }
1162         args.push_back(arg);
1163         value = call_intrin(op->type, output_lanes, ss.str(), args);
1164 
1165         if (accumulator.defined()) {
1166             // We still have an initial value to take care of
1167             string n = unique_name('t');
1168             sym_push(n, value);
1169             Expr v = Variable::make(accumulator.type(), n);
1170             switch (op->op) {
1171             case VectorReduce::Add:
1172                 accumulator += v;
1173                 break;
1174             case VectorReduce::Min:
1175                 accumulator = min(accumulator, v);
1176                 break;
1177             case VectorReduce::Max:
1178                 accumulator = max(accumulator, v);
1179                 break;
1180             default:
1181                 internal_error << "unreachable";
1182             }
1183             codegen(accumulator);
1184             sym_pop(n);
1185         }
1186 
1187         return;
1188     }
1189 
1190     // Pattern-match 8-bit dot product instructions available on newer
1191     // ARM cores.
1192     if (target.has_feature(Target::ARMDotProd) &&
1193         factor % 4 == 0 &&
1194         op->op == VectorReduce::Add &&
1195         target.bits == 64 &&
1196         (op->type.element_of() == Int(32) ||
1197          op->type.element_of() == UInt(32))) {
1198         const Mul *mul = op->value.as<Mul>();
1199         if (mul) {
1200             const int input_lanes = mul->type.lanes();
1201             Expr a = lossless_cast(UInt(8, input_lanes), mul->a);
1202             Expr b = lossless_cast(UInt(8, input_lanes), mul->b);
1203             if (!a.defined()) {
1204                 a = lossless_cast(Int(8, input_lanes), mul->a);
1205                 b = lossless_cast(Int(8, input_lanes), mul->b);
1206             }
1207             if (a.defined() && b.defined()) {
1208                 if (factor != 4) {
1209                     Expr equiv = VectorReduce::make(op->op, op->value, input_lanes / 4);
1210                     equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
1211                     codegen_vector_reduce(equiv.as<VectorReduce>(), init);
1212                     return;
1213                 }
1214                 Expr i = init;
1215                 if (!i.defined()) {
1216                     i = make_zero(op->type);
1217                 }
1218                 vector<Expr> args{i, a, b};
1219                 if (op->type.lanes() <= 2) {
1220                     if (op->type.is_uint()) {
1221                         value = call_intrin(op->type, 2, "llvm.aarch64.neon.udot.v2i32.v8i8", args);
1222                     } else {
1223                         value = call_intrin(op->type, 2, "llvm.aarch64.neon.sdot.v2i32.v8i8", args);
1224                     }
1225                 } else {
1226                     if (op->type.is_uint()) {
1227                         value = call_intrin(op->type, 4, "llvm.aarch64.neon.udot.v4i32.v16i8", args);
1228                     } else {
1229                         value = call_intrin(op->type, 4, "llvm.aarch64.neon.sdot.v4i32.v16i8", args);
1230                     }
1231                 }
1232                 return;
1233             }
1234         }
1235     }
1236 
1237     CodeGen_Posix::codegen_vector_reduce(op, init);
1238 }
1239 
mcpu() const1240 string CodeGen_ARM::mcpu() const {
1241     if (target.bits == 32) {
1242         if (target.has_feature(Target::ARMv7s)) {
1243             return "swift";
1244         } else {
1245             return "cortex-a9";
1246         }
1247     } else {
1248         if (target.os == Target::IOS) {
1249             return "cyclone";
1250         } else if (target.os == Target::OSX) {
1251             return "apple-a12";
1252         } else {
1253             return "generic";
1254         }
1255     }
1256 }
1257 
mattrs() const1258 string CodeGen_ARM::mattrs() const {
1259     if (target.bits == 32) {
1260         if (target.has_feature(Target::ARMv7s)) {
1261             return "+neon";
1262         }
1263         if (!target.has_feature(Target::NoNEON)) {
1264             return "+neon";
1265         } else {
1266             return "-neon";
1267         }
1268     } else {
1269         // TODO: Should Halide's SVE flags be 64-bit only?
1270         string arch_flags;
1271         if (target.has_feature(Target::SVE2)) {
1272             arch_flags = "+sve2";
1273         } else if (target.has_feature(Target::SVE)) {
1274             arch_flags = "+sve";
1275         }
1276 
1277         if (target.has_feature(Target::ARMDotProd)) {
1278             arch_flags += "+dotprod";
1279         }
1280 
1281         if (target.os == Target::IOS || target.os == Target::OSX) {
1282             return arch_flags + "+reserve-x18";
1283         } else {
1284             return arch_flags;
1285         }
1286     }
1287 }
1288 
use_soft_float_abi() const1289 bool CodeGen_ARM::use_soft_float_abi() const {
1290     // One expects the flag is irrelevant on 64-bit, but we'll make the logic
1291     // exhaustive anyway. It is not clear the armv7s case is necessary either.
1292     return target.has_feature(Target::SoftFloatABI) ||
1293            (target.bits == 32 &&
1294             ((target.os == Target::Android) ||
1295              (target.os == Target::IOS && !target.has_feature(Target::ARMv7s))));
1296 }
1297 
native_vector_bits() const1298 int CodeGen_ARM::native_vector_bits() const {
1299     return 128;
1300 }
1301 
1302 }  // namespace Internal
1303 }  // namespace Halide
1304