1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
11 
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/SelectionDAGNodes.h"
14 #include "llvm/CodeGen/TargetLowering.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/Support/BranchProbability.h"
17 
18 namespace llvm {
19 
20 class FunctionLoweringInfo;
21 class MachineBasicBlock;
22 class BlockFrequencyInfo;
23 
24 namespace SwitchCG {
25 
26 enum CaseClusterKind {
27   /// A cluster of adjacent case labels with the same destination, or just one
28   /// case.
29   CC_Range,
30   /// A cluster of cases suitable for jump table lowering.
31   CC_JumpTable,
32   /// A cluster of cases suitable for bit test lowering.
33   CC_BitTests
34 };
35 
36 /// A cluster of case labels.
37 struct CaseCluster {
38   CaseClusterKind Kind;
39   const ConstantInt *Low, *High;
40   union {
41     MachineBasicBlock *MBB;
42     unsigned JTCasesIndex;
43     unsigned BTCasesIndex;
44   };
45   BranchProbability Prob;
46 
47   static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
48                            MachineBasicBlock *MBB, BranchProbability Prob) {
49     CaseCluster C;
50     C.Kind = CC_Range;
51     C.Low = Low;
52     C.High = High;
53     C.MBB = MBB;
54     C.Prob = Prob;
55     return C;
56   }
57 
58   static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
59                                unsigned JTCasesIndex, BranchProbability Prob) {
60     CaseCluster C;
61     C.Kind = CC_JumpTable;
62     C.Low = Low;
63     C.High = High;
64     C.JTCasesIndex = JTCasesIndex;
65     C.Prob = Prob;
66     return C;
67   }
68 
69   static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
70                               unsigned BTCasesIndex, BranchProbability Prob) {
71     CaseCluster C;
72     C.Kind = CC_BitTests;
73     C.Low = Low;
74     C.High = High;
75     C.BTCasesIndex = BTCasesIndex;
76     C.Prob = Prob;
77     return C;
78   }
79 };
80 
81 using CaseClusterVector = std::vector<CaseCluster>;
82 using CaseClusterIt = CaseClusterVector::iterator;
83 
84 /// Sort Clusters and merge adjacent cases.
85 void sortAndRangeify(CaseClusterVector &Clusters);
86 
87 struct CaseBits {
88   uint64_t Mask = 0;
89   MachineBasicBlock *BB = nullptr;
90   unsigned Bits = 0;
91   BranchProbability ExtraProb;
92 
93   CaseBits() = default;
94   CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
95            BranchProbability Prob)
96       : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
97 };
98 
99 using CaseBitsVector = std::vector<CaseBits>;
100 
101 /// This structure is used to communicate between SelectionDAGBuilder and
102 /// SDISel for the code generation of additional basic blocks needed by
103 /// multi-case switch statements.
104 struct CaseBlock {
105   // For the GISel interface.
106   struct PredInfoPair {
107     CmpInst::Predicate Pred;
108     // Set when no comparison should be emitted.
109     bool NoCmp;
110   };
111   union {
112     // The condition code to use for the case block's setcc node.
113     // Besides the integer condition codes, this can also be SETTRUE, in which
114     // case no comparison gets emitted.
115     ISD::CondCode CC;
116     struct PredInfoPair PredInfo;
117   };
118 
119   // The LHS/MHS/RHS of the comparison to emit.
120   // Emit by default LHS op RHS. MHS is used for range comparisons:
121   // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
122   const Value *CmpLHS, *CmpMHS, *CmpRHS;
123 
124   // The block to branch to if the setcc is true/false.
125   MachineBasicBlock *TrueBB, *FalseBB;
126 
127   // The block into which to emit the code for the setcc and branches.
128   MachineBasicBlock *ThisBB;
129 
130   /// The debug location of the instruction this CaseBlock was
131   /// produced from.
132   SDLoc DL;
133   DebugLoc DbgLoc;
134 
135   // Branch weights.
136   BranchProbability TrueProb, FalseProb;
137 
138   // Constructor for SelectionDAG.
139   CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
140             const Value *cmpmiddle, MachineBasicBlock *truebb,
141             MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
142             BranchProbability trueprob = BranchProbability::getUnknown(),
143             BranchProbability falseprob = BranchProbability::getUnknown())
144       : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
145         TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
146         TrueProb(trueprob), FalseProb(falseprob) {}
147 
148   // Constructor for GISel.
149   CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
150             const Value *cmprhs, const Value *cmpmiddle,
151             MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
152             MachineBasicBlock *me, DebugLoc dl,
153             BranchProbability trueprob = BranchProbability::getUnknown(),
154             BranchProbability falseprob = BranchProbability::getUnknown())
155       : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
156         CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
157         DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
158 };
159 
160 struct JumpTable {
161   /// The virtual register containing the index of the jump table entry
162   /// to jump to.
163   unsigned Reg;
164   /// The JumpTableIndex for this jump table in the function.
165   unsigned JTI;
166   /// The MBB into which to emit the code for the indirect jump.
167   MachineBasicBlock *MBB;
168   /// The MBB of the default bb, which is a successor of the range
169   /// check MBB.  This is when updating PHI nodes in successors.
170   MachineBasicBlock *Default;
171 
172   JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
173       : Reg(R), JTI(J), MBB(M), Default(D) {}
174 };
175 struct JumpTableHeader {
176   APInt First;
177   APInt Last;
178   const Value *SValue;
179   MachineBasicBlock *HeaderBB;
180   bool Emitted;
181   bool OmitRangeCheck;
182 
183   JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
184                   bool E = false)
185       : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
186         Emitted(E), OmitRangeCheck(false) {}
187 };
188 using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
189 
190 struct BitTestCase {
191   uint64_t Mask;
192   MachineBasicBlock *ThisBB;
193   MachineBasicBlock *TargetBB;
194   BranchProbability ExtraProb;
195 
196   BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
197               BranchProbability Prob)
198       : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
199 };
200 
201 using BitTestInfo = SmallVector<BitTestCase, 3>;
202 
203 struct BitTestBlock {
204   APInt First;
205   APInt Range;
206   const Value *SValue;
207   unsigned Reg;
208   MVT RegVT;
209   bool Emitted;
210   bool ContiguousRange;
211   MachineBasicBlock *Parent;
212   MachineBasicBlock *Default;
213   BitTestInfo Cases;
214   BranchProbability Prob;
215   BranchProbability DefaultProb;
216   bool OmitRangeCheck;
217 
218   BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
219                bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
220                BitTestInfo C, BranchProbability Pr)
221       : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
222         RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
223         Cases(std::move(C)), Prob(Pr), OmitRangeCheck(false) {}
224 };
225 
226 /// Return the range of values within a range.
227 uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
228                            unsigned Last);
229 
230 /// Return the number of cases within a range.
231 uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
232                               unsigned First, unsigned Last);
233 
234 struct SwitchWorkListItem {
235   MachineBasicBlock *MBB;
236   CaseClusterIt FirstCluster;
237   CaseClusterIt LastCluster;
238   const ConstantInt *GE;
239   const ConstantInt *LT;
240   BranchProbability DefaultProb;
241 };
242 using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
243 
244 class SwitchLowering {
245 public:
246   SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
247 
248   void init(const TargetLowering &tli, const TargetMachine &tm,
249             const DataLayout &dl) {
250     TLI = &tli;
251     TM = &tm;
252     DL = &dl;
253   }
254 
255   /// Vector of CaseBlock structures used to communicate SwitchInst code
256   /// generation information.
257   std::vector<CaseBlock> SwitchCases;
258 
259   /// Vector of JumpTable structures used to communicate SwitchInst code
260   /// generation information.
261   std::vector<JumpTableBlock> JTCases;
262 
263   /// Vector of BitTestBlock structures used to communicate SwitchInst code
264   /// generation information.
265   std::vector<BitTestBlock> BitTestCases;
266 
267   void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
268                       MachineBasicBlock *DefaultMBB,
269                       ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
270 
271   bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
272                       unsigned Last, const SwitchInst *SI,
273                       MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
274 
275 
276   void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
277 
278   /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
279   /// decides it's not a good idea.
280   bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
281                      const SwitchInst *SI, CaseCluster &BTCluster);
282 
283   virtual void addSuccessorWithProb(
284       MachineBasicBlock *Src, MachineBasicBlock *Dst,
285       BranchProbability Prob = BranchProbability::getUnknown()) = 0;
286 
287   virtual ~SwitchLowering() = default;
288 
289 private:
290   const TargetLowering *TLI;
291   const TargetMachine *TM;
292   const DataLayout *DL;
293   FunctionLoweringInfo &FuncInfo;
294 };
295 
296 } // namespace SwitchCG
297 } // namespace llvm
298 
299 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
300