1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
11 
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/SelectionDAGNodes.h"
14 #include "llvm/CodeGen/TargetLowering.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/Support/BranchProbability.h"
17 
18 namespace llvm {
19 
20 class FunctionLoweringInfo;
21 class MachineBasicBlock;
22 
23 namespace SwitchCG {
24 
25 enum CaseClusterKind {
26   /// A cluster of adjacent case labels with the same destination, or just one
27   /// case.
28   CC_Range,
29   /// A cluster of cases suitable for jump table lowering.
30   CC_JumpTable,
31   /// A cluster of cases suitable for bit test lowering.
32   CC_BitTests
33 };
34 
35 /// A cluster of case labels.
36 struct CaseCluster {
37   CaseClusterKind Kind;
38   const ConstantInt *Low, *High;
39   union {
40     MachineBasicBlock *MBB;
41     unsigned JTCasesIndex;
42     unsigned BTCasesIndex;
43   };
44   BranchProbability Prob;
45 
46   static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
47                            MachineBasicBlock *MBB, BranchProbability Prob) {
48     CaseCluster C;
49     C.Kind = CC_Range;
50     C.Low = Low;
51     C.High = High;
52     C.MBB = MBB;
53     C.Prob = Prob;
54     return C;
55   }
56 
57   static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
58                                unsigned JTCasesIndex, BranchProbability Prob) {
59     CaseCluster C;
60     C.Kind = CC_JumpTable;
61     C.Low = Low;
62     C.High = High;
63     C.JTCasesIndex = JTCasesIndex;
64     C.Prob = Prob;
65     return C;
66   }
67 
68   static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
69                               unsigned BTCasesIndex, BranchProbability Prob) {
70     CaseCluster C;
71     C.Kind = CC_BitTests;
72     C.Low = Low;
73     C.High = High;
74     C.BTCasesIndex = BTCasesIndex;
75     C.Prob = Prob;
76     return C;
77   }
78 };
79 
80 using CaseClusterVector = std::vector<CaseCluster>;
81 using CaseClusterIt = CaseClusterVector::iterator;
82 
83 /// Sort Clusters and merge adjacent cases.
84 void sortAndRangeify(CaseClusterVector &Clusters);
85 
86 struct CaseBits {
87   uint64_t Mask = 0;
88   MachineBasicBlock *BB = nullptr;
89   unsigned Bits = 0;
90   BranchProbability ExtraProb;
91 
92   CaseBits() = default;
93   CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
94            BranchProbability Prob)
95       : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
96 };
97 
98 using CaseBitsVector = std::vector<CaseBits>;
99 
100 /// This structure is used to communicate between SelectionDAGBuilder and
101 /// SDISel for the code generation of additional basic blocks needed by
102 /// multi-case switch statements.
103 struct CaseBlock {
104   // For the GISel interface.
105   struct PredInfoPair {
106     CmpInst::Predicate Pred;
107     // Set when no comparison should be emitted.
108     bool NoCmp;
109   };
110   union {
111     // The condition code to use for the case block's setcc node.
112     // Besides the integer condition codes, this can also be SETTRUE, in which
113     // case no comparison gets emitted.
114     ISD::CondCode CC;
115     struct PredInfoPair PredInfo;
116   };
117 
118   // The LHS/MHS/RHS of the comparison to emit.
119   // Emit by default LHS op RHS. MHS is used for range comparisons:
120   // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
121   const Value *CmpLHS, *CmpMHS, *CmpRHS;
122 
123   // The block to branch to if the setcc is true/false.
124   MachineBasicBlock *TrueBB, *FalseBB;
125 
126   // The block into which to emit the code for the setcc and branches.
127   MachineBasicBlock *ThisBB;
128 
129   /// The debug location of the instruction this CaseBlock was
130   /// produced from.
131   SDLoc DL;
132   DebugLoc DbgLoc;
133 
134   // Branch weights.
135   BranchProbability TrueProb, FalseProb;
136 
137   // Constructor for SelectionDAG.
138   CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
139             const Value *cmpmiddle, MachineBasicBlock *truebb,
140             MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
141             BranchProbability trueprob = BranchProbability::getUnknown(),
142             BranchProbability falseprob = BranchProbability::getUnknown())
143       : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
144         TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
145         TrueProb(trueprob), FalseProb(falseprob) {}
146 
147   // Constructor for GISel.
148   CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
149             const Value *cmprhs, const Value *cmpmiddle,
150             MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
151             MachineBasicBlock *me, DebugLoc dl,
152             BranchProbability trueprob = BranchProbability::getUnknown(),
153             BranchProbability falseprob = BranchProbability::getUnknown())
154       : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
155         CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
156         DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
157 };
158 
159 struct JumpTable {
160   /// The virtual register containing the index of the jump table entry
161   /// to jump to.
162   unsigned Reg;
163   /// The JumpTableIndex for this jump table in the function.
164   unsigned JTI;
165   /// The MBB into which to emit the code for the indirect jump.
166   MachineBasicBlock *MBB;
167   /// The MBB of the default bb, which is a successor of the range
168   /// check MBB.  This is when updating PHI nodes in successors.
169   MachineBasicBlock *Default;
170 
171   JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
172       : Reg(R), JTI(J), MBB(M), Default(D) {}
173 };
174 struct JumpTableHeader {
175   APInt First;
176   APInt Last;
177   const Value *SValue;
178   MachineBasicBlock *HeaderBB;
179   bool Emitted;
180   bool OmitRangeCheck;
181 
182   JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
183                   bool E = false)
184       : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
185         Emitted(E), OmitRangeCheck(false) {}
186 };
187 using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
188 
189 struct BitTestCase {
190   uint64_t Mask;
191   MachineBasicBlock *ThisBB;
192   MachineBasicBlock *TargetBB;
193   BranchProbability ExtraProb;
194 
195   BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
196               BranchProbability Prob)
197       : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
198 };
199 
200 using BitTestInfo = SmallVector<BitTestCase, 3>;
201 
202 struct BitTestBlock {
203   APInt First;
204   APInt Range;
205   const Value *SValue;
206   unsigned Reg;
207   MVT RegVT;
208   bool Emitted;
209   bool ContiguousRange;
210   MachineBasicBlock *Parent;
211   MachineBasicBlock *Default;
212   BitTestInfo Cases;
213   BranchProbability Prob;
214   BranchProbability DefaultProb;
215 
216   BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
217                bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
218                BitTestInfo C, BranchProbability Pr)
219       : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
220         RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
221         Cases(std::move(C)), Prob(Pr) {}
222 };
223 
224 /// Return the range of value within a range.
225 uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
226                            unsigned Last);
227 
228 /// Return the number of cases within a range.
229 uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
230                               unsigned First, unsigned Last);
231 
232 struct SwitchWorkListItem {
233   MachineBasicBlock *MBB;
234   CaseClusterIt FirstCluster;
235   CaseClusterIt LastCluster;
236   const ConstantInt *GE;
237   const ConstantInt *LT;
238   BranchProbability DefaultProb;
239 };
240 using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
241 
242 class SwitchLowering {
243 public:
244   SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
245 
246   void init(const TargetLowering &tli, const TargetMachine &tm,
247             const DataLayout &dl) {
248     TLI = &tli;
249     TM = &tm;
250     DL = &dl;
251   }
252 
253   /// Vector of CaseBlock structures used to communicate SwitchInst code
254   /// generation information.
255   std::vector<CaseBlock> SwitchCases;
256 
257   /// Vector of JumpTable structures used to communicate SwitchInst code
258   /// generation information.
259   std::vector<JumpTableBlock> JTCases;
260 
261   /// Vector of BitTestBlock structures used to communicate SwitchInst code
262   /// generation information.
263   std::vector<BitTestBlock> BitTestCases;
264 
265   void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
266                       MachineBasicBlock *DefaultMBB);
267 
268   bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
269                       unsigned Last, const SwitchInst *SI,
270                       MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
271 
272 
273   void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
274 
275   /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
276   /// decides it's not a good idea.
277   bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
278                      const SwitchInst *SI, CaseCluster &BTCluster);
279 
280   virtual void addSuccessorWithProb(
281       MachineBasicBlock *Src, MachineBasicBlock *Dst,
282       BranchProbability Prob = BranchProbability::getUnknown()) = 0;
283 
284   virtual ~SwitchLowering() = default;
285 
286 private:
287   const TargetLowering *TLI;
288   const TargetMachine *TM;
289   const DataLayout *DL;
290   FunctionLoweringInfo &FuncInfo;
291 };
292 
293 } // namespace SwitchCG
294 } // namespace llvm
295 
296 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
297 
298