1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "MCTargetDesc/AArch64AddressingModes.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/TargetTransformInfo.h"
14 #include "llvm/CodeGen/BasicTTIImpl.h"
15 #include "llvm/CodeGen/CostTable.h"
16 #include "llvm/CodeGen/TargetLowering.h"
17 #include "llvm/IR/IntrinsicInst.h"
18 #include "llvm/IR/IntrinsicsAArch64.h"
19 #include "llvm/IR/PatternMatch.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Transforms/InstCombine/InstCombiner.h"
22 #include <algorithm>
23 using namespace llvm;
24 using namespace llvm::PatternMatch;
25
26 #define DEBUG_TYPE "aarch64tti"
27
28 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
29 cl::init(true), cl::Hidden);
30
areInlineCompatible(const Function * Caller,const Function * Callee) const31 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
32 const Function *Callee) const {
33 const TargetMachine &TM = getTLI()->getTargetMachine();
34
35 const FeatureBitset &CallerBits =
36 TM.getSubtargetImpl(*Caller)->getFeatureBits();
37 const FeatureBitset &CalleeBits =
38 TM.getSubtargetImpl(*Callee)->getFeatureBits();
39
40 // Inline a callee if its target-features are a subset of the callers
41 // target-features.
42 return (CallerBits & CalleeBits) == CalleeBits;
43 }
44
45 /// Calculate the cost of materializing a 64-bit value. This helper
46 /// method might only calculate a fraction of a larger immediate. Therefore it
47 /// is valid to return a cost of ZERO.
getIntImmCost(int64_t Val)48 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
49 // Check if the immediate can be encoded within an instruction.
50 if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
51 return 0;
52
53 if (Val < 0)
54 Val = ~Val;
55
56 // Calculate how many moves we will need to materialize this constant.
57 SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
58 AArch64_IMM::expandMOVImm(Val, 64, Insn);
59 return Insn.size();
60 }
61
62 /// Calculate the cost of materializing the given constant.
getIntImmCost(const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind)63 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
64 TTI::TargetCostKind CostKind) {
65 assert(Ty->isIntegerTy());
66
67 unsigned BitSize = Ty->getPrimitiveSizeInBits();
68 if (BitSize == 0)
69 return ~0U;
70
71 // Sign-extend all constants to a multiple of 64-bit.
72 APInt ImmVal = Imm;
73 if (BitSize & 0x3f)
74 ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
75
76 // Split the constant into 64-bit chunks and calculate the cost for each
77 // chunk.
78 InstructionCost Cost = 0;
79 for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
80 APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
81 int64_t Val = Tmp.getSExtValue();
82 Cost += getIntImmCost(Val);
83 }
84 // We need at least one instruction to materialze the constant.
85 return std::max<InstructionCost>(1, Cost);
86 }
87
getIntImmCostInst(unsigned Opcode,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind,Instruction * Inst)88 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
89 const APInt &Imm, Type *Ty,
90 TTI::TargetCostKind CostKind,
91 Instruction *Inst) {
92 assert(Ty->isIntegerTy());
93
94 unsigned BitSize = Ty->getPrimitiveSizeInBits();
95 // There is no cost model for constants with a bit size of 0. Return TCC_Free
96 // here, so that constant hoisting will ignore this constant.
97 if (BitSize == 0)
98 return TTI::TCC_Free;
99
100 unsigned ImmIdx = ~0U;
101 switch (Opcode) {
102 default:
103 return TTI::TCC_Free;
104 case Instruction::GetElementPtr:
105 // Always hoist the base address of a GetElementPtr.
106 if (Idx == 0)
107 return 2 * TTI::TCC_Basic;
108 return TTI::TCC_Free;
109 case Instruction::Store:
110 ImmIdx = 0;
111 break;
112 case Instruction::Add:
113 case Instruction::Sub:
114 case Instruction::Mul:
115 case Instruction::UDiv:
116 case Instruction::SDiv:
117 case Instruction::URem:
118 case Instruction::SRem:
119 case Instruction::And:
120 case Instruction::Or:
121 case Instruction::Xor:
122 case Instruction::ICmp:
123 ImmIdx = 1;
124 break;
125 // Always return TCC_Free for the shift value of a shift instruction.
126 case Instruction::Shl:
127 case Instruction::LShr:
128 case Instruction::AShr:
129 if (Idx == 1)
130 return TTI::TCC_Free;
131 break;
132 case Instruction::Trunc:
133 case Instruction::ZExt:
134 case Instruction::SExt:
135 case Instruction::IntToPtr:
136 case Instruction::PtrToInt:
137 case Instruction::BitCast:
138 case Instruction::PHI:
139 case Instruction::Call:
140 case Instruction::Select:
141 case Instruction::Ret:
142 case Instruction::Load:
143 break;
144 }
145
146 if (Idx == ImmIdx) {
147 int NumConstants = (BitSize + 63) / 64;
148 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
149 return (Cost <= NumConstants * TTI::TCC_Basic)
150 ? static_cast<int>(TTI::TCC_Free)
151 : Cost;
152 }
153 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
154 }
155
156 InstructionCost
getIntImmCostIntrin(Intrinsic::ID IID,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind)157 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
158 const APInt &Imm, Type *Ty,
159 TTI::TargetCostKind CostKind) {
160 assert(Ty->isIntegerTy());
161
162 unsigned BitSize = Ty->getPrimitiveSizeInBits();
163 // There is no cost model for constants with a bit size of 0. Return TCC_Free
164 // here, so that constant hoisting will ignore this constant.
165 if (BitSize == 0)
166 return TTI::TCC_Free;
167
168 // Most (all?) AArch64 intrinsics do not support folding immediates into the
169 // selected instruction, so we compute the materialization cost for the
170 // immediate directly.
171 if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
172 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
173
174 switch (IID) {
175 default:
176 return TTI::TCC_Free;
177 case Intrinsic::sadd_with_overflow:
178 case Intrinsic::uadd_with_overflow:
179 case Intrinsic::ssub_with_overflow:
180 case Intrinsic::usub_with_overflow:
181 case Intrinsic::smul_with_overflow:
182 case Intrinsic::umul_with_overflow:
183 if (Idx == 1) {
184 int NumConstants = (BitSize + 63) / 64;
185 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
186 return (Cost <= NumConstants * TTI::TCC_Basic)
187 ? static_cast<int>(TTI::TCC_Free)
188 : Cost;
189 }
190 break;
191 case Intrinsic::experimental_stackmap:
192 if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
193 return TTI::TCC_Free;
194 break;
195 case Intrinsic::experimental_patchpoint_void:
196 case Intrinsic::experimental_patchpoint_i64:
197 if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
198 return TTI::TCC_Free;
199 break;
200 case Intrinsic::experimental_gc_statepoint:
201 if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
202 return TTI::TCC_Free;
203 break;
204 }
205 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
206 }
207
208 TargetTransformInfo::PopcntSupportKind
getPopcntSupport(unsigned TyWidth)209 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
210 assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
211 if (TyWidth == 32 || TyWidth == 64)
212 return TTI::PSK_FastHardware;
213 // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
214 return TTI::PSK_Software;
215 }
216
217 InstructionCost
getIntrinsicInstrCost(const IntrinsicCostAttributes & ICA,TTI::TargetCostKind CostKind)218 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
219 TTI::TargetCostKind CostKind) {
220 auto *RetTy = ICA.getReturnType();
221 switch (ICA.getID()) {
222 case Intrinsic::umin:
223 case Intrinsic::umax: {
224 auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
225 // umin(x,y) -> sub(x,usubsat(x,y))
226 // umax(x,y) -> add(x,usubsat(y,x))
227 if (LT.second == MVT::v2i64)
228 return LT.first * 2;
229 LLVM_FALLTHROUGH;
230 }
231 case Intrinsic::smin:
232 case Intrinsic::smax: {
233 static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
234 MVT::v8i16, MVT::v2i32, MVT::v4i32};
235 auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
236 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; }))
237 return LT.first;
238 break;
239 }
240 case Intrinsic::sadd_sat:
241 case Intrinsic::ssub_sat:
242 case Intrinsic::uadd_sat:
243 case Intrinsic::usub_sat: {
244 static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
245 MVT::v8i16, MVT::v2i32, MVT::v4i32,
246 MVT::v2i64};
247 auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
248 // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
249 // need to extend the type, as it uses shr(qadd(shl, shl)).
250 unsigned Instrs =
251 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
252 if (any_of(ValidSatTys, [<](MVT M) { return M == LT.second; }))
253 return LT.first * Instrs;
254 break;
255 }
256 case Intrinsic::abs: {
257 static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
258 MVT::v8i16, MVT::v2i32, MVT::v4i32,
259 MVT::v2i64};
260 auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
261 if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; }))
262 return LT.first;
263 break;
264 }
265 case Intrinsic::experimental_stepvector: {
266 InstructionCost Cost = 1; // Cost of the `index' instruction
267 auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
268 // Legalisation of illegal vectors involves an `index' instruction plus
269 // (LT.first - 1) vector adds.
270 if (LT.first > 1) {
271 Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
272 InstructionCost AddCost =
273 getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
274 Cost += AddCost * (LT.first - 1);
275 }
276 return Cost;
277 }
278 default:
279 break;
280 }
281 return BaseT::getIntrinsicInstrCost(ICA, CostKind);
282 }
283
284 /// The function will remove redundant reinterprets casting in the presence
285 /// of the control flow
processPhiNode(InstCombiner & IC,IntrinsicInst & II)286 static Optional<Instruction *> processPhiNode(InstCombiner &IC,
287 IntrinsicInst &II) {
288 SmallVector<Instruction *, 32> Worklist;
289 auto RequiredType = II.getType();
290
291 auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
292 assert(PN && "Expected Phi Node!");
293
294 // Don't create a new Phi unless we can remove the old one.
295 if (!PN->hasOneUse())
296 return None;
297
298 for (Value *IncValPhi : PN->incoming_values()) {
299 auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
300 if (!Reinterpret ||
301 Reinterpret->getIntrinsicID() !=
302 Intrinsic::aarch64_sve_convert_to_svbool ||
303 RequiredType != Reinterpret->getArgOperand(0)->getType())
304 return None;
305 }
306
307 // Create the new Phi
308 LLVMContext &Ctx = PN->getContext();
309 IRBuilder<> Builder(Ctx);
310 Builder.SetInsertPoint(PN);
311 PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
312 Worklist.push_back(PN);
313
314 for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
315 auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
316 NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
317 Worklist.push_back(Reinterpret);
318 }
319
320 // Cleanup Phi Node and reinterprets
321 return IC.replaceInstUsesWith(II, NPN);
322 }
323
instCombineConvertFromSVBool(InstCombiner & IC,IntrinsicInst & II)324 static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
325 IntrinsicInst &II) {
326 // If the reinterpret instruction operand is a PHI Node
327 if (isa<PHINode>(II.getArgOperand(0)))
328 return processPhiNode(IC, II);
329
330 SmallVector<Instruction *, 32> CandidatesForRemoval;
331 Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
332
333 const auto *IVTy = cast<VectorType>(II.getType());
334
335 // Walk the chain of conversions.
336 while (Cursor) {
337 // If the type of the cursor has fewer lanes than the final result, zeroing
338 // must take place, which breaks the equivalence chain.
339 const auto *CursorVTy = cast<VectorType>(Cursor->getType());
340 if (CursorVTy->getElementCount().getKnownMinValue() <
341 IVTy->getElementCount().getKnownMinValue())
342 break;
343
344 // If the cursor has the same type as I, it is a viable replacement.
345 if (Cursor->getType() == IVTy)
346 EarliestReplacement = Cursor;
347
348 auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
349
350 // If this is not an SVE conversion intrinsic, this is the end of the chain.
351 if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
352 Intrinsic::aarch64_sve_convert_to_svbool ||
353 IntrinsicCursor->getIntrinsicID() ==
354 Intrinsic::aarch64_sve_convert_from_svbool))
355 break;
356
357 CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
358 Cursor = IntrinsicCursor->getOperand(0);
359 }
360
361 // If no viable replacement in the conversion chain was found, there is
362 // nothing to do.
363 if (!EarliestReplacement)
364 return None;
365
366 return IC.replaceInstUsesWith(II, EarliestReplacement);
367 }
368
instCombineSVEDup(InstCombiner & IC,IntrinsicInst & II)369 static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
370 IntrinsicInst &II) {
371 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
372 if (!Pg)
373 return None;
374
375 if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
376 return None;
377
378 const auto PTruePattern =
379 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
380 if (PTruePattern != AArch64SVEPredPattern::vl1)
381 return None;
382
383 // The intrinsic is inserting into lane zero so use an insert instead.
384 auto *IdxTy = Type::getInt64Ty(II.getContext());
385 auto *Insert = InsertElementInst::Create(
386 II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
387 Insert->insertBefore(&II);
388 Insert->takeName(&II);
389
390 return IC.replaceInstUsesWith(II, Insert);
391 }
392
instCombineSVELast(InstCombiner & IC,IntrinsicInst & II)393 static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
394 IntrinsicInst &II) {
395 Value *Pg = II.getArgOperand(0);
396 Value *Vec = II.getArgOperand(1);
397 bool IsAfter = II.getIntrinsicID() == Intrinsic::aarch64_sve_lasta;
398
399 auto *C = dyn_cast<Constant>(Pg);
400 if (IsAfter && C && C->isNullValue()) {
401 // The intrinsic is extracting lane 0 so use an extract instead.
402 auto *IdxTy = Type::getInt64Ty(II.getContext());
403 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
404 Extract->insertBefore(&II);
405 Extract->takeName(&II);
406 return IC.replaceInstUsesWith(II, Extract);
407 }
408
409 auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
410 if (!IntrPG)
411 return None;
412
413 if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
414 return None;
415
416 const auto PTruePattern =
417 cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
418
419 // Can the intrinsic's predicate be converted to a known constant index?
420 unsigned Idx;
421 switch (PTruePattern) {
422 default:
423 return None;
424 case AArch64SVEPredPattern::vl1:
425 Idx = 0;
426 break;
427 case AArch64SVEPredPattern::vl2:
428 Idx = 1;
429 break;
430 case AArch64SVEPredPattern::vl3:
431 Idx = 2;
432 break;
433 case AArch64SVEPredPattern::vl4:
434 Idx = 3;
435 break;
436 case AArch64SVEPredPattern::vl5:
437 Idx = 4;
438 break;
439 case AArch64SVEPredPattern::vl6:
440 Idx = 5;
441 break;
442 case AArch64SVEPredPattern::vl7:
443 Idx = 6;
444 break;
445 case AArch64SVEPredPattern::vl8:
446 Idx = 7;
447 break;
448 case AArch64SVEPredPattern::vl16:
449 Idx = 15;
450 break;
451 }
452
453 // Increment the index if extracting the element after the last active
454 // predicate element.
455 if (IsAfter)
456 ++Idx;
457
458 // Ignore extracts whose index is larger than the known minimum vector
459 // length. NOTE: This is an artificial constraint where we prefer to
460 // maintain what the user asked for until an alternative is proven faster.
461 auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
462 if (Idx >= PgVTy->getMinNumElements())
463 return None;
464
465 // The intrinsic is extracting a fixed lane so use an extract instead.
466 auto *IdxTy = Type::getInt64Ty(II.getContext());
467 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
468 Extract->insertBefore(&II);
469 Extract->takeName(&II);
470 return IC.replaceInstUsesWith(II, Extract);
471 }
472
instCombineRDFFR(InstCombiner & IC,IntrinsicInst & II)473 static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
474 IntrinsicInst &II) {
475 LLVMContext &Ctx = II.getContext();
476 IRBuilder<> Builder(Ctx);
477 Builder.SetInsertPoint(&II);
478 // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
479 // can work with RDFFR_PP for ptest elimination.
480 auto *AllPat =
481 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
482 auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
483 {II.getType()}, {AllPat});
484 auto *RDFFR =
485 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
486 RDFFR->takeName(&II);
487 return IC.replaceInstUsesWith(II, RDFFR);
488 }
489
490 Optional<Instruction *>
instCombineIntrinsic(InstCombiner & IC,IntrinsicInst & II) const491 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
492 IntrinsicInst &II) const {
493 Intrinsic::ID IID = II.getIntrinsicID();
494 switch (IID) {
495 default:
496 break;
497 case Intrinsic::aarch64_sve_convert_from_svbool:
498 return instCombineConvertFromSVBool(IC, II);
499 case Intrinsic::aarch64_sve_dup:
500 return instCombineSVEDup(IC, II);
501 case Intrinsic::aarch64_sve_rdffr:
502 return instCombineRDFFR(IC, II);
503 case Intrinsic::aarch64_sve_lasta:
504 case Intrinsic::aarch64_sve_lastb:
505 return instCombineSVELast(IC, II);
506 }
507
508 return None;
509 }
510
isWideningInstruction(Type * DstTy,unsigned Opcode,ArrayRef<const Value * > Args)511 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
512 ArrayRef<const Value *> Args) {
513
514 // A helper that returns a vector type from the given type. The number of
515 // elements in type Ty determine the vector width.
516 auto toVectorTy = [&](Type *ArgTy) {
517 return VectorType::get(ArgTy->getScalarType(),
518 cast<VectorType>(DstTy)->getElementCount());
519 };
520
521 // Exit early if DstTy is not a vector type whose elements are at least
522 // 16-bits wide.
523 if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16)
524 return false;
525
526 // Determine if the operation has a widening variant. We consider both the
527 // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
528 // instructions.
529 //
530 // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we
531 // verify that their extending operands are eliminated during code
532 // generation.
533 switch (Opcode) {
534 case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
535 case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
536 break;
537 default:
538 return false;
539 }
540
541 // To be a widening instruction (either the "wide" or "long" versions), the
542 // second operand must be a sign- or zero extend having a single user. We
543 // only consider extends having a single user because they may otherwise not
544 // be eliminated.
545 if (Args.size() != 2 ||
546 (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])) ||
547 !Args[1]->hasOneUse())
548 return false;
549 auto *Extend = cast<CastInst>(Args[1]);
550
551 // Legalize the destination type and ensure it can be used in a widening
552 // operation.
553 auto DstTyL = TLI->getTypeLegalizationCost(DL, DstTy);
554 unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
555 if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
556 return false;
557
558 // Legalize the source type and ensure it can be used in a widening
559 // operation.
560 auto *SrcTy = toVectorTy(Extend->getSrcTy());
561 auto SrcTyL = TLI->getTypeLegalizationCost(DL, SrcTy);
562 unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
563 if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
564 return false;
565
566 // Get the total number of vector elements in the legalized types.
567 InstructionCost NumDstEls =
568 DstTyL.first * DstTyL.second.getVectorMinNumElements();
569 InstructionCost NumSrcEls =
570 SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
571
572 // Return true if the legalized types have the same number of vector elements
573 // and the destination element type size is twice that of the source type.
574 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
575 }
576
getCastInstrCost(unsigned Opcode,Type * Dst,Type * Src,TTI::CastContextHint CCH,TTI::TargetCostKind CostKind,const Instruction * I)577 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
578 Type *Src,
579 TTI::CastContextHint CCH,
580 TTI::TargetCostKind CostKind,
581 const Instruction *I) {
582 int ISD = TLI->InstructionOpcodeToISD(Opcode);
583 assert(ISD && "Invalid opcode");
584
585 // If the cast is observable, and it is used by a widening instruction (e.g.,
586 // uaddl, saddw, etc.), it may be free.
587 if (I && I->hasOneUse()) {
588 auto *SingleUser = cast<Instruction>(*I->user_begin());
589 SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
590 if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
591 // If the cast is the second operand, it is free. We will generate either
592 // a "wide" or "long" version of the widening instruction.
593 if (I == SingleUser->getOperand(1))
594 return 0;
595 // If the cast is not the second operand, it will be free if it looks the
596 // same as the second operand. In this case, we will generate a "long"
597 // version of the widening instruction.
598 if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
599 if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
600 cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
601 return 0;
602 }
603 }
604
605 // TODO: Allow non-throughput costs that aren't binary.
606 auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
607 if (CostKind != TTI::TCK_RecipThroughput)
608 return Cost == 0 ? 0 : 1;
609 return Cost;
610 };
611
612 EVT SrcTy = TLI->getValueType(DL, Src);
613 EVT DstTy = TLI->getValueType(DL, Dst);
614
615 if (!SrcTy.isSimple() || !DstTy.isSimple())
616 return AdjustCost(
617 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
618
619 static const TypeConversionCostTblEntry
620 ConversionTbl[] = {
621 { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1 },
622 { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 0 },
623 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 3 },
624 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 },
625
626 // Truncations on nxvmiN
627 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
628 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
629 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
630 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
631 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
632 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
633 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
634 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
635 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
636 { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
637 { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
638 { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
639 { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
640 { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
641 { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
642
643 // The number of shll instructions for the extension.
644 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
645 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
646 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2 },
647 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2 },
648 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3 },
649 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3 },
650 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2 },
651 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2 },
652 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7 },
653 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7 },
654 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6 },
655 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6 },
656 { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
657 { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
658 { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
659 { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
660
661 // LowerVectorINT_TO_FP:
662 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
663 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
664 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
665 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
666 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
667 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
668
669 // Complex: to v2f32
670 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 },
671 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
672 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
673 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 },
674 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
675 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
676
677 // Complex: to v4f32
678 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 4 },
679 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
680 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3 },
681 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
682
683 // Complex: to v8f32
684 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 },
685 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
686 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 },
687 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
688
689 // Complex: to v16f32
690 { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
691 { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
692
693 // Complex: to v2f64
694 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 },
695 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
696 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
697 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 },
698 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
699 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
700
701
702 // LowerVectorFP_TO_INT
703 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
704 { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
705 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
706 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
707 { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
708 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
709
710 // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
711 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
712 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
713 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 1 },
714 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
715 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
716 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 1 },
717
718 // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
719 { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
720 { ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 2 },
721 { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
722 { ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2 },
723
724 // Complex, from nxv2f32.
725 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
726 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
727 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
728 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1 },
729 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
730 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
731 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
732 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1 },
733
734 // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
735 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
736 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
737 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f64, 2 },
738 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
739 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
740 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2 },
741
742 // Complex, from nxv2f64.
743 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
744 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
745 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
746 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1 },
747 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
748 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
749 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
750 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1 },
751
752 // Complex, from nxv4f32.
753 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
754 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
755 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
756 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1 },
757 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
758 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
759 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
760 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1 },
761
762 // Complex, from nxv8f64. Illegal -> illegal conversions not required.
763 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
764 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7 },
765 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
766 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7 },
767
768 // Complex, from nxv4f64. Illegal -> illegal conversions not required.
769 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
770 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
771 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3 },
772 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
773 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
774 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3 },
775
776 // Complex, from nxv8f32. Illegal -> illegal conversions not required.
777 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
778 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3 },
779 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
780 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3 },
781
782 // Complex, from nxv8f16.
783 { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
784 { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
785 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
786 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1 },
787 { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
788 { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
789 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
790 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1 },
791
792 // Complex, from nxv4f16.
793 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
794 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
795 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
796 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1 },
797 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
798 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
799 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
800 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1 },
801
802 // Complex, from nxv2f16.
803 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
804 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
805 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
806 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1 },
807 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
808 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
809 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
810 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1 },
811
812 // Truncate from nxvmf32 to nxvmf16.
813 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
814 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
815 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
816
817 // Truncate from nxvmf64 to nxvmf16.
818 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
819 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
820 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
821
822 // Truncate from nxvmf64 to nxvmf32.
823 { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
824 { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
825 { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
826
827 // Extend from nxvmf16 to nxvmf32.
828 { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
829 { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
830 { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
831
832 // Extend from nxvmf16 to nxvmf64.
833 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
834 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
835 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
836
837 // Extend from nxvmf32 to nxvmf64.
838 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
839 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
840 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
841
842 };
843
844 if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
845 DstTy.getSimpleVT(),
846 SrcTy.getSimpleVT()))
847 return AdjustCost(Entry->Cost);
848
849 return AdjustCost(
850 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
851 }
852
getExtractWithExtendCost(unsigned Opcode,Type * Dst,VectorType * VecTy,unsigned Index)853 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
854 Type *Dst,
855 VectorType *VecTy,
856 unsigned Index) {
857
858 // Make sure we were given a valid extend opcode.
859 assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
860 "Invalid opcode");
861
862 // We are extending an element we extract from a vector, so the source type
863 // of the extend is the element type of the vector.
864 auto *Src = VecTy->getElementType();
865
866 // Sign- and zero-extends are for integer types only.
867 assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
868
869 // Get the cost for the extract. We compute the cost (if any) for the extend
870 // below.
871 InstructionCost Cost =
872 getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
873
874 // Legalize the types.
875 auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
876 auto DstVT = TLI->getValueType(DL, Dst);
877 auto SrcVT = TLI->getValueType(DL, Src);
878 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
879
880 // If the resulting type is still a vector and the destination type is legal,
881 // we may get the extension for free. If not, get the default cost for the
882 // extend.
883 if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
884 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
885 CostKind);
886
887 // The destination type should be larger than the element type. If not, get
888 // the default cost for the extend.
889 if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
890 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
891 CostKind);
892
893 switch (Opcode) {
894 default:
895 llvm_unreachable("Opcode should be either SExt or ZExt");
896
897 // For sign-extends, we only need a smov, which performs the extension
898 // automatically.
899 case Instruction::SExt:
900 return Cost;
901
902 // For zero-extends, the extend is performed automatically by a umov unless
903 // the destination type is i64 and the element type is i8 or i16.
904 case Instruction::ZExt:
905 if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
906 return Cost;
907 }
908
909 // If we are unable to perform the extend for free, get the default cost.
910 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
911 CostKind);
912 }
913
getCFInstrCost(unsigned Opcode,TTI::TargetCostKind CostKind,const Instruction * I)914 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
915 TTI::TargetCostKind CostKind,
916 const Instruction *I) {
917 if (CostKind != TTI::TCK_RecipThroughput)
918 return Opcode == Instruction::PHI ? 0 : 1;
919 assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
920 // Branches are assumed to be predicted.
921 return 0;
922 }
923
getVectorInstrCost(unsigned Opcode,Type * Val,unsigned Index)924 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
925 unsigned Index) {
926 assert(Val->isVectorTy() && "This must be a vector type");
927
928 if (Index != -1U) {
929 // Legalize the type.
930 std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
931
932 // This type is legalized to a scalar type.
933 if (!LT.second.isVector())
934 return 0;
935
936 // The type may be split. Normalize the index to the new type.
937 unsigned Width = LT.second.getVectorNumElements();
938 Index = Index % Width;
939
940 // The element at index zero is already inside the vector.
941 if (Index == 0)
942 return 0;
943 }
944
945 // All other insert/extracts cost this much.
946 return ST->getVectorInsertExtractBaseCost();
947 }
948
getArithmeticInstrCost(unsigned Opcode,Type * Ty,TTI::TargetCostKind CostKind,TTI::OperandValueKind Opd1Info,TTI::OperandValueKind Opd2Info,TTI::OperandValueProperties Opd1PropInfo,TTI::OperandValueProperties Opd2PropInfo,ArrayRef<const Value * > Args,const Instruction * CxtI)949 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
950 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
951 TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info,
952 TTI::OperandValueProperties Opd1PropInfo,
953 TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
954 const Instruction *CxtI) {
955 // TODO: Handle more cost kinds.
956 if (CostKind != TTI::TCK_RecipThroughput)
957 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
958 Opd2Info, Opd1PropInfo,
959 Opd2PropInfo, Args, CxtI);
960
961 // Legalize the type.
962 std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
963
964 // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
965 // add in the widening overhead specified by the sub-target. Since the
966 // extends feeding widening instructions are performed automatically, they
967 // aren't present in the generated code and have a zero cost. By adding a
968 // widening overhead here, we attach the total cost of the combined operation
969 // to the widening instruction.
970 InstructionCost Cost = 0;
971 if (isWideningInstruction(Ty, Opcode, Args))
972 Cost += ST->getWideningBaseCost();
973
974 int ISD = TLI->InstructionOpcodeToISD(Opcode);
975
976 switch (ISD) {
977 default:
978 return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
979 Opd2Info,
980 Opd1PropInfo, Opd2PropInfo);
981 case ISD::SDIV:
982 if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue &&
983 Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
984 // On AArch64, scalar signed division by constants power-of-two are
985 // normally expanded to the sequence ADD + CMP + SELECT + SRA.
986 // The OperandValue properties many not be same as that of previous
987 // operation; conservatively assume OP_None.
988 Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
989 Opd1Info, Opd2Info,
990 TargetTransformInfo::OP_None,
991 TargetTransformInfo::OP_None);
992 Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
993 Opd1Info, Opd2Info,
994 TargetTransformInfo::OP_None,
995 TargetTransformInfo::OP_None);
996 Cost += getArithmeticInstrCost(Instruction::Select, Ty, CostKind,
997 Opd1Info, Opd2Info,
998 TargetTransformInfo::OP_None,
999 TargetTransformInfo::OP_None);
1000 Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
1001 Opd1Info, Opd2Info,
1002 TargetTransformInfo::OP_None,
1003 TargetTransformInfo::OP_None);
1004 return Cost;
1005 }
1006 LLVM_FALLTHROUGH;
1007 case ISD::UDIV:
1008 if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) {
1009 auto VT = TLI->getValueType(DL, Ty);
1010 if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
1011 // Vector signed division by constant are expanded to the
1012 // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
1013 // to MULHS + SUB + SRL + ADD + SRL.
1014 InstructionCost MulCost = getArithmeticInstrCost(
1015 Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info,
1016 TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1017 InstructionCost AddCost = getArithmeticInstrCost(
1018 Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info,
1019 TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1020 InstructionCost ShrCost = getArithmeticInstrCost(
1021 Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info,
1022 TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1023 return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
1024 }
1025 }
1026
1027 Cost += BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1028 Opd2Info,
1029 Opd1PropInfo, Opd2PropInfo);
1030 if (Ty->isVectorTy()) {
1031 // On AArch64, vector divisions are not supported natively and are
1032 // expanded into scalar divisions of each pair of elements.
1033 Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
1034 Opd1Info, Opd2Info, Opd1PropInfo,
1035 Opd2PropInfo);
1036 Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
1037 Opd1Info, Opd2Info, Opd1PropInfo,
1038 Opd2PropInfo);
1039 // TODO: if one of the arguments is scalar, then it's not necessary to
1040 // double the cost of handling the vector elements.
1041 Cost += Cost;
1042 }
1043 return Cost;
1044
1045 case ISD::MUL:
1046 if (LT.second != MVT::v2i64)
1047 return (Cost + 1) * LT.first;
1048 // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
1049 // as elements are extracted from the vectors and the muls scalarized.
1050 // As getScalarizationOverhead is a bit too pessimistic, we estimate the
1051 // cost for a i64 vector directly here, which is:
1052 // - four i64 extracts,
1053 // - two i64 inserts, and
1054 // - two muls.
1055 // So, for a v2i64 with LT.First = 1 the cost is 8, and for a v4i64 with
1056 // LT.first = 2 the cost is 16.
1057 return LT.first * 8;
1058 case ISD::ADD:
1059 case ISD::XOR:
1060 case ISD::OR:
1061 case ISD::AND:
1062 // These nodes are marked as 'custom' for combining purposes only.
1063 // We know that they are legal. See LowerAdd in ISelLowering.
1064 return (Cost + 1) * LT.first;
1065
1066 case ISD::FADD:
1067 // These nodes are marked as 'custom' just to lower them to SVE.
1068 // We know said lowering will incur no additional cost.
1069 if (isa<FixedVectorType>(Ty) && !Ty->getScalarType()->isFP128Ty())
1070 return (Cost + 2) * LT.first;
1071
1072 return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1073 Opd2Info,
1074 Opd1PropInfo, Opd2PropInfo);
1075 }
1076 }
1077
getAddressComputationCost(Type * Ty,ScalarEvolution * SE,const SCEV * Ptr)1078 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
1079 ScalarEvolution *SE,
1080 const SCEV *Ptr) {
1081 // Address computations in vectorized code with non-consecutive addresses will
1082 // likely result in more instructions compared to scalar code where the
1083 // computation can more often be merged into the index mode. The resulting
1084 // extra micro-ops can significantly decrease throughput.
1085 unsigned NumVectorInstToHideOverhead = 10;
1086 int MaxMergeDistance = 64;
1087
1088 if (Ty->isVectorTy() && SE &&
1089 !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
1090 return NumVectorInstToHideOverhead;
1091
1092 // In many cases the address computation is not merged into the instruction
1093 // addressing mode.
1094 return 1;
1095 }
1096
getCmpSelInstrCost(unsigned Opcode,Type * ValTy,Type * CondTy,CmpInst::Predicate VecPred,TTI::TargetCostKind CostKind,const Instruction * I)1097 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1098 Type *CondTy,
1099 CmpInst::Predicate VecPred,
1100 TTI::TargetCostKind CostKind,
1101 const Instruction *I) {
1102 // TODO: Handle other cost kinds.
1103 if (CostKind != TTI::TCK_RecipThroughput)
1104 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
1105 I);
1106
1107 int ISD = TLI->InstructionOpcodeToISD(Opcode);
1108 // We don't lower some vector selects well that are wider than the register
1109 // width.
1110 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
1111 // We would need this many instructions to hide the scalarization happening.
1112 const int AmortizationCost = 20;
1113
1114 // If VecPred is not set, check if we can get a predicate from the context
1115 // instruction, if its type matches the requested ValTy.
1116 if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
1117 CmpInst::Predicate CurrentPred;
1118 if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
1119 m_Value())))
1120 VecPred = CurrentPred;
1121 }
1122 // Check if we have a compare/select chain that can be lowered using CMxx &
1123 // BFI pair.
1124 if (CmpInst::isIntPredicate(VecPred)) {
1125 static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
1126 MVT::v8i16, MVT::v2i32, MVT::v4i32,
1127 MVT::v2i64};
1128 auto LT = TLI->getTypeLegalizationCost(DL, ValTy);
1129 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; }))
1130 return LT.first;
1131 }
1132
1133 static const TypeConversionCostTblEntry
1134 VectorSelectTbl[] = {
1135 { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
1136 { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
1137 { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
1138 { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
1139 { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
1140 { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
1141 };
1142
1143 EVT SelCondTy = TLI->getValueType(DL, CondTy);
1144 EVT SelValTy = TLI->getValueType(DL, ValTy);
1145 if (SelCondTy.isSimple() && SelValTy.isSimple()) {
1146 if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
1147 SelCondTy.getSimpleVT(),
1148 SelValTy.getSimpleVT()))
1149 return Entry->Cost;
1150 }
1151 }
1152 // The base case handles scalable vectors fine for now, since it treats the
1153 // cost as 1 * legalization cost.
1154 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1155 }
1156
1157 AArch64TTIImpl::TTI::MemCmpExpansionOptions
enableMemCmpExpansion(bool OptSize,bool IsZeroCmp) const1158 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
1159 TTI::MemCmpExpansionOptions Options;
1160 if (ST->requiresStrictAlign()) {
1161 // TODO: Add cost modeling for strict align. Misaligned loads expand to
1162 // a bunch of instructions when strict align is enabled.
1163 return Options;
1164 }
1165 Options.AllowOverlappingLoads = true;
1166 Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
1167 Options.NumLoadsPerBlock = Options.MaxNumLoads;
1168 // TODO: Though vector loads usually perform well on AArch64, in some targets
1169 // they may wake up the FP unit, which raises the power consumption. Perhaps
1170 // they could be used with no holds barred (-O3).
1171 Options.LoadSizes = {8, 4, 2, 1};
1172 return Options;
1173 }
1174
1175 InstructionCost
getMaskedMemoryOpCost(unsigned Opcode,Type * Src,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind)1176 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
1177 Align Alignment, unsigned AddressSpace,
1178 TTI::TargetCostKind CostKind) {
1179 if (!isa<ScalableVectorType>(Src))
1180 return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
1181 CostKind);
1182 auto LT = TLI->getTypeLegalizationCost(DL, Src);
1183 return LT.first * 2;
1184 }
1185
getGatherScatterOpCost(unsigned Opcode,Type * DataTy,const Value * Ptr,bool VariableMask,Align Alignment,TTI::TargetCostKind CostKind,const Instruction * I)1186 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
1187 unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1188 Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
1189
1190 if (!isa<ScalableVectorType>(DataTy))
1191 return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1192 Alignment, CostKind, I);
1193 auto *VT = cast<VectorType>(DataTy);
1194 auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
1195 ElementCount LegalVF = LT.second.getVectorElementCount();
1196 Optional<unsigned> MaxNumVScale = getMaxVScale();
1197 assert(MaxNumVScale && "Expected valid max vscale value");
1198
1199 InstructionCost MemOpCost =
1200 getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I);
1201 unsigned MaxNumElementsPerGather =
1202 MaxNumVScale.getValue() * LegalVF.getKnownMinValue();
1203 return LT.first * MaxNumElementsPerGather * MemOpCost;
1204 }
1205
useNeonVector(const Type * Ty) const1206 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
1207 return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
1208 }
1209
getMemoryOpCost(unsigned Opcode,Type * Ty,MaybeAlign Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,const Instruction * I)1210 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
1211 MaybeAlign Alignment,
1212 unsigned AddressSpace,
1213 TTI::TargetCostKind CostKind,
1214 const Instruction *I) {
1215 // Type legalization can't handle structs
1216 if (TLI->getValueType(DL, Ty, true) == MVT::Other)
1217 return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
1218 CostKind);
1219
1220 auto LT = TLI->getTypeLegalizationCost(DL, Ty);
1221
1222 // TODO: consider latency as well for TCK_SizeAndLatency.
1223 if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
1224 return LT.first;
1225
1226 if (CostKind != TTI::TCK_RecipThroughput)
1227 return 1;
1228
1229 if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
1230 LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
1231 // Unaligned stores are extremely inefficient. We don't split all
1232 // unaligned 128-bit stores because the negative impact that has shown in
1233 // practice on inlined block copy code.
1234 // We make such stores expensive so that we will only vectorize if there
1235 // are 6 other instructions getting vectorized.
1236 const int AmortizationCost = 6;
1237
1238 return LT.first * 2 * AmortizationCost;
1239 }
1240
1241 if (useNeonVector(Ty) &&
1242 cast<VectorType>(Ty)->getElementType()->isIntegerTy(8)) {
1243 unsigned ProfitableNumElements;
1244 if (Opcode == Instruction::Store)
1245 // We use a custom trunc store lowering so v.4b should be profitable.
1246 ProfitableNumElements = 4;
1247 else
1248 // We scalarize the loads because there is not v.4b register and we
1249 // have to promote the elements to v.2.
1250 ProfitableNumElements = 8;
1251
1252 if (cast<FixedVectorType>(Ty)->getNumElements() < ProfitableNumElements) {
1253 unsigned NumVecElts = cast<FixedVectorType>(Ty)->getNumElements();
1254 unsigned NumVectorizableInstsToAmortize = NumVecElts * 2;
1255 // We generate 2 instructions per vector element.
1256 return NumVectorizableInstsToAmortize * NumVecElts * 2;
1257 }
1258 }
1259
1260 return LT.first;
1261 }
1262
getInterleavedMemoryOpCost(unsigned Opcode,Type * VecTy,unsigned Factor,ArrayRef<unsigned> Indices,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,bool UseMaskForCond,bool UseMaskForGaps)1263 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
1264 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
1265 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
1266 bool UseMaskForCond, bool UseMaskForGaps) {
1267 assert(Factor >= 2 && "Invalid interleave factor");
1268 auto *VecVTy = cast<FixedVectorType>(VecTy);
1269
1270 if (!UseMaskForCond && !UseMaskForGaps &&
1271 Factor <= TLI->getMaxSupportedInterleaveFactor()) {
1272 unsigned NumElts = VecVTy->getNumElements();
1273 auto *SubVecTy =
1274 FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
1275
1276 // ldN/stN only support legal vector types of size 64 or 128 in bits.
1277 // Accesses having vector types that are a multiple of 128 bits can be
1278 // matched to more than one ldN/stN instruction.
1279 if (NumElts % Factor == 0 &&
1280 TLI->isLegalInterleavedAccessType(SubVecTy, DL))
1281 return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL);
1282 }
1283
1284 return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
1285 Alignment, AddressSpace, CostKind,
1286 UseMaskForCond, UseMaskForGaps);
1287 }
1288
1289 InstructionCost
getCostOfKeepingLiveOverCall(ArrayRef<Type * > Tys)1290 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
1291 InstructionCost Cost = 0;
1292 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1293 for (auto *I : Tys) {
1294 if (!I->isVectorTy())
1295 continue;
1296 if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
1297 128)
1298 Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
1299 getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
1300 }
1301 return Cost;
1302 }
1303
getMaxInterleaveFactor(unsigned VF)1304 unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
1305 return ST->getMaxInterleaveFactor();
1306 }
1307
1308 // For Falkor, we want to avoid having too many strided loads in a loop since
1309 // that can exhaust the HW prefetcher resources. We adjust the unroller
1310 // MaxCount preference below to attempt to ensure unrolling doesn't create too
1311 // many strided loads.
1312 static void
getFalkorUnrollingPreferences(Loop * L,ScalarEvolution & SE,TargetTransformInfo::UnrollingPreferences & UP)1313 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1314 TargetTransformInfo::UnrollingPreferences &UP) {
1315 enum { MaxStridedLoads = 7 };
1316 auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
1317 int StridedLoads = 0;
1318 // FIXME? We could make this more precise by looking at the CFG and
1319 // e.g. not counting loads in each side of an if-then-else diamond.
1320 for (const auto BB : L->blocks()) {
1321 for (auto &I : *BB) {
1322 LoadInst *LMemI = dyn_cast<LoadInst>(&I);
1323 if (!LMemI)
1324 continue;
1325
1326 Value *PtrValue = LMemI->getPointerOperand();
1327 if (L->isLoopInvariant(PtrValue))
1328 continue;
1329
1330 const SCEV *LSCEV = SE.getSCEV(PtrValue);
1331 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
1332 if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
1333 continue;
1334
1335 // FIXME? We could take pairing of unrolled load copies into account
1336 // by looking at the AddRec, but we would probably have to limit this
1337 // to loops with no stores or other memory optimization barriers.
1338 ++StridedLoads;
1339 // We've seen enough strided loads that seeing more won't make a
1340 // difference.
1341 if (StridedLoads > MaxStridedLoads / 2)
1342 return StridedLoads;
1343 }
1344 }
1345 return StridedLoads;
1346 };
1347
1348 int StridedLoads = countStridedLoads(L, SE);
1349 LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
1350 << " strided loads\n");
1351 // Pick the largest power of 2 unroll count that won't result in too many
1352 // strided loads.
1353 if (StridedLoads) {
1354 UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
1355 LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
1356 << UP.MaxCount << '\n');
1357 }
1358 }
1359
getUnrollingPreferences(Loop * L,ScalarEvolution & SE,TTI::UnrollingPreferences & UP)1360 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1361 TTI::UnrollingPreferences &UP) {
1362 // Enable partial unrolling and runtime unrolling.
1363 BaseT::getUnrollingPreferences(L, SE, UP);
1364
1365 // For inner loop, it is more likely to be a hot one, and the runtime check
1366 // can be promoted out from LICM pass, so the overhead is less, let's try
1367 // a larger threshold to unroll more loops.
1368 if (L->getLoopDepth() > 1)
1369 UP.PartialThreshold *= 2;
1370
1371 // Disable partial & runtime unrolling on -Os.
1372 UP.PartialOptSizeThreshold = 0;
1373
1374 if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
1375 EnableFalkorHWPFUnrollFix)
1376 getFalkorUnrollingPreferences(L, SE, UP);
1377
1378 // Scan the loop: don't unroll loops with calls as this could prevent
1379 // inlining. Don't unroll vector loops either, as they don't benefit much from
1380 // unrolling.
1381 for (auto *BB : L->getBlocks()) {
1382 for (auto &I : *BB) {
1383 // Don't unroll vectorised loop.
1384 if (I.getType()->isVectorTy())
1385 return;
1386
1387 if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
1388 if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
1389 if (!isLoweredToCall(F))
1390 continue;
1391 }
1392 return;
1393 }
1394 }
1395 }
1396
1397 // Enable runtime unrolling for in-order models
1398 // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
1399 // checking for that case, we can ensure that the default behaviour is
1400 // unchanged
1401 if (ST->getProcFamily() != AArch64Subtarget::Others &&
1402 !ST->getSchedModel().isOutOfOrder()) {
1403 UP.Runtime = true;
1404 UP.Partial = true;
1405 UP.UpperBound = true;
1406 UP.UnrollRemainder = true;
1407 UP.DefaultUnrollRuntimeCount = 4;
1408 }
1409 }
1410
getPeelingPreferences(Loop * L,ScalarEvolution & SE,TTI::PeelingPreferences & PP)1411 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
1412 TTI::PeelingPreferences &PP) {
1413 BaseT::getPeelingPreferences(L, SE, PP);
1414 }
1415
getOrCreateResultFromMemIntrinsic(IntrinsicInst * Inst,Type * ExpectedType)1416 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
1417 Type *ExpectedType) {
1418 switch (Inst->getIntrinsicID()) {
1419 default:
1420 return nullptr;
1421 case Intrinsic::aarch64_neon_st2:
1422 case Intrinsic::aarch64_neon_st3:
1423 case Intrinsic::aarch64_neon_st4: {
1424 // Create a struct type
1425 StructType *ST = dyn_cast<StructType>(ExpectedType);
1426 if (!ST)
1427 return nullptr;
1428 unsigned NumElts = Inst->getNumArgOperands() - 1;
1429 if (ST->getNumElements() != NumElts)
1430 return nullptr;
1431 for (unsigned i = 0, e = NumElts; i != e; ++i) {
1432 if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
1433 return nullptr;
1434 }
1435 Value *Res = UndefValue::get(ExpectedType);
1436 IRBuilder<> Builder(Inst);
1437 for (unsigned i = 0, e = NumElts; i != e; ++i) {
1438 Value *L = Inst->getArgOperand(i);
1439 Res = Builder.CreateInsertValue(Res, L, i);
1440 }
1441 return Res;
1442 }
1443 case Intrinsic::aarch64_neon_ld2:
1444 case Intrinsic::aarch64_neon_ld3:
1445 case Intrinsic::aarch64_neon_ld4:
1446 if (Inst->getType() == ExpectedType)
1447 return Inst;
1448 return nullptr;
1449 }
1450 }
1451
getTgtMemIntrinsic(IntrinsicInst * Inst,MemIntrinsicInfo & Info)1452 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
1453 MemIntrinsicInfo &Info) {
1454 switch (Inst->getIntrinsicID()) {
1455 default:
1456 break;
1457 case Intrinsic::aarch64_neon_ld2:
1458 case Intrinsic::aarch64_neon_ld3:
1459 case Intrinsic::aarch64_neon_ld4:
1460 Info.ReadMem = true;
1461 Info.WriteMem = false;
1462 Info.PtrVal = Inst->getArgOperand(0);
1463 break;
1464 case Intrinsic::aarch64_neon_st2:
1465 case Intrinsic::aarch64_neon_st3:
1466 case Intrinsic::aarch64_neon_st4:
1467 Info.ReadMem = false;
1468 Info.WriteMem = true;
1469 Info.PtrVal = Inst->getArgOperand(Inst->getNumArgOperands() - 1);
1470 break;
1471 }
1472
1473 switch (Inst->getIntrinsicID()) {
1474 default:
1475 return false;
1476 case Intrinsic::aarch64_neon_ld2:
1477 case Intrinsic::aarch64_neon_st2:
1478 Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
1479 break;
1480 case Intrinsic::aarch64_neon_ld3:
1481 case Intrinsic::aarch64_neon_st3:
1482 Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
1483 break;
1484 case Intrinsic::aarch64_neon_ld4:
1485 case Intrinsic::aarch64_neon_st4:
1486 Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
1487 break;
1488 }
1489 return true;
1490 }
1491
1492 /// See if \p I should be considered for address type promotion. We check if \p
1493 /// I is a sext with right type and used in memory accesses. If it used in a
1494 /// "complex" getelementptr, we allow it to be promoted without finding other
1495 /// sext instructions that sign extended the same initial value. A getelementptr
1496 /// is considered as "complex" if it has more than 2 operands.
shouldConsiderAddressTypePromotion(const Instruction & I,bool & AllowPromotionWithoutCommonHeader)1497 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
1498 const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
1499 bool Considerable = false;
1500 AllowPromotionWithoutCommonHeader = false;
1501 if (!isa<SExtInst>(&I))
1502 return false;
1503 Type *ConsideredSExtType =
1504 Type::getInt64Ty(I.getParent()->getParent()->getContext());
1505 if (I.getType() != ConsideredSExtType)
1506 return false;
1507 // See if the sext is the one with the right type and used in at least one
1508 // GetElementPtrInst.
1509 for (const User *U : I.users()) {
1510 if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
1511 Considerable = true;
1512 // A getelementptr is considered as "complex" if it has more than 2
1513 // operands. We will promote a SExt used in such complex GEP as we
1514 // expect some computation to be merged if they are done on 64 bits.
1515 if (GEPInst->getNumOperands() > 2) {
1516 AllowPromotionWithoutCommonHeader = true;
1517 break;
1518 }
1519 }
1520 }
1521 return Considerable;
1522 }
1523
isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc,ElementCount VF) const1524 bool AArch64TTIImpl::isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc,
1525 ElementCount VF) const {
1526 if (!VF.isScalable())
1527 return true;
1528
1529 Type *Ty = RdxDesc.getRecurrenceType();
1530 if (Ty->isBFloatTy() || !isLegalElementTypeForSVE(Ty))
1531 return false;
1532
1533 switch (RdxDesc.getRecurrenceKind()) {
1534 case RecurKind::Add:
1535 case RecurKind::FAdd:
1536 case RecurKind::And:
1537 case RecurKind::Or:
1538 case RecurKind::Xor:
1539 case RecurKind::SMin:
1540 case RecurKind::SMax:
1541 case RecurKind::UMin:
1542 case RecurKind::UMax:
1543 case RecurKind::FMin:
1544 case RecurKind::FMax:
1545 return true;
1546 default:
1547 return false;
1548 }
1549 }
1550
1551 InstructionCost
getMinMaxReductionCost(VectorType * Ty,VectorType * CondTy,bool IsPairwise,bool IsUnsigned,TTI::TargetCostKind CostKind)1552 AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
1553 bool IsPairwise, bool IsUnsigned,
1554 TTI::TargetCostKind CostKind) {
1555 if (!isa<ScalableVectorType>(Ty))
1556 return BaseT::getMinMaxReductionCost(Ty, CondTy, IsPairwise, IsUnsigned,
1557 CostKind);
1558 assert((isa<ScalableVectorType>(Ty) && isa<ScalableVectorType>(CondTy)) &&
1559 "Both vector needs to be scalable");
1560
1561 std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1562 InstructionCost LegalizationCost = 0;
1563 if (LT.first > 1) {
1564 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
1565 unsigned CmpOpcode =
1566 Ty->isFPOrFPVectorTy() ? Instruction::FCmp : Instruction::ICmp;
1567 LegalizationCost =
1568 getCmpSelInstrCost(CmpOpcode, LegalVTy, LegalVTy,
1569 CmpInst::BAD_ICMP_PREDICATE, CostKind) +
1570 getCmpSelInstrCost(Instruction::Select, LegalVTy, LegalVTy,
1571 CmpInst::BAD_ICMP_PREDICATE, CostKind);
1572 LegalizationCost *= LT.first - 1;
1573 }
1574
1575 return LegalizationCost + /*Cost of horizontal reduction*/ 2;
1576 }
1577
getArithmeticReductionCostSVE(unsigned Opcode,VectorType * ValTy,bool IsPairwise,TTI::TargetCostKind CostKind)1578 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
1579 unsigned Opcode, VectorType *ValTy, bool IsPairwise,
1580 TTI::TargetCostKind CostKind) {
1581 assert(!IsPairwise && "Cannot be pair wise to continue");
1582
1583 std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
1584 InstructionCost LegalizationCost = 0;
1585 if (LT.first > 1) {
1586 Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
1587 LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
1588 LegalizationCost *= LT.first - 1;
1589 }
1590
1591 int ISD = TLI->InstructionOpcodeToISD(Opcode);
1592 assert(ISD && "Invalid opcode");
1593 // Add the final reduction cost for the legal horizontal reduction
1594 switch (ISD) {
1595 case ISD::ADD:
1596 case ISD::AND:
1597 case ISD::OR:
1598 case ISD::XOR:
1599 case ISD::FADD:
1600 return LegalizationCost + 2;
1601 default:
1602 return InstructionCost::getInvalid();
1603 }
1604 }
1605
1606 InstructionCost
getArithmeticReductionCost(unsigned Opcode,VectorType * ValTy,bool IsPairwiseForm,TTI::TargetCostKind CostKind)1607 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
1608 bool IsPairwiseForm,
1609 TTI::TargetCostKind CostKind) {
1610
1611 if (isa<ScalableVectorType>(ValTy))
1612 return getArithmeticReductionCostSVE(Opcode, ValTy, IsPairwiseForm,
1613 CostKind);
1614 if (IsPairwiseForm)
1615 return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm,
1616 CostKind);
1617
1618 std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
1619 MVT MTy = LT.second;
1620 int ISD = TLI->InstructionOpcodeToISD(Opcode);
1621 assert(ISD && "Invalid opcode");
1622
1623 // Horizontal adds can use the 'addv' instruction. We model the cost of these
1624 // instructions as normal vector adds. This is the only arithmetic vector
1625 // reduction operation for which we have an instruction.
1626 static const CostTblEntry CostTblNoPairwise[]{
1627 {ISD::ADD, MVT::v8i8, 1},
1628 {ISD::ADD, MVT::v16i8, 1},
1629 {ISD::ADD, MVT::v4i16, 1},
1630 {ISD::ADD, MVT::v8i16, 1},
1631 {ISD::ADD, MVT::v4i32, 1},
1632 };
1633
1634 if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
1635 return LT.first * Entry->Cost;
1636
1637 return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm,
1638 CostKind);
1639 }
1640
getShuffleCost(TTI::ShuffleKind Kind,VectorType * Tp,ArrayRef<int> Mask,int Index,VectorType * SubTp)1641 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
1642 VectorType *Tp,
1643 ArrayRef<int> Mask, int Index,
1644 VectorType *SubTp) {
1645 Kind = improveShuffleKindFromMask(Kind, Mask);
1646 if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
1647 Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
1648 Kind == TTI::SK_Reverse) {
1649 static const CostTblEntry ShuffleTbl[] = {
1650 // Broadcast shuffle kinds can be performed with 'dup'.
1651 { TTI::SK_Broadcast, MVT::v8i8, 1 },
1652 { TTI::SK_Broadcast, MVT::v16i8, 1 },
1653 { TTI::SK_Broadcast, MVT::v4i16, 1 },
1654 { TTI::SK_Broadcast, MVT::v8i16, 1 },
1655 { TTI::SK_Broadcast, MVT::v2i32, 1 },
1656 { TTI::SK_Broadcast, MVT::v4i32, 1 },
1657 { TTI::SK_Broadcast, MVT::v2i64, 1 },
1658 { TTI::SK_Broadcast, MVT::v2f32, 1 },
1659 { TTI::SK_Broadcast, MVT::v4f32, 1 },
1660 { TTI::SK_Broadcast, MVT::v2f64, 1 },
1661 // Transpose shuffle kinds can be performed with 'trn1/trn2' and
1662 // 'zip1/zip2' instructions.
1663 { TTI::SK_Transpose, MVT::v8i8, 1 },
1664 { TTI::SK_Transpose, MVT::v16i8, 1 },
1665 { TTI::SK_Transpose, MVT::v4i16, 1 },
1666 { TTI::SK_Transpose, MVT::v8i16, 1 },
1667 { TTI::SK_Transpose, MVT::v2i32, 1 },
1668 { TTI::SK_Transpose, MVT::v4i32, 1 },
1669 { TTI::SK_Transpose, MVT::v2i64, 1 },
1670 { TTI::SK_Transpose, MVT::v2f32, 1 },
1671 { TTI::SK_Transpose, MVT::v4f32, 1 },
1672 { TTI::SK_Transpose, MVT::v2f64, 1 },
1673 // Select shuffle kinds.
1674 // TODO: handle vXi8/vXi16.
1675 { TTI::SK_Select, MVT::v2i32, 1 }, // mov.
1676 { TTI::SK_Select, MVT::v4i32, 2 }, // rev+trn (or similar).
1677 { TTI::SK_Select, MVT::v2i64, 1 }, // mov.
1678 { TTI::SK_Select, MVT::v2f32, 1 }, // mov.
1679 { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
1680 { TTI::SK_Select, MVT::v2f64, 1 }, // mov.
1681 // PermuteSingleSrc shuffle kinds.
1682 // TODO: handle vXi8/vXi16.
1683 { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
1684 { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
1685 { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
1686 { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
1687 { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
1688 { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
1689 // Reverse can be lowered with `rev`.
1690 { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
1691 { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
1692 { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
1693 { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
1694 { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
1695 { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
1696 // Broadcast shuffle kinds for scalable vectors
1697 { TTI::SK_Broadcast, MVT::nxv16i8, 1 },
1698 { TTI::SK_Broadcast, MVT::nxv8i16, 1 },
1699 { TTI::SK_Broadcast, MVT::nxv4i32, 1 },
1700 { TTI::SK_Broadcast, MVT::nxv2i64, 1 },
1701 { TTI::SK_Broadcast, MVT::nxv2f16, 1 },
1702 { TTI::SK_Broadcast, MVT::nxv4f16, 1 },
1703 { TTI::SK_Broadcast, MVT::nxv8f16, 1 },
1704 { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
1705 { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
1706 { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
1707 { TTI::SK_Broadcast, MVT::nxv2f32, 1 },
1708 { TTI::SK_Broadcast, MVT::nxv4f32, 1 },
1709 { TTI::SK_Broadcast, MVT::nxv2f64, 1 },
1710 { TTI::SK_Broadcast, MVT::nxv16i1, 1 },
1711 { TTI::SK_Broadcast, MVT::nxv8i1, 1 },
1712 { TTI::SK_Broadcast, MVT::nxv4i1, 1 },
1713 { TTI::SK_Broadcast, MVT::nxv2i1, 1 },
1714 // Handle the cases for vector.reverse with scalable vectors
1715 { TTI::SK_Reverse, MVT::nxv16i8, 1 },
1716 { TTI::SK_Reverse, MVT::nxv8i16, 1 },
1717 { TTI::SK_Reverse, MVT::nxv4i32, 1 },
1718 { TTI::SK_Reverse, MVT::nxv2i64, 1 },
1719 { TTI::SK_Reverse, MVT::nxv2f16, 1 },
1720 { TTI::SK_Reverse, MVT::nxv4f16, 1 },
1721 { TTI::SK_Reverse, MVT::nxv8f16, 1 },
1722 { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
1723 { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
1724 { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
1725 { TTI::SK_Reverse, MVT::nxv2f32, 1 },
1726 { TTI::SK_Reverse, MVT::nxv4f32, 1 },
1727 { TTI::SK_Reverse, MVT::nxv2f64, 1 },
1728 { TTI::SK_Reverse, MVT::nxv16i1, 1 },
1729 { TTI::SK_Reverse, MVT::nxv8i1, 1 },
1730 { TTI::SK_Reverse, MVT::nxv4i1, 1 },
1731 { TTI::SK_Reverse, MVT::nxv2i1, 1 },
1732 };
1733 std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
1734 if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
1735 return LT.first * Entry->Cost;
1736 }
1737
1738 return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp);
1739 }
1740