1 //===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file defines a TargetTransformInfo::Concept conforming object specific
10 /// to the RISC-V target machine. It uses the target's detailed information to
11 /// provide more precise answers to certain TTI queries, while letting the
12 /// target independent and default TTI implementations handle the rest.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
17 #define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
18 
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/CodeGen/BasicTTIImpl.h"
23 #include "llvm/IR/Function.h"
24 
25 namespace llvm {
26 
27 class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
28   using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
29   using TTI = TargetTransformInfo;
30 
31   friend BaseT;
32 
33   const RISCVSubtarget *ST;
34   const RISCVTargetLowering *TLI;
35 
getST()36   const RISCVSubtarget *getST() const { return ST; }
getTLI()37   const RISCVTargetLowering *getTLI() const { return TLI; }
38 
39 public:
RISCVTTIImpl(const RISCVTargetMachine * TM,const Function & F)40   explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
41       : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)),
42         TLI(ST->getTargetLowering()) {}
43 
44   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
45                                 TTI::TargetCostKind CostKind);
46   InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
47                                     const APInt &Imm, Type *Ty,
48                                     TTI::TargetCostKind CostKind,
49                                     Instruction *Inst = nullptr);
50   InstructionCost getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
51                                       const APInt &Imm, Type *Ty,
52                                       TTI::TargetCostKind CostKind);
53 
54   TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth);
55 
56   bool shouldExpandReduction(const IntrinsicInst *II) const;
supportsScalableVectors()57   bool supportsScalableVectors() const { return ST->hasStdExtV(); }
58   Optional<unsigned> getMaxVScale() const;
59 
getRegisterBitWidth(TargetTransformInfo::RegisterKind K)60   TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
61     switch (K) {
62     case TargetTransformInfo::RGK_Scalar:
63       return TypeSize::getFixed(ST->getXLen());
64     case TargetTransformInfo::RGK_FixedWidthVector:
65       return TypeSize::getFixed(
66           ST->hasStdExtV() ? ST->getMinRVVVectorSizeInBits() : 0);
67     case TargetTransformInfo::RGK_ScalableVector:
68       return TypeSize::getScalable(
69           ST->hasStdExtV() ? ST->getMinRVVVectorSizeInBits() : 0);
70     }
71 
72     llvm_unreachable("Unsupported register kind");
73   }
74 
75   InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
76                                          const Value *Ptr, bool VariableMask,
77                                          Align Alignment,
78                                          TTI::TargetCostKind CostKind,
79                                          const Instruction *I);
80 
isLegalElementTypeForRVV(Type * ScalarTy)81   bool isLegalElementTypeForRVV(Type *ScalarTy) const {
82     if (ScalarTy->isPointerTy())
83       return true;
84 
85     if (ScalarTy->isIntegerTy(8) || ScalarTy->isIntegerTy(16) ||
86         ScalarTy->isIntegerTy(32) || ScalarTy->isIntegerTy(64))
87       return true;
88 
89     if (ScalarTy->isHalfTy())
90       return ST->hasStdExtZfh();
91     if (ScalarTy->isFloatTy())
92       return ST->hasStdExtF();
93     if (ScalarTy->isDoubleTy())
94       return ST->hasStdExtD();
95 
96     return false;
97   }
98 
isLegalMaskedLoadStore(Type * DataType,Align Alignment)99   bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
100     if (!ST->hasStdExtV())
101       return false;
102 
103     // Only support fixed vectors if we know the minimum vector size.
104     if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
105       return false;
106 
107     if (Alignment <
108         DL.getTypeStoreSize(DataType->getScalarType()).getFixedSize())
109       return false;
110 
111     return isLegalElementTypeForRVV(DataType->getScalarType());
112   }
113 
isLegalMaskedLoad(Type * DataType,Align Alignment)114   bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
115     return isLegalMaskedLoadStore(DataType, Alignment);
116   }
isLegalMaskedStore(Type * DataType,Align Alignment)117   bool isLegalMaskedStore(Type *DataType, Align Alignment) {
118     return isLegalMaskedLoadStore(DataType, Alignment);
119   }
120 
isLegalMaskedGatherScatter(Type * DataType,Align Alignment)121   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
122     if (!ST->hasStdExtV())
123       return false;
124 
125     // Only support fixed vectors if we know the minimum vector size.
126     if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
127       return false;
128 
129     if (Alignment <
130         DL.getTypeStoreSize(DataType->getScalarType()).getFixedSize())
131       return false;
132 
133     return isLegalElementTypeForRVV(DataType->getScalarType());
134   }
135 
isLegalMaskedGather(Type * DataType,Align Alignment)136   bool isLegalMaskedGather(Type *DataType, Align Alignment) {
137     return isLegalMaskedGatherScatter(DataType, Alignment);
138   }
isLegalMaskedScatter(Type * DataType,Align Alignment)139   bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
140     return isLegalMaskedGatherScatter(DataType, Alignment);
141   }
142 
143   /// \returns How the target needs this vector-predicated operation to be
144   /// transformed.
145   TargetTransformInfo::VPLegalization
getVPLegalizationStrategy(const VPIntrinsic & PI)146   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
147     using VPLegalization = TargetTransformInfo::VPLegalization;
148     return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
149   }
150 
isLegalToVectorizeReduction(const RecurrenceDescriptor & RdxDesc,ElementCount VF)151   bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
152                                    ElementCount VF) const {
153     if (!ST->hasStdExtV())
154       return false;
155 
156     if (!VF.isScalable())
157       return true;
158 
159     Type *Ty = RdxDesc.getRecurrenceType();
160     if (!isLegalElementTypeForRVV(Ty))
161       return false;
162 
163     switch (RdxDesc.getRecurrenceKind()) {
164     case RecurKind::Add:
165     case RecurKind::FAdd:
166     case RecurKind::And:
167     case RecurKind::Or:
168     case RecurKind::Xor:
169     case RecurKind::SMin:
170     case RecurKind::SMax:
171     case RecurKind::UMin:
172     case RecurKind::UMax:
173     case RecurKind::FMin:
174     case RecurKind::FMax:
175       return true;
176     default:
177       return false;
178     }
179   }
180 
getMaxInterleaveFactor(unsigned VF)181   unsigned getMaxInterleaveFactor(unsigned VF) {
182     return ST->getMaxInterleaveFactor();
183   }
184 };
185 
186 } // end namespace llvm
187 
188 #endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
189