1 //===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file defines a TargetTransformInfo::Concept conforming object specific
10 /// to the RISC-V target machine. It uses the target's detailed information to
11 /// provide more precise answers to certain TTI queries, while letting the
12 /// target independent and default TTI implementations handle the rest.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
17 #define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
18 
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/Analysis/IVDescriptors.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/BasicTTIImpl.h"
24 #include "llvm/IR/Function.h"
25 #include <optional>
26 
27 namespace llvm {
28 
29 class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
30   using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
31   using TTI = TargetTransformInfo;
32 
33   friend BaseT;
34 
35   const RISCVSubtarget *ST;
36   const RISCVTargetLowering *TLI;
37 
38   const RISCVSubtarget *getST() const { return ST; }
39   const RISCVTargetLowering *getTLI() const { return TLI; }
40 
41   /// This function returns an estimate for VL to be used in VL based terms
42   /// of the cost model.  For fixed length vectors, this is simply the
43   /// vector length.  For scalable vectors, we return results consistent
44   /// with getVScaleForTuning under the assumption that clients are also
45   /// using that when comparing costs between scalar and vector representation.
46   /// This does unfortunately mean that we can both undershoot and overshot
47   /// the true cost significantly if getVScaleForTuning is wildly off for the
48   /// actual target hardware.
49   unsigned getEstimatedVLFor(VectorType *Ty);
50 
51   /// Return the cost of LMUL. The larger the LMUL, the higher the cost.
52   InstructionCost getLMULCost(MVT VT);
53 
54 public:
55   explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
56       : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)),
57         TLI(ST->getTargetLowering()) {}
58 
59   /// Return the cost of materializing an immediate for a value operand of
60   /// a store instruction.
61   InstructionCost getStoreImmCost(Type *VecTy, TTI::OperandValueInfo OpInfo,
62                                   TTI::TargetCostKind CostKind);
63 
64   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
65                                 TTI::TargetCostKind CostKind);
66   InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
67                                     const APInt &Imm, Type *Ty,
68                                     TTI::TargetCostKind CostKind,
69                                     Instruction *Inst = nullptr);
70   InstructionCost getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
71                                       const APInt &Imm, Type *Ty,
72                                       TTI::TargetCostKind CostKind);
73 
74   TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth);
75 
76   bool shouldExpandReduction(const IntrinsicInst *II) const;
77   bool supportsScalableVectors() const { return ST->hasVInstructions(); }
78   bool enableScalableVectorization() const { return ST->hasVInstructions(); }
79   PredicationStyle emitGetActiveLaneMask() const {
80     return ST->hasVInstructions() ? PredicationStyle::Data
81                                   : PredicationStyle::None;
82   }
83   std::optional<unsigned> getMaxVScale() const;
84   std::optional<unsigned> getVScaleForTuning() const;
85 
86   TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
87 
88   unsigned getRegUsageForType(Type *Ty);
89 
90   unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const;
91 
92   bool preferEpilogueVectorization() const {
93     // Epilogue vectorization is usually unprofitable - tail folding or
94     // a smaller VF would have been better.  This a blunt hammer - we
95     // should re-examine this once vectorization is better tuned.
96     return false;
97   }
98 
99   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
100                                         Align Alignment, unsigned AddressSpace,
101                                         TTI::TargetCostKind CostKind);
102 
103   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
104                                TTI::UnrollingPreferences &UP,
105                                OptimizationRemarkEmitter *ORE);
106 
107   void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
108                              TTI::PeelingPreferences &PP);
109 
110   unsigned getMinVectorRegisterBitWidth() const {
111     return ST->useRVVForFixedLengthVectors() ? 16 : 0;
112   }
113 
114   InstructionCost getSpliceCost(VectorType *Tp, int Index);
115   InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
116                                  ArrayRef<int> Mask,
117                                  TTI::TargetCostKind CostKind, int Index,
118                                  VectorType *SubTp,
119                                  ArrayRef<const Value *> Args = std::nullopt);
120 
121   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
122                                         TTI::TargetCostKind CostKind);
123 
124   InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
125                                          const Value *Ptr, bool VariableMask,
126                                          Align Alignment,
127                                          TTI::TargetCostKind CostKind,
128                                          const Instruction *I);
129 
130   InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
131                                    TTI::CastContextHint CCH,
132                                    TTI::TargetCostKind CostKind,
133                                    const Instruction *I = nullptr);
134 
135   InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
136                                          bool IsUnsigned,
137                                          TTI::TargetCostKind CostKind);
138 
139   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
140                                              std::optional<FastMathFlags> FMF,
141                                              TTI::TargetCostKind CostKind);
142 
143   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
144                                            Type *ResTy, VectorType *ValTy,
145                                            std::optional<FastMathFlags> FMF,
146                                            TTI::TargetCostKind CostKind);
147 
148   InstructionCost
149   getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment,
150                   unsigned AddressSpace, TTI::TargetCostKind CostKind,
151                   TTI::OperandValueInfo OpdInfo = {TTI::OK_AnyValue, TTI::OP_None},
152                   const Instruction *I = nullptr);
153 
154   InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
155                                      CmpInst::Predicate VecPred,
156                                      TTI::TargetCostKind CostKind,
157                                      const Instruction *I = nullptr);
158 
159   using BaseT::getVectorInstrCost;
160   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
161                                      TTI::TargetCostKind CostKind,
162                                      unsigned Index, Value *Op0, Value *Op1);
163 
164   InstructionCost getArithmeticInstrCost(
165       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
166       TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
167       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
168       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
169       const Instruction *CxtI = nullptr);
170 
171   bool isElementTypeLegalForScalableVector(Type *Ty) const {
172     return TLI->isLegalElementTypeForRVV(Ty);
173   }
174 
175   bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
176     if (!ST->hasVInstructions())
177       return false;
178 
179     // Only support fixed vectors if we know the minimum vector size.
180     if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors())
181       return false;
182 
183     // Don't allow elements larger than the ELEN.
184     // FIXME: How to limit for scalable vectors?
185     if (isa<FixedVectorType>(DataType) &&
186         DataType->getScalarSizeInBits() > ST->getELEN())
187       return false;
188 
189     if (Alignment <
190         DL.getTypeStoreSize(DataType->getScalarType()).getFixedValue())
191       return false;
192 
193     return TLI->isLegalElementTypeForRVV(DataType->getScalarType());
194   }
195 
196   bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
197     return isLegalMaskedLoadStore(DataType, Alignment);
198   }
199   bool isLegalMaskedStore(Type *DataType, Align Alignment) {
200     return isLegalMaskedLoadStore(DataType, Alignment);
201   }
202 
203   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
204     if (!ST->hasVInstructions())
205       return false;
206 
207     // Only support fixed vectors if we know the minimum vector size.
208     if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors())
209       return false;
210 
211     // Don't allow elements larger than the ELEN.
212     // FIXME: How to limit for scalable vectors?
213     if (isa<FixedVectorType>(DataType) &&
214         DataType->getScalarSizeInBits() > ST->getELEN())
215       return false;
216 
217     if (Alignment <
218         DL.getTypeStoreSize(DataType->getScalarType()).getFixedValue())
219       return false;
220 
221     return TLI->isLegalElementTypeForRVV(DataType->getScalarType());
222   }
223 
224   bool isLegalMaskedGather(Type *DataType, Align Alignment) {
225     return isLegalMaskedGatherScatter(DataType, Alignment);
226   }
227   bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
228     return isLegalMaskedGatherScatter(DataType, Alignment);
229   }
230 
231   bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment) {
232     // Scalarize masked gather for RV64 if EEW=64 indices aren't supported.
233     return ST->is64Bit() && !ST->hasVInstructionsI64();
234   }
235 
236   bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) {
237     // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported.
238     return ST->is64Bit() && !ST->hasVInstructionsI64();
239   }
240 
241   /// \returns How the target needs this vector-predicated operation to be
242   /// transformed.
243   TargetTransformInfo::VPLegalization
244   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
245     using VPLegalization = TargetTransformInfo::VPLegalization;
246     if (!ST->hasVInstructions() ||
247         (PI.getIntrinsicID() == Intrinsic::vp_reduce_mul &&
248          cast<VectorType>(PI.getArgOperand(1)->getType())
249                  ->getElementType()
250                  ->getIntegerBitWidth() != 1))
251       return VPLegalization(VPLegalization::Discard, VPLegalization::Convert);
252     return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
253   }
254 
255   bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
256                                    ElementCount VF) const {
257     if (!VF.isScalable())
258       return true;
259 
260     Type *Ty = RdxDesc.getRecurrenceType();
261     if (!TLI->isLegalElementTypeForRVV(Ty))
262       return false;
263 
264     switch (RdxDesc.getRecurrenceKind()) {
265     case RecurKind::Add:
266     case RecurKind::FAdd:
267     case RecurKind::And:
268     case RecurKind::Or:
269     case RecurKind::Xor:
270     case RecurKind::SMin:
271     case RecurKind::SMax:
272     case RecurKind::UMin:
273     case RecurKind::UMax:
274     case RecurKind::FMin:
275     case RecurKind::FMax:
276     case RecurKind::SelectICmp:
277     case RecurKind::SelectFCmp:
278     case RecurKind::FMulAdd:
279       return true;
280     default:
281       return false;
282     }
283   }
284 
285   unsigned getMaxInterleaveFactor(unsigned VF) {
286     // If the loop will not be vectorized, don't interleave the loop.
287     // Let regular unroll to unroll the loop.
288     return VF == 1 ? 1 : ST->getMaxInterleaveFactor();
289   }
290 
291   enum RISCVRegisterClass { GPRRC, FPRRC, VRRC };
292   unsigned getNumberOfRegisters(unsigned ClassID) const {
293     switch (ClassID) {
294     case RISCVRegisterClass::GPRRC:
295       // 31 = 32 GPR - x0 (zero register)
296       // FIXME: Should we exclude fixed registers like SP, TP or GP?
297       return 31;
298     case RISCVRegisterClass::FPRRC:
299       if (ST->hasStdExtF())
300         return 32;
301       return 0;
302     case RISCVRegisterClass::VRRC:
303       // Although there are 32 vector registers, v0 is special in that it is the
304       // only register that can be used to hold a mask.
305       // FIXME: Should we conservatively return 31 as the number of usable
306       // vector registers?
307       return ST->hasVInstructions() ? 32 : 0;
308     }
309     llvm_unreachable("unknown register class");
310   }
311 
312   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
313     if (Vector)
314       return RISCVRegisterClass::VRRC;
315     if (!Ty)
316       return RISCVRegisterClass::GPRRC;
317 
318     Type *ScalarTy = Ty->getScalarType();
319     if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhOrZfhmin()) ||
320         (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
321         (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
322       return RISCVRegisterClass::FPRRC;
323     }
324 
325     return RISCVRegisterClass::GPRRC;
326   }
327 
328   const char *getRegisterClassName(unsigned ClassID) const {
329     switch (ClassID) {
330     case RISCVRegisterClass::GPRRC:
331       return "RISCV::GPRRC";
332     case RISCVRegisterClass::FPRRC:
333       return "RISCV::FPRRC";
334     case RISCVRegisterClass::VRRC:
335       return "RISCV::VRRC";
336     }
337     llvm_unreachable("unknown register class");
338   }
339 
340   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
341                      const TargetTransformInfo::LSRCost &C2);
342 };
343 
344 } // end namespace llvm
345 
346 #endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
347