1 //===- TargetTransformInfoImpl.h --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file provides helpers for the implementation of
10 /// a TargetTransformInfo-conforming class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_ANALYSIS_TARGETTRANSFORMINFOIMPL_H
15 #define LLVM_ANALYSIS_TARGETTRANSFORMINFOIMPL_H
16 
17 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/Analysis/VectorUtils.h"
20 #include "llvm/IR/DataLayout.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/GetElementPtrTypeIterator.h"
23 #include "llvm/IR/Intrinsics.h"
24 #include "llvm/IR/Operator.h"
25 #include "llvm/IR/Type.h"
26 
27 namespace llvm {
28 
29 /// Base class for use as a mix-in that aids implementing
30 /// a TargetTransformInfo-compatible class.
31 class TargetTransformInfoImplBase {
32 protected:
33   typedef TargetTransformInfo TTI;
34 
35   const DataLayout &DL;
36 
TargetTransformInfoImplBase(const DataLayout & DL)37   explicit TargetTransformInfoImplBase(const DataLayout &DL) : DL(DL) {}
38 
39 public:
40   // Provide value semantics. MSVC requires that we spell all of these out.
TargetTransformInfoImplBase(const TargetTransformInfoImplBase & Arg)41   TargetTransformInfoImplBase(const TargetTransformInfoImplBase &Arg)
42       : DL(Arg.DL) {}
TargetTransformInfoImplBase(TargetTransformInfoImplBase && Arg)43   TargetTransformInfoImplBase(TargetTransformInfoImplBase &&Arg) : DL(Arg.DL) {}
44 
getDataLayout()45   const DataLayout &getDataLayout() const { return DL; }
46 
47   int getGEPCost(Type *PointeeType, const Value *Ptr,
48                  ArrayRef<const Value *> Operands,
49                  TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency) {
50     // In the basic model, we just assume that all-constant GEPs will be folded
51     // into their uses via addressing modes.
52     for (unsigned Idx = 0, Size = Operands.size(); Idx != Size; ++Idx)
53       if (!isa<Constant>(Operands[Idx]))
54         return TTI::TCC_Basic;
55 
56     return TTI::TCC_Free;
57   }
58 
getEstimatedNumberOfCaseClusters(const SwitchInst & SI,unsigned & JTSize,ProfileSummaryInfo * PSI,BlockFrequencyInfo * BFI)59   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
60                                             unsigned &JTSize,
61                                             ProfileSummaryInfo *PSI,
62                                             BlockFrequencyInfo *BFI) {
63     (void)PSI;
64     (void)BFI;
65     JTSize = 0;
66     return SI.getNumCases();
67   }
68 
getInliningThresholdMultiplier()69   unsigned getInliningThresholdMultiplier() { return 1; }
70 
getInlinerVectorBonusPercent()71   int getInlinerVectorBonusPercent() { return 150; }
72 
getMemcpyCost(const Instruction * I)73   unsigned getMemcpyCost(const Instruction *I) { return TTI::TCC_Expensive; }
74 
hasBranchDivergence()75   bool hasBranchDivergence() { return false; }
76 
useGPUDivergenceAnalysis()77   bool useGPUDivergenceAnalysis() { return false; }
78 
isSourceOfDivergence(const Value * V)79   bool isSourceOfDivergence(const Value *V) { return false; }
80 
isAlwaysUniform(const Value * V)81   bool isAlwaysUniform(const Value *V) { return false; }
82 
getFlatAddressSpace()83   unsigned getFlatAddressSpace() { return -1; }
84 
collectFlatAddressOperands(SmallVectorImpl<int> & OpIndexes,Intrinsic::ID IID)85   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
86                                   Intrinsic::ID IID) const {
87     return false;
88   }
89 
isNoopAddrSpaceCast(unsigned,unsigned)90   bool isNoopAddrSpaceCast(unsigned, unsigned) const { return false; }
91 
rewriteIntrinsicWithAddressSpace(IntrinsicInst * II,Value * OldV,Value * NewV)92   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
93                                           Value *NewV) const {
94     return nullptr;
95   }
96 
isLoweredToCall(const Function * F)97   bool isLoweredToCall(const Function *F) {
98     assert(F && "A concrete function must be provided to this routine.");
99 
100     // FIXME: These should almost certainly not be handled here, and instead
101     // handled with the help of TLI or the target itself. This was largely
102     // ported from existing analysis heuristics here so that such refactorings
103     // can take place in the future.
104 
105     if (F->isIntrinsic())
106       return false;
107 
108     if (F->hasLocalLinkage() || !F->hasName())
109       return true;
110 
111     StringRef Name = F->getName();
112 
113     // These will all likely lower to a single selection DAG node.
114     if (Name == "copysign" || Name == "copysignf" || Name == "copysignl" ||
115         Name == "fabs" || Name == "fabsf" || Name == "fabsl" || Name == "sin" ||
116         Name == "fmin" || Name == "fminf" || Name == "fminl" ||
117         Name == "fmax" || Name == "fmaxf" || Name == "fmaxl" ||
118         Name == "sinf" || Name == "sinl" || Name == "cos" || Name == "cosf" ||
119         Name == "cosl" || Name == "sqrt" || Name == "sqrtf" || Name == "sqrtl")
120       return false;
121 
122     // These are all likely to be optimized into something smaller.
123     if (Name == "pow" || Name == "powf" || Name == "powl" || Name == "exp2" ||
124         Name == "exp2l" || Name == "exp2f" || Name == "floor" ||
125         Name == "floorf" || Name == "ceil" || Name == "round" ||
126         Name == "ffs" || Name == "ffsl" || Name == "abs" || Name == "labs" ||
127         Name == "llabs")
128       return false;
129 
130     return true;
131   }
132 
isHardwareLoopProfitable(Loop * L,ScalarEvolution & SE,AssumptionCache & AC,TargetLibraryInfo * LibInfo,HardwareLoopInfo & HWLoopInfo)133   bool isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
134                                 AssumptionCache &AC, TargetLibraryInfo *LibInfo,
135                                 HardwareLoopInfo &HWLoopInfo) {
136     return false;
137   }
138 
preferPredicateOverEpilogue(Loop * L,LoopInfo * LI,ScalarEvolution & SE,AssumptionCache & AC,TargetLibraryInfo * TLI,DominatorTree * DT,const LoopAccessInfo * LAI)139   bool preferPredicateOverEpilogue(Loop *L, LoopInfo *LI, ScalarEvolution &SE,
140                                    AssumptionCache &AC, TargetLibraryInfo *TLI,
141                                    DominatorTree *DT,
142                                    const LoopAccessInfo *LAI) const {
143     return false;
144   }
145 
emitGetActiveLaneMask()146   bool emitGetActiveLaneMask() const {
147     return false;
148   }
149 
getUnrollingPreferences(Loop *,ScalarEvolution &,TTI::UnrollingPreferences &)150   void getUnrollingPreferences(Loop *, ScalarEvolution &,
151                                TTI::UnrollingPreferences &) {}
152 
getPeelingPreferences(Loop *,ScalarEvolution &,TTI::PeelingPreferences &)153   void getPeelingPreferences(Loop *, ScalarEvolution &,
154                              TTI::PeelingPreferences &) {}
155 
isLegalAddImmediate(int64_t Imm)156   bool isLegalAddImmediate(int64_t Imm) { return false; }
157 
isLegalICmpImmediate(int64_t Imm)158   bool isLegalICmpImmediate(int64_t Imm) { return false; }
159 
160   bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
161                              bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
162                              Instruction *I = nullptr) {
163     // Guess that only reg and reg+reg addressing is allowed. This heuristic is
164     // taken from the implementation of LSR.
165     return !BaseGV && BaseOffset == 0 && (Scale == 0 || Scale == 1);
166   }
167 
isLSRCostLess(TTI::LSRCost & C1,TTI::LSRCost & C2)168   bool isLSRCostLess(TTI::LSRCost &C1, TTI::LSRCost &C2) {
169     return std::tie(C1.NumRegs, C1.AddRecCost, C1.NumIVMuls, C1.NumBaseAdds,
170                     C1.ScaleCost, C1.ImmCost, C1.SetupCost) <
171            std::tie(C2.NumRegs, C2.AddRecCost, C2.NumIVMuls, C2.NumBaseAdds,
172                     C2.ScaleCost, C2.ImmCost, C2.SetupCost);
173   }
174 
isProfitableLSRChainElement(Instruction * I)175   bool isProfitableLSRChainElement(Instruction *I) { return false; }
176 
canMacroFuseCmp()177   bool canMacroFuseCmp() { return false; }
178 
canSaveCmp(Loop * L,BranchInst ** BI,ScalarEvolution * SE,LoopInfo * LI,DominatorTree * DT,AssumptionCache * AC,TargetLibraryInfo * LibInfo)179   bool canSaveCmp(Loop *L, BranchInst **BI, ScalarEvolution *SE, LoopInfo *LI,
180                   DominatorTree *DT, AssumptionCache *AC,
181                   TargetLibraryInfo *LibInfo) {
182     return false;
183   }
184 
shouldFavorPostInc()185   bool shouldFavorPostInc() const { return false; }
186 
shouldFavorBackedgeIndex(const Loop * L)187   bool shouldFavorBackedgeIndex(const Loop *L) const { return false; }
188 
isLegalMaskedStore(Type * DataType,Align Alignment)189   bool isLegalMaskedStore(Type *DataType, Align Alignment) { return false; }
190 
isLegalMaskedLoad(Type * DataType,Align Alignment)191   bool isLegalMaskedLoad(Type *DataType, Align Alignment) { return false; }
192 
isLegalNTStore(Type * DataType,Align Alignment)193   bool isLegalNTStore(Type *DataType, Align Alignment) {
194     // By default, assume nontemporal memory stores are available for stores
195     // that are aligned and have a size that is a power of 2.
196     unsigned DataSize = DL.getTypeStoreSize(DataType);
197     return Alignment >= DataSize && isPowerOf2_32(DataSize);
198   }
199 
isLegalNTLoad(Type * DataType,Align Alignment)200   bool isLegalNTLoad(Type *DataType, Align Alignment) {
201     // By default, assume nontemporal memory loads are available for loads that
202     // are aligned and have a size that is a power of 2.
203     unsigned DataSize = DL.getTypeStoreSize(DataType);
204     return Alignment >= DataSize && isPowerOf2_32(DataSize);
205   }
206 
isLegalMaskedScatter(Type * DataType,Align Alignment)207   bool isLegalMaskedScatter(Type *DataType, Align Alignment) { return false; }
208 
isLegalMaskedGather(Type * DataType,Align Alignment)209   bool isLegalMaskedGather(Type *DataType, Align Alignment) { return false; }
210 
isLegalMaskedCompressStore(Type * DataType)211   bool isLegalMaskedCompressStore(Type *DataType) { return false; }
212 
isLegalMaskedExpandLoad(Type * DataType)213   bool isLegalMaskedExpandLoad(Type *DataType) { return false; }
214 
hasDivRemOp(Type * DataType,bool IsSigned)215   bool hasDivRemOp(Type *DataType, bool IsSigned) { return false; }
216 
hasVolatileVariant(Instruction * I,unsigned AddrSpace)217   bool hasVolatileVariant(Instruction *I, unsigned AddrSpace) { return false; }
218 
prefersVectorizedAddressing()219   bool prefersVectorizedAddressing() { return true; }
220 
getScalingFactorCost(Type * Ty,GlobalValue * BaseGV,int64_t BaseOffset,bool HasBaseReg,int64_t Scale,unsigned AddrSpace)221   int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
222                            bool HasBaseReg, int64_t Scale, unsigned AddrSpace) {
223     // Guess that all legal addressing mode are free.
224     if (isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
225                               AddrSpace))
226       return 0;
227     return -1;
228   }
229 
LSRWithInstrQueries()230   bool LSRWithInstrQueries() { return false; }
231 
isTruncateFree(Type * Ty1,Type * Ty2)232   bool isTruncateFree(Type *Ty1, Type *Ty2) { return false; }
233 
isProfitableToHoist(Instruction * I)234   bool isProfitableToHoist(Instruction *I) { return true; }
235 
useAA()236   bool useAA() { return false; }
237 
isTypeLegal(Type * Ty)238   bool isTypeLegal(Type *Ty) { return false; }
239 
shouldBuildLookupTables()240   bool shouldBuildLookupTables() { return true; }
shouldBuildLookupTablesForConstant(Constant * C)241   bool shouldBuildLookupTablesForConstant(Constant *C) { return true; }
242 
useColdCCForColdCall(Function & F)243   bool useColdCCForColdCall(Function &F) { return false; }
244 
getScalarizationOverhead(VectorType * Ty,const APInt & DemandedElts,bool Insert,bool Extract)245   unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
246                                     bool Insert, bool Extract) {
247     return 0;
248   }
249 
getOperandsScalarizationOverhead(ArrayRef<const Value * > Args,unsigned VF)250   unsigned getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
251                                             unsigned VF) {
252     return 0;
253   }
254 
supportsEfficientVectorElementLoadStore()255   bool supportsEfficientVectorElementLoadStore() { return false; }
256 
enableAggressiveInterleaving(bool LoopHasReductions)257   bool enableAggressiveInterleaving(bool LoopHasReductions) { return false; }
258 
enableMemCmpExpansion(bool OptSize,bool IsZeroCmp)259   TTI::MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
260                                                     bool IsZeroCmp) const {
261     return {};
262   }
263 
enableInterleavedAccessVectorization()264   bool enableInterleavedAccessVectorization() { return false; }
265 
enableMaskedInterleavedAccessVectorization()266   bool enableMaskedInterleavedAccessVectorization() { return false; }
267 
isFPVectorizationPotentiallyUnsafe()268   bool isFPVectorizationPotentiallyUnsafe() { return false; }
269 
allowsMisalignedMemoryAccesses(LLVMContext & Context,unsigned BitWidth,unsigned AddressSpace,unsigned Alignment,bool * Fast)270   bool allowsMisalignedMemoryAccesses(LLVMContext &Context, unsigned BitWidth,
271                                       unsigned AddressSpace, unsigned Alignment,
272                                       bool *Fast) {
273     return false;
274   }
275 
getPopcntSupport(unsigned IntTyWidthInBit)276   TTI::PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) {
277     return TTI::PSK_Software;
278   }
279 
haveFastSqrt(Type * Ty)280   bool haveFastSqrt(Type *Ty) { return false; }
281 
isFCmpOrdCheaperThanFCmpZero(Type * Ty)282   bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) { return true; }
283 
getFPOpCost(Type * Ty)284   unsigned getFPOpCost(Type *Ty) { return TargetTransformInfo::TCC_Basic; }
285 
getIntImmCodeSizeCost(unsigned Opcode,unsigned Idx,const APInt & Imm,Type * Ty)286   int getIntImmCodeSizeCost(unsigned Opcode, unsigned Idx, const APInt &Imm,
287                             Type *Ty) {
288     return 0;
289   }
290 
getIntImmCost(const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind)291   unsigned getIntImmCost(const APInt &Imm, Type *Ty,
292                          TTI::TargetCostKind CostKind) {
293     return TTI::TCC_Basic;
294   }
295 
getIntImmCostInst(unsigned Opcode,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind)296   unsigned getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
297                              Type *Ty, TTI::TargetCostKind CostKind) {
298     return TTI::TCC_Free;
299   }
300 
getIntImmCostIntrin(Intrinsic::ID IID,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind)301   unsigned getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
302                                const APInt &Imm, Type *Ty,
303                                TTI::TargetCostKind CostKind) {
304     return TTI::TCC_Free;
305   }
306 
getNumberOfRegisters(unsigned ClassID)307   unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
308 
309   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
310     return Vector ? 1 : 0;
311   };
312 
getRegisterClassName(unsigned ClassID)313   const char *getRegisterClassName(unsigned ClassID) const {
314     switch (ClassID) {
315     default:
316       return "Generic::Unknown Register Class";
317     case 0:
318       return "Generic::ScalarRC";
319     case 1:
320       return "Generic::VectorRC";
321     }
322   }
323 
getRegisterBitWidth(bool Vector)324   unsigned getRegisterBitWidth(bool Vector) const { return 32; }
325 
getMinVectorRegisterBitWidth()326   unsigned getMinVectorRegisterBitWidth() { return 128; }
327 
shouldMaximizeVectorBandwidth(bool OptSize)328   bool shouldMaximizeVectorBandwidth(bool OptSize) const { return false; }
329 
getMinimumVF(unsigned ElemWidth)330   unsigned getMinimumVF(unsigned ElemWidth) const { return 0; }
331 
332   bool
shouldConsiderAddressTypePromotion(const Instruction & I,bool & AllowPromotionWithoutCommonHeader)333   shouldConsiderAddressTypePromotion(const Instruction &I,
334                                      bool &AllowPromotionWithoutCommonHeader) {
335     AllowPromotionWithoutCommonHeader = false;
336     return false;
337   }
338 
getCacheLineSize()339   unsigned getCacheLineSize() const { return 0; }
340 
341   llvm::Optional<unsigned>
getCacheSize(TargetTransformInfo::CacheLevel Level)342   getCacheSize(TargetTransformInfo::CacheLevel Level) const {
343     switch (Level) {
344     case TargetTransformInfo::CacheLevel::L1D:
345       LLVM_FALLTHROUGH;
346     case TargetTransformInfo::CacheLevel::L2D:
347       return llvm::Optional<unsigned>();
348     }
349     llvm_unreachable("Unknown TargetTransformInfo::CacheLevel");
350   }
351 
352   llvm::Optional<unsigned>
getCacheAssociativity(TargetTransformInfo::CacheLevel Level)353   getCacheAssociativity(TargetTransformInfo::CacheLevel Level) const {
354     switch (Level) {
355     case TargetTransformInfo::CacheLevel::L1D:
356       LLVM_FALLTHROUGH;
357     case TargetTransformInfo::CacheLevel::L2D:
358       return llvm::Optional<unsigned>();
359     }
360 
361     llvm_unreachable("Unknown TargetTransformInfo::CacheLevel");
362   }
363 
getPrefetchDistance()364   unsigned getPrefetchDistance() const { return 0; }
getMinPrefetchStride(unsigned NumMemAccesses,unsigned NumStridedMemAccesses,unsigned NumPrefetches,bool HasCall)365   unsigned getMinPrefetchStride(unsigned NumMemAccesses,
366                                 unsigned NumStridedMemAccesses,
367                                 unsigned NumPrefetches, bool HasCall) const {
368     return 1;
369   }
getMaxPrefetchIterationsAhead()370   unsigned getMaxPrefetchIterationsAhead() const { return UINT_MAX; }
enableWritePrefetching()371   bool enableWritePrefetching() const { return false; }
372 
getMaxInterleaveFactor(unsigned VF)373   unsigned getMaxInterleaveFactor(unsigned VF) { return 1; }
374 
375   unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty,
376                                   TTI::TargetCostKind CostKind,
377                                   TTI::OperandValueKind Opd1Info,
378                                   TTI::OperandValueKind Opd2Info,
379                                   TTI::OperandValueProperties Opd1PropInfo,
380                                   TTI::OperandValueProperties Opd2PropInfo,
381                                   ArrayRef<const Value *> Args,
382                                   const Instruction *CxtI = nullptr) {
383     // FIXME: A number of transformation tests seem to require these values
384     // which seems a little odd for how arbitary there are.
385     switch (Opcode) {
386     default:
387       break;
388     case Instruction::FDiv:
389     case Instruction::FRem:
390     case Instruction::SDiv:
391     case Instruction::SRem:
392     case Instruction::UDiv:
393     case Instruction::URem:
394       // FIXME: Unlikely to be true for CodeSize.
395       return TTI::TCC_Expensive;
396     }
397     return 1;
398   }
399 
getShuffleCost(TTI::ShuffleKind Kind,VectorType * Ty,int Index,VectorType * SubTp)400   unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *Ty, int Index,
401                           VectorType *SubTp) {
402     return 1;
403   }
404 
getCastInstrCost(unsigned Opcode,Type * Dst,Type * Src,TTI::TargetCostKind CostKind,const Instruction * I)405   unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
406                             TTI::TargetCostKind CostKind,
407                             const Instruction *I) {
408     switch (Opcode) {
409     default:
410       break;
411     case Instruction::IntToPtr: {
412       unsigned SrcSize = Src->getScalarSizeInBits();
413       if (DL.isLegalInteger(SrcSize) &&
414           SrcSize <= DL.getPointerAddrSizeInBits(Dst))
415         return 0;
416       break;
417     }
418     case Instruction::PtrToInt: {
419       unsigned DstSize = Dst->getScalarSizeInBits();
420       if (DL.isLegalInteger(DstSize) &&
421           DstSize >= DL.getPointerAddrSizeInBits(Src))
422         return 0;
423       break;
424     }
425     case Instruction::BitCast:
426       if (Dst == Src || (Dst->isPointerTy() && Src->isPointerTy()))
427         // Identity and pointer-to-pointer casts are free.
428         return 0;
429       break;
430     case Instruction::Trunc:
431       // trunc to a native type is free (assuming the target has compare and
432       // shift-right of the same width).
433       if (DL.isLegalInteger(DL.getTypeSizeInBits(Dst)))
434         return 0;
435       break;
436     }
437     return 1;
438   }
439 
getExtractWithExtendCost(unsigned Opcode,Type * Dst,VectorType * VecTy,unsigned Index)440   unsigned getExtractWithExtendCost(unsigned Opcode, Type *Dst,
441                                     VectorType *VecTy, unsigned Index) {
442     return 1;
443   }
444 
getCFInstrCost(unsigned Opcode,TTI::TargetCostKind CostKind)445   unsigned getCFInstrCost(unsigned Opcode,
446                           TTI::TargetCostKind CostKind) {
447     // A phi would be free, unless we're costing the throughput because it
448     // will require a register.
449     if (Opcode == Instruction::PHI && CostKind != TTI::TCK_RecipThroughput)
450       return 0;
451     return 1;
452   }
453 
getCmpSelInstrCost(unsigned Opcode,Type * ValTy,Type * CondTy,TTI::TargetCostKind CostKind,const Instruction * I)454   unsigned getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
455                               TTI::TargetCostKind CostKind,
456                               const Instruction *I) const {
457     return 1;
458   }
459 
getVectorInstrCost(unsigned Opcode,Type * Val,unsigned Index)460   unsigned getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) {
461     return 1;
462   }
463 
getMemoryOpCost(unsigned Opcode,Type * Src,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,const Instruction * I)464   unsigned getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
465                            unsigned AddressSpace, TTI::TargetCostKind CostKind,
466                            const Instruction *I) const {
467     return 1;
468   }
469 
getMaskedMemoryOpCost(unsigned Opcode,Type * Src,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind)470   unsigned getMaskedMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
471                                  unsigned AddressSpace,
472                                  TTI::TargetCostKind CostKind) {
473     return 1;
474   }
475 
476   unsigned getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
477                                   const Value *Ptr, bool VariableMask,
478                                   Align Alignment, TTI::TargetCostKind CostKind,
479                                   const Instruction *I = nullptr) {
480     return 1;
481   }
482 
getInterleavedMemoryOpCost(unsigned Opcode,Type * VecTy,unsigned Factor,ArrayRef<unsigned> Indices,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,bool UseMaskForCond,bool UseMaskForGaps)483   unsigned getInterleavedMemoryOpCost(
484       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
485       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
486       bool UseMaskForCond, bool UseMaskForGaps) {
487     return 1;
488   }
489 
getIntrinsicInstrCost(const IntrinsicCostAttributes & ICA,TTI::TargetCostKind CostKind)490   unsigned getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
491                                  TTI::TargetCostKind CostKind) {
492     switch (ICA.getID()) {
493     default:
494       break;
495     case Intrinsic::annotation:
496     case Intrinsic::assume:
497     case Intrinsic::sideeffect:
498     case Intrinsic::dbg_declare:
499     case Intrinsic::dbg_value:
500     case Intrinsic::dbg_label:
501     case Intrinsic::invariant_start:
502     case Intrinsic::invariant_end:
503     case Intrinsic::launder_invariant_group:
504     case Intrinsic::strip_invariant_group:
505     case Intrinsic::is_constant:
506     case Intrinsic::lifetime_start:
507     case Intrinsic::lifetime_end:
508     case Intrinsic::objectsize:
509     case Intrinsic::ptr_annotation:
510     case Intrinsic::var_annotation:
511     case Intrinsic::experimental_gc_result:
512     case Intrinsic::experimental_gc_relocate:
513     case Intrinsic::coro_alloc:
514     case Intrinsic::coro_begin:
515     case Intrinsic::coro_free:
516     case Intrinsic::coro_end:
517     case Intrinsic::coro_frame:
518     case Intrinsic::coro_size:
519     case Intrinsic::coro_suspend:
520     case Intrinsic::coro_param:
521     case Intrinsic::coro_subfn_addr:
522       // These intrinsics don't actually represent code after lowering.
523       return 0;
524     }
525     return 1;
526   }
527 
getCallInstrCost(Function * F,Type * RetTy,ArrayRef<Type * > Tys,TTI::TargetCostKind CostKind)528   unsigned getCallInstrCost(Function *F, Type *RetTy, ArrayRef<Type *> Tys,
529                             TTI::TargetCostKind CostKind) {
530     return 1;
531   }
532 
getNumberOfParts(Type * Tp)533   unsigned getNumberOfParts(Type *Tp) { return 0; }
534 
getAddressComputationCost(Type * Tp,ScalarEvolution *,const SCEV *)535   unsigned getAddressComputationCost(Type *Tp, ScalarEvolution *,
536                                      const SCEV *) {
537     return 0;
538   }
539 
getArithmeticReductionCost(unsigned,VectorType *,bool,TTI::TargetCostKind)540   unsigned getArithmeticReductionCost(unsigned, VectorType *, bool,
541                                       TTI::TargetCostKind) { return 1; }
542 
getMinMaxReductionCost(VectorType *,VectorType *,bool,bool,TTI::TargetCostKind)543   unsigned getMinMaxReductionCost(VectorType *, VectorType *, bool, bool,
544                                   TTI::TargetCostKind) { return 1; }
545 
getCostOfKeepingLiveOverCall(ArrayRef<Type * > Tys)546   unsigned getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { return 0; }
547 
getTgtMemIntrinsic(IntrinsicInst * Inst,MemIntrinsicInfo & Info)548   bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info) {
549     return false;
550   }
551 
getAtomicMemIntrinsicMaxElementSize()552   unsigned getAtomicMemIntrinsicMaxElementSize() const {
553     // Note for overrides: You must ensure for all element unordered-atomic
554     // memory intrinsics that all power-of-2 element sizes up to, and
555     // including, the return value of this method have a corresponding
556     // runtime lib call. These runtime lib call definitions can be found
557     // in RuntimeLibcalls.h
558     return 0;
559   }
560 
getOrCreateResultFromMemIntrinsic(IntrinsicInst * Inst,Type * ExpectedType)561   Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
562                                            Type *ExpectedType) {
563     return nullptr;
564   }
565 
getMemcpyLoopLoweringType(LLVMContext & Context,Value * Length,unsigned SrcAddrSpace,unsigned DestAddrSpace,unsigned SrcAlign,unsigned DestAlign)566   Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
567                                   unsigned SrcAddrSpace, unsigned DestAddrSpace,
568                                   unsigned SrcAlign, unsigned DestAlign) const {
569     return Type::getInt8Ty(Context);
570   }
571 
getMemcpyLoopResidualLoweringType(SmallVectorImpl<Type * > & OpsOut,LLVMContext & Context,unsigned RemainingBytes,unsigned SrcAddrSpace,unsigned DestAddrSpace,unsigned SrcAlign,unsigned DestAlign)572   void getMemcpyLoopResidualLoweringType(
573       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
574       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
575       unsigned SrcAlign, unsigned DestAlign) const {
576     for (unsigned i = 0; i != RemainingBytes; ++i)
577       OpsOut.push_back(Type::getInt8Ty(Context));
578   }
579 
areInlineCompatible(const Function * Caller,const Function * Callee)580   bool areInlineCompatible(const Function *Caller,
581                            const Function *Callee) const {
582     return (Caller->getFnAttribute("target-cpu") ==
583             Callee->getFnAttribute("target-cpu")) &&
584            (Caller->getFnAttribute("target-features") ==
585             Callee->getFnAttribute("target-features"));
586   }
587 
areFunctionArgsABICompatible(const Function * Caller,const Function * Callee,SmallPtrSetImpl<Argument * > & Args)588   bool areFunctionArgsABICompatible(const Function *Caller,
589                                     const Function *Callee,
590                                     SmallPtrSetImpl<Argument *> &Args) const {
591     return (Caller->getFnAttribute("target-cpu") ==
592             Callee->getFnAttribute("target-cpu")) &&
593            (Caller->getFnAttribute("target-features") ==
594             Callee->getFnAttribute("target-features"));
595   }
596 
isIndexedLoadLegal(TTI::MemIndexedMode Mode,Type * Ty,const DataLayout & DL)597   bool isIndexedLoadLegal(TTI::MemIndexedMode Mode, Type *Ty,
598                           const DataLayout &DL) const {
599     return false;
600   }
601 
isIndexedStoreLegal(TTI::MemIndexedMode Mode,Type * Ty,const DataLayout & DL)602   bool isIndexedStoreLegal(TTI::MemIndexedMode Mode, Type *Ty,
603                            const DataLayout &DL) const {
604     return false;
605   }
606 
getLoadStoreVecRegBitWidth(unsigned AddrSpace)607   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const { return 128; }
608 
isLegalToVectorizeLoad(LoadInst * LI)609   bool isLegalToVectorizeLoad(LoadInst *LI) const { return true; }
610 
isLegalToVectorizeStore(StoreInst * SI)611   bool isLegalToVectorizeStore(StoreInst *SI) const { return true; }
612 
isLegalToVectorizeLoadChain(unsigned ChainSizeInBytes,Align Alignment,unsigned AddrSpace)613   bool isLegalToVectorizeLoadChain(unsigned ChainSizeInBytes, Align Alignment,
614                                    unsigned AddrSpace) const {
615     return true;
616   }
617 
isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes,Align Alignment,unsigned AddrSpace)618   bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, Align Alignment,
619                                     unsigned AddrSpace) const {
620     return true;
621   }
622 
getLoadVectorFactor(unsigned VF,unsigned LoadSize,unsigned ChainSizeInBytes,VectorType * VecTy)623   unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize,
624                                unsigned ChainSizeInBytes,
625                                VectorType *VecTy) const {
626     return VF;
627   }
628 
getStoreVectorFactor(unsigned VF,unsigned StoreSize,unsigned ChainSizeInBytes,VectorType * VecTy)629   unsigned getStoreVectorFactor(unsigned VF, unsigned StoreSize,
630                                 unsigned ChainSizeInBytes,
631                                 VectorType *VecTy) const {
632     return VF;
633   }
634 
useReductionIntrinsic(unsigned Opcode,Type * Ty,TTI::ReductionFlags Flags)635   bool useReductionIntrinsic(unsigned Opcode, Type *Ty,
636                              TTI::ReductionFlags Flags) const {
637     return false;
638   }
639 
shouldExpandReduction(const IntrinsicInst * II)640   bool shouldExpandReduction(const IntrinsicInst *II) const { return true; }
641 
getGISelRematGlobalCost()642   unsigned getGISelRematGlobalCost() const { return 1; }
643 
hasActiveVectorLength()644   bool hasActiveVectorLength() const { return false; }
645 
646 protected:
647   // Obtain the minimum required size to hold the value (without the sign)
648   // In case of a vector it returns the min required size for one element.
minRequiredElementSize(const Value * Val,bool & isSigned)649   unsigned minRequiredElementSize(const Value *Val, bool &isSigned) {
650     if (isa<ConstantDataVector>(Val) || isa<ConstantVector>(Val)) {
651       const auto *VectorValue = cast<Constant>(Val);
652 
653       // In case of a vector need to pick the max between the min
654       // required size for each element
655       auto *VT = cast<VectorType>(Val->getType());
656 
657       // Assume unsigned elements
658       isSigned = false;
659 
660       // The max required size is the size of the vector element type
661       unsigned MaxRequiredSize =
662           VT->getElementType()->getPrimitiveSizeInBits().getFixedSize();
663 
664       unsigned MinRequiredSize = 0;
665       for (unsigned i = 0, e = VT->getNumElements(); i < e; ++i) {
666         if (auto *IntElement =
667                 dyn_cast<ConstantInt>(VectorValue->getAggregateElement(i))) {
668           bool signedElement = IntElement->getValue().isNegative();
669           // Get the element min required size.
670           unsigned ElementMinRequiredSize =
671               IntElement->getValue().getMinSignedBits() - 1;
672           // In case one element is signed then all the vector is signed.
673           isSigned |= signedElement;
674           // Save the max required bit size between all the elements.
675           MinRequiredSize = std::max(MinRequiredSize, ElementMinRequiredSize);
676         } else {
677           // not an int constant element
678           return MaxRequiredSize;
679         }
680       }
681       return MinRequiredSize;
682     }
683 
684     if (const auto *CI = dyn_cast<ConstantInt>(Val)) {
685       isSigned = CI->getValue().isNegative();
686       return CI->getValue().getMinSignedBits() - 1;
687     }
688 
689     if (const auto *Cast = dyn_cast<SExtInst>(Val)) {
690       isSigned = true;
691       return Cast->getSrcTy()->getScalarSizeInBits() - 1;
692     }
693 
694     if (const auto *Cast = dyn_cast<ZExtInst>(Val)) {
695       isSigned = false;
696       return Cast->getSrcTy()->getScalarSizeInBits();
697     }
698 
699     isSigned = false;
700     return Val->getType()->getScalarSizeInBits();
701   }
702 
isStridedAccess(const SCEV * Ptr)703   bool isStridedAccess(const SCEV *Ptr) {
704     return Ptr && isa<SCEVAddRecExpr>(Ptr);
705   }
706 
getConstantStrideStep(ScalarEvolution * SE,const SCEV * Ptr)707   const SCEVConstant *getConstantStrideStep(ScalarEvolution *SE,
708                                             const SCEV *Ptr) {
709     if (!isStridedAccess(Ptr))
710       return nullptr;
711     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ptr);
712     return dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*SE));
713   }
714 
isConstantStridedAccessLessThan(ScalarEvolution * SE,const SCEV * Ptr,int64_t MergeDistance)715   bool isConstantStridedAccessLessThan(ScalarEvolution *SE, const SCEV *Ptr,
716                                        int64_t MergeDistance) {
717     const SCEVConstant *Step = getConstantStrideStep(SE, Ptr);
718     if (!Step)
719       return false;
720     APInt StrideVal = Step->getAPInt();
721     if (StrideVal.getBitWidth() > 64)
722       return false;
723     // FIXME: Need to take absolute value for negative stride case.
724     return StrideVal.getSExtValue() < MergeDistance;
725   }
726 };
727 
728 /// CRTP base class for use as a mix-in that aids implementing
729 /// a TargetTransformInfo-compatible class.
730 template <typename T>
731 class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
732 private:
733   typedef TargetTransformInfoImplBase BaseT;
734 
735 protected:
TargetTransformInfoImplCRTPBase(const DataLayout & DL)736   explicit TargetTransformInfoImplCRTPBase(const DataLayout &DL) : BaseT(DL) {}
737 
738 public:
739   using BaseT::getGEPCost;
740 
741   int getGEPCost(Type *PointeeType, const Value *Ptr,
742                  ArrayRef<const Value *> Operands,
743                  TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency) {
744     assert(PointeeType && Ptr && "can't get GEPCost of nullptr");
745     // TODO: will remove this when pointers have an opaque type.
746     assert(Ptr->getType()->getScalarType()->getPointerElementType() ==
747                PointeeType &&
748            "explicit pointee type doesn't match operand's pointee type");
749     auto *BaseGV = dyn_cast<GlobalValue>(Ptr->stripPointerCasts());
750     bool HasBaseReg = (BaseGV == nullptr);
751 
752     auto PtrSizeBits = DL.getPointerTypeSizeInBits(Ptr->getType());
753     APInt BaseOffset(PtrSizeBits, 0);
754     int64_t Scale = 0;
755 
756     auto GTI = gep_type_begin(PointeeType, Operands);
757     Type *TargetType = nullptr;
758 
759     // Handle the case where the GEP instruction has a single operand,
760     // the basis, therefore TargetType is a nullptr.
761     if (Operands.empty())
762       return !BaseGV ? TTI::TCC_Free : TTI::TCC_Basic;
763 
764     for (auto I = Operands.begin(); I != Operands.end(); ++I, ++GTI) {
765       TargetType = GTI.getIndexedType();
766       // We assume that the cost of Scalar GEP with constant index and the
767       // cost of Vector GEP with splat constant index are the same.
768       const ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I);
769       if (!ConstIdx)
770         if (auto Splat = getSplatValue(*I))
771           ConstIdx = dyn_cast<ConstantInt>(Splat);
772       if (StructType *STy = GTI.getStructTypeOrNull()) {
773         // For structures the index is always splat or scalar constant
774         assert(ConstIdx && "Unexpected GEP index");
775         uint64_t Field = ConstIdx->getZExtValue();
776         BaseOffset += DL.getStructLayout(STy)->getElementOffset(Field);
777       } else {
778         int64_t ElementSize = DL.getTypeAllocSize(GTI.getIndexedType());
779         if (ConstIdx) {
780           BaseOffset +=
781               ConstIdx->getValue().sextOrTrunc(PtrSizeBits) * ElementSize;
782         } else {
783           // Needs scale register.
784           if (Scale != 0)
785             // No addressing mode takes two scale registers.
786             return TTI::TCC_Basic;
787           Scale = ElementSize;
788         }
789       }
790     }
791 
792     if (static_cast<T *>(this)->isLegalAddressingMode(
793             TargetType, const_cast<GlobalValue *>(BaseGV),
794             BaseOffset.sextOrTrunc(64).getSExtValue(), HasBaseReg, Scale,
795             Ptr->getType()->getPointerAddressSpace()))
796       return TTI::TCC_Free;
797     return TTI::TCC_Basic;
798   }
799 
getUserCost(const User * U,ArrayRef<const Value * > Operands,TTI::TargetCostKind CostKind)800   int getUserCost(const User *U, ArrayRef<const Value *> Operands,
801                   TTI::TargetCostKind CostKind) {
802     auto *TargetTTI = static_cast<T *>(this);
803 
804     // FIXME: We shouldn't have to special-case intrinsics here.
805     if (CostKind == TTI::TCK_RecipThroughput) {
806       if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) {
807         IntrinsicCostAttributes CostAttrs(*II);
808         return TargetTTI->getIntrinsicInstrCost(CostAttrs, CostKind);
809       }
810     }
811 
812     // FIXME: Unlikely to be true for anything but CodeSize.
813     if (const auto *CB = dyn_cast<CallBase>(U)) {
814       const Function *F = CB->getCalledFunction();
815       if (F) {
816         FunctionType *FTy = F->getFunctionType();
817         if (Intrinsic::ID IID = F->getIntrinsicID()) {
818           IntrinsicCostAttributes Attrs(IID, *CB);
819           return TargetTTI->getIntrinsicInstrCost(Attrs, CostKind);
820         }
821 
822         if (!TargetTTI->isLoweredToCall(F))
823           return TTI::TCC_Basic; // Give a basic cost if it will be lowered
824 
825         return TTI::TCC_Basic * (FTy->getNumParams() + 1);
826       }
827       return TTI::TCC_Basic * (CB->arg_size() + 1);
828     }
829 
830     Type *Ty = U->getType();
831     Type *OpTy =
832       U->getNumOperands() == 1 ? U->getOperand(0)->getType() : nullptr;
833     unsigned Opcode = Operator::getOpcode(U);
834     auto *I = dyn_cast<Instruction>(U);
835     switch (Opcode) {
836     default:
837       break;
838     case Instruction::Br:
839     case Instruction::Ret:
840     case Instruction::PHI:
841       return TargetTTI->getCFInstrCost(Opcode, CostKind);
842     case Instruction::ExtractValue:
843     case Instruction::Freeze:
844       return TTI::TCC_Free;
845     case Instruction::Alloca:
846       if (cast<AllocaInst>(U)->isStaticAlloca())
847         return TTI::TCC_Free;
848       break;
849     case Instruction::GetElementPtr: {
850       const GEPOperator *GEP = cast<GEPOperator>(U);
851       return TargetTTI->getGEPCost(GEP->getSourceElementType(),
852                                    GEP->getPointerOperand(),
853                                    Operands.drop_front());
854     }
855     case Instruction::Add:
856     case Instruction::FAdd:
857     case Instruction::Sub:
858     case Instruction::FSub:
859     case Instruction::Mul:
860     case Instruction::FMul:
861     case Instruction::UDiv:
862     case Instruction::SDiv:
863     case Instruction::FDiv:
864     case Instruction::URem:
865     case Instruction::SRem:
866     case Instruction::FRem:
867     case Instruction::Shl:
868     case Instruction::LShr:
869     case Instruction::AShr:
870     case Instruction::And:
871     case Instruction::Or:
872     case Instruction::Xor:
873     case Instruction::FNeg: {
874       TTI::OperandValueProperties Op1VP = TTI::OP_None;
875       TTI::OperandValueProperties Op2VP = TTI::OP_None;
876       TTI::OperandValueKind Op1VK =
877         TTI::getOperandInfo(U->getOperand(0), Op1VP);
878       TTI::OperandValueKind Op2VK = Opcode != Instruction::FNeg ?
879         TTI::getOperandInfo(U->getOperand(1), Op2VP) : TTI::OK_AnyValue;
880       SmallVector<const Value *, 2> Operands(U->operand_values());
881       return TargetTTI->getArithmeticInstrCost(Opcode, Ty, CostKind,
882                                                Op1VK, Op2VK,
883                                                Op1VP, Op2VP, Operands, I);
884     }
885     case Instruction::IntToPtr:
886     case Instruction::PtrToInt:
887     case Instruction::SIToFP:
888     case Instruction::UIToFP:
889     case Instruction::FPToUI:
890     case Instruction::FPToSI:
891     case Instruction::Trunc:
892     case Instruction::FPTrunc:
893     case Instruction::BitCast:
894     case Instruction::FPExt:
895     case Instruction::SExt:
896     case Instruction::ZExt:
897     case Instruction::AddrSpaceCast:
898       return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I);
899     case Instruction::Store: {
900       auto *SI = cast<StoreInst>(U);
901       Type *ValTy = U->getOperand(0)->getType();
902       return TargetTTI->getMemoryOpCost(Opcode, ValTy, SI->getAlign(),
903                                         SI->getPointerAddressSpace(),
904                                         CostKind, I);
905     }
906     case Instruction::Load: {
907       auto *LI = cast<LoadInst>(U);
908       return TargetTTI->getMemoryOpCost(Opcode, U->getType(), LI->getAlign(),
909                                         LI->getPointerAddressSpace(),
910                                         CostKind, I);
911     }
912     case Instruction::Select: {
913       Type *CondTy = U->getOperand(0)->getType();
914       return TargetTTI->getCmpSelInstrCost(Opcode, U->getType(), CondTy,
915                                            CostKind, I);
916     }
917     case Instruction::ICmp:
918     case Instruction::FCmp: {
919       Type *ValTy = U->getOperand(0)->getType();
920       return TargetTTI->getCmpSelInstrCost(Opcode, ValTy, U->getType(),
921                                            CostKind, I);
922     }
923     case Instruction::InsertElement: {
924       auto *IE = dyn_cast<InsertElementInst>(U);
925       if (!IE)
926         return TTI::TCC_Basic; // FIXME
927       auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
928       unsigned Idx = CI ? CI->getZExtValue() : -1;
929       return TargetTTI->getVectorInstrCost(Opcode, Ty, Idx);
930     }
931     case Instruction::ShuffleVector: {
932       auto *Shuffle = dyn_cast<ShuffleVectorInst>(U);
933       if (!Shuffle)
934         return TTI::TCC_Basic; // FIXME
935       auto *VecTy = cast<VectorType>(U->getType());
936       auto *VecSrcTy = cast<VectorType>(U->getOperand(0)->getType());
937 
938       // TODO: Identify and add costs for insert subvector, etc.
939       int SubIndex;
940       if (Shuffle->isExtractSubvectorMask(SubIndex))
941         return TargetTTI->getShuffleCost(TTI::SK_ExtractSubvector, VecSrcTy,
942                                          SubIndex, VecTy);
943       else if (Shuffle->changesLength())
944         return CostKind == TTI::TCK_RecipThroughput ? -1 : 1;
945       else if (Shuffle->isIdentity())
946         return 0;
947       else if (Shuffle->isReverse())
948         return TargetTTI->getShuffleCost(TTI::SK_Reverse, VecTy, 0, nullptr);
949       else if (Shuffle->isSelect())
950         return TargetTTI->getShuffleCost(TTI::SK_Select, VecTy, 0, nullptr);
951       else if (Shuffle->isTranspose())
952         return TargetTTI->getShuffleCost(TTI::SK_Transpose, VecTy, 0, nullptr);
953       else if (Shuffle->isZeroEltSplat())
954         return TargetTTI->getShuffleCost(TTI::SK_Broadcast, VecTy, 0, nullptr);
955       else if (Shuffle->isSingleSource())
956         return TargetTTI->getShuffleCost(TTI::SK_PermuteSingleSrc, VecTy, 0,
957                                          nullptr);
958 
959       return TargetTTI->getShuffleCost(TTI::SK_PermuteTwoSrc, VecTy, 0,
960                                        nullptr);
961     }
962     case Instruction::ExtractElement: {
963       unsigned Idx = -1;
964       auto *EEI = dyn_cast<ExtractElementInst>(U);
965       if (!EEI)
966         return TTI::TCC_Basic; // FIXME
967 
968       auto *CI = dyn_cast<ConstantInt>(EEI->getOperand(1));
969       if (CI)
970         Idx = CI->getZExtValue();
971 
972       // Try to match a reduction sequence (series of shufflevector and
973       // vector  adds followed by a extractelement).
974       unsigned ReduxOpCode;
975       VectorType *ReduxType;
976 
977       switch (TTI::matchVectorSplittingReduction(EEI, ReduxOpCode,
978                                                  ReduxType)) {
979       case TTI::RK_Arithmetic:
980         return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
981                                           /*IsPairwiseForm=*/false,
982                                           CostKind);
983       case TTI::RK_MinMax:
984         return TargetTTI->getMinMaxReductionCost(
985             ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
986             /*IsPairwiseForm=*/false, /*IsUnsigned=*/false, CostKind);
987       case TTI::RK_UnsignedMinMax:
988         return TargetTTI->getMinMaxReductionCost(
989             ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
990             /*IsPairwiseForm=*/false, /*IsUnsigned=*/true, CostKind);
991       case TTI::RK_None:
992         break;
993       }
994 
995       switch (TTI::matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
996       case TTI::RK_Arithmetic:
997         return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
998                                           /*IsPairwiseForm=*/true, CostKind);
999       case TTI::RK_MinMax:
1000         return TargetTTI->getMinMaxReductionCost(
1001             ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
1002             /*IsPairwiseForm=*/true, /*IsUnsigned=*/false, CostKind);
1003       case TTI::RK_UnsignedMinMax:
1004         return TargetTTI->getMinMaxReductionCost(
1005             ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
1006             /*IsPairwiseForm=*/true, /*IsUnsigned=*/true, CostKind);
1007       case TTI::RK_None:
1008         break;
1009       }
1010       return TargetTTI->getVectorInstrCost(Opcode, U->getOperand(0)->getType(),
1011                                            Idx);
1012     }
1013     }
1014     // By default, just classify everything as 'basic'.
1015     return TTI::TCC_Basic;
1016   }
1017 
getInstructionLatency(const Instruction * I)1018   int getInstructionLatency(const Instruction *I) {
1019     SmallVector<const Value *, 4> Operands(I->value_op_begin(),
1020                                            I->value_op_end());
1021     if (getUserCost(I, Operands, TTI::TCK_Latency) == TTI::TCC_Free)
1022       return 0;
1023 
1024     if (isa<LoadInst>(I))
1025       return 4;
1026 
1027     Type *DstTy = I->getType();
1028 
1029     // Usually an intrinsic is a simple instruction.
1030     // A real function call is much slower.
1031     if (auto *CI = dyn_cast<CallInst>(I)) {
1032       const Function *F = CI->getCalledFunction();
1033       if (!F || static_cast<T *>(this)->isLoweredToCall(F))
1034         return 40;
1035       // Some intrinsics return a value and a flag, we use the value type
1036       // to decide its latency.
1037       if (StructType *StructTy = dyn_cast<StructType>(DstTy))
1038         DstTy = StructTy->getElementType(0);
1039       // Fall through to simple instructions.
1040     }
1041 
1042     if (VectorType *VectorTy = dyn_cast<VectorType>(DstTy))
1043       DstTy = VectorTy->getElementType();
1044     if (DstTy->isFloatingPointTy())
1045       return 3;
1046 
1047     return 1;
1048   }
1049 };
1050 } // namespace llvm
1051 
1052 #endif
1053