1 //===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file defines a TargetTransformInfo::Concept conforming object specific
10 /// to the RISC-V target machine. It uses the target's detailed information to
11 /// provide more precise answers to certain TTI queries, while letting the
12 /// target independent and default TTI implementations handle the rest.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
17 #define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
18 
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/Analysis/IVDescriptors.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/BasicTTIImpl.h"
24 #include "llvm/IR/Function.h"
25 
26 namespace llvm {
27 
28 class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
29   using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
30   using TTI = TargetTransformInfo;
31 
32   friend BaseT;
33 
34   const RISCVSubtarget *ST;
35   const RISCVTargetLowering *TLI;
36 
37   const RISCVSubtarget *getST() const { return ST; }
38   const RISCVTargetLowering *getTLI() const { return TLI; }
39 
40   unsigned getMaxVLFor(VectorType *Ty);
41 public:
42   explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
43       : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)),
44         TLI(ST->getTargetLowering()) {}
45 
46   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
47                                 TTI::TargetCostKind CostKind);
48   InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
49                                     const APInt &Imm, Type *Ty,
50                                     TTI::TargetCostKind CostKind,
51                                     Instruction *Inst = nullptr);
52   InstructionCost getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
53                                       const APInt &Imm, Type *Ty,
54                                       TTI::TargetCostKind CostKind);
55 
56   TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth);
57 
58   bool shouldExpandReduction(const IntrinsicInst *II) const;
59   bool supportsScalableVectors() const { return ST->hasVInstructions(); }
60   PredicationStyle emitGetActiveLaneMask() const {
61     return ST->hasVInstructions() ? PredicationStyle::Data
62                                   : PredicationStyle::None;
63   }
64   Optional<unsigned> getMaxVScale() const;
65   Optional<unsigned> getVScaleForTuning() const;
66 
67   TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
68 
69   unsigned getRegUsageForType(Type *Ty);
70 
71   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
72                                         Align Alignment, unsigned AddressSpace,
73                                         TTI::TargetCostKind CostKind);
74 
75   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
76                                TTI::UnrollingPreferences &UP,
77                                OptimizationRemarkEmitter *ORE);
78 
79   void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
80                              TTI::PeelingPreferences &PP);
81 
82   unsigned getMinVectorRegisterBitWidth() const {
83     return ST->useRVVForFixedLengthVectors() ? 16 : 0;
84   }
85 
86   InstructionCost getSpliceCost(VectorType *Tp, int Index);
87   InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
88                                  ArrayRef<int> Mask, int Index,
89                                  VectorType *SubTp,
90                                  ArrayRef<const Value *> Args = None);
91 
92   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
93                                         TTI::TargetCostKind CostKind);
94 
95   InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
96                                          const Value *Ptr, bool VariableMask,
97                                          Align Alignment,
98                                          TTI::TargetCostKind CostKind,
99                                          const Instruction *I);
100 
101   InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
102                                    TTI::CastContextHint CCH,
103                                    TTI::TargetCostKind CostKind,
104                                    const Instruction *I = nullptr);
105 
106   InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
107                                          bool IsUnsigned,
108                                          TTI::TargetCostKind CostKind);
109 
110   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
111                                              Optional<FastMathFlags> FMF,
112                                              TTI::TargetCostKind CostKind);
113 
114   bool isElementTypeLegalForScalableVector(Type *Ty) const {
115     return TLI->isLegalElementTypeForRVV(Ty);
116   }
117 
118   bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
119     if (!ST->hasVInstructions())
120       return false;
121 
122     // Only support fixed vectors if we know the minimum vector size.
123     if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors())
124       return false;
125 
126     // Don't allow elements larger than the ELEN.
127     // FIXME: How to limit for scalable vectors?
128     if (isa<FixedVectorType>(DataType) &&
129         DataType->getScalarSizeInBits() > ST->getELEN())
130       return false;
131 
132     if (Alignment <
133         DL.getTypeStoreSize(DataType->getScalarType()).getFixedSize())
134       return false;
135 
136     return TLI->isLegalElementTypeForRVV(DataType->getScalarType());
137   }
138 
139   bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
140     return isLegalMaskedLoadStore(DataType, Alignment);
141   }
142   bool isLegalMaskedStore(Type *DataType, Align Alignment) {
143     return isLegalMaskedLoadStore(DataType, Alignment);
144   }
145 
146   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
147     if (!ST->hasVInstructions())
148       return false;
149 
150     // Only support fixed vectors if we know the minimum vector size.
151     if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors())
152       return false;
153 
154     // Don't allow elements larger than the ELEN.
155     // FIXME: How to limit for scalable vectors?
156     if (isa<FixedVectorType>(DataType) &&
157         DataType->getScalarSizeInBits() > ST->getELEN())
158       return false;
159 
160     if (Alignment <
161         DL.getTypeStoreSize(DataType->getScalarType()).getFixedSize())
162       return false;
163 
164     return TLI->isLegalElementTypeForRVV(DataType->getScalarType());
165   }
166 
167   bool isLegalMaskedGather(Type *DataType, Align Alignment) {
168     return isLegalMaskedGatherScatter(DataType, Alignment);
169   }
170   bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
171     return isLegalMaskedGatherScatter(DataType, Alignment);
172   }
173 
174   bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment) {
175     // Scalarize masked gather for RV64 if EEW=64 indices aren't supported.
176     return ST->is64Bit() && !ST->hasVInstructionsI64();
177   }
178 
179   bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) {
180     // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported.
181     return ST->is64Bit() && !ST->hasVInstructionsI64();
182   }
183 
184   /// \returns How the target needs this vector-predicated operation to be
185   /// transformed.
186   TargetTransformInfo::VPLegalization
187   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
188     using VPLegalization = TargetTransformInfo::VPLegalization;
189     return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
190   }
191 
192   bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
193                                    ElementCount VF) const {
194     if (!VF.isScalable())
195       return true;
196 
197     Type *Ty = RdxDesc.getRecurrenceType();
198     if (!TLI->isLegalElementTypeForRVV(Ty))
199       return false;
200 
201     switch (RdxDesc.getRecurrenceKind()) {
202     case RecurKind::Add:
203     case RecurKind::FAdd:
204     case RecurKind::And:
205     case RecurKind::Or:
206     case RecurKind::Xor:
207     case RecurKind::SMin:
208     case RecurKind::SMax:
209     case RecurKind::UMin:
210     case RecurKind::UMax:
211     case RecurKind::FMin:
212     case RecurKind::FMax:
213       return true;
214     default:
215       return false;
216     }
217   }
218 
219   unsigned getMaxInterleaveFactor(unsigned VF) {
220     // If the loop will not be vectorized, don't interleave the loop.
221     // Let regular unroll to unroll the loop.
222     return VF == 1 ? 1 : ST->getMaxInterleaveFactor();
223   }
224 
225   enum RISCVRegisterClass { GPRRC, FPRRC, VRRC };
226   unsigned getNumberOfRegisters(unsigned ClassID) const {
227     switch (ClassID) {
228     case RISCVRegisterClass::GPRRC:
229       // 31 = 32 GPR - x0 (zero register)
230       // FIXME: Should we exclude fixed registers like SP, TP or GP?
231       return 31;
232     case RISCVRegisterClass::FPRRC:
233       if (ST->hasStdExtF())
234         return 32;
235       return 0;
236     case RISCVRegisterClass::VRRC:
237       // Although there are 32 vector registers, v0 is special in that it is the
238       // only register that can be used to hold a mask.
239       // FIXME: Should we conservatively return 31 as the number of usable
240       // vector registers?
241       return ST->hasVInstructions() ? 32 : 0;
242     }
243     llvm_unreachable("unknown register class");
244   }
245 
246   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
247     if (Vector)
248       return RISCVRegisterClass::VRRC;
249     if (!Ty)
250       return RISCVRegisterClass::GPRRC;
251 
252     Type *ScalarTy = Ty->getScalarType();
253     if ((ScalarTy->isHalfTy() && ST->hasStdExtZfh()) ||
254         (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
255         (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
256       return RISCVRegisterClass::FPRRC;
257     }
258 
259     return RISCVRegisterClass::GPRRC;
260   }
261 
262   const char *getRegisterClassName(unsigned ClassID) const {
263     switch (ClassID) {
264     case RISCVRegisterClass::GPRRC:
265       return "RISCV::GPRRC";
266     case RISCVRegisterClass::FPRRC:
267       return "RISCV::FPRRC";
268     case RISCVRegisterClass::VRRC:
269       return "RISCV::VRRC";
270     }
271     llvm_unreachable("unknown register class");
272   }
273 };
274 
275 } // end namespace llvm
276 
277 #endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
278