1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 11 12 #include "llvm/ADT/SmallVector.h" 13 #include "llvm/CodeGen/SelectionDAGNodes.h" 14 #include "llvm/CodeGen/TargetLowering.h" 15 #include "llvm/IR/Constants.h" 16 #include "llvm/Support/BranchProbability.h" 17 18 namespace llvm { 19 20 class FunctionLoweringInfo; 21 class MachineBasicBlock; 22 23 namespace SwitchCG { 24 25 enum CaseClusterKind { 26 /// A cluster of adjacent case labels with the same destination, or just one 27 /// case. 28 CC_Range, 29 /// A cluster of cases suitable for jump table lowering. 30 CC_JumpTable, 31 /// A cluster of cases suitable for bit test lowering. 32 CC_BitTests 33 }; 34 35 /// A cluster of case labels. 36 struct CaseCluster { 37 CaseClusterKind Kind; 38 const ConstantInt *Low, *High; 39 union { 40 MachineBasicBlock *MBB; 41 unsigned JTCasesIndex; 42 unsigned BTCasesIndex; 43 }; 44 BranchProbability Prob; 45 46 static CaseCluster range(const ConstantInt *Low, const ConstantInt *High, 47 MachineBasicBlock *MBB, BranchProbability Prob) { 48 CaseCluster C; 49 C.Kind = CC_Range; 50 C.Low = Low; 51 C.High = High; 52 C.MBB = MBB; 53 C.Prob = Prob; 54 return C; 55 } 56 57 static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, 58 unsigned JTCasesIndex, BranchProbability Prob) { 59 CaseCluster C; 60 C.Kind = CC_JumpTable; 61 C.Low = Low; 62 C.High = High; 63 C.JTCasesIndex = JTCasesIndex; 64 C.Prob = Prob; 65 return C; 66 } 67 68 static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, 69 unsigned BTCasesIndex, BranchProbability Prob) { 70 CaseCluster C; 71 C.Kind = CC_BitTests; 72 C.Low = Low; 73 C.High = High; 74 C.BTCasesIndex = BTCasesIndex; 75 C.Prob = Prob; 76 return C; 77 } 78 }; 79 80 using CaseClusterVector = std::vector<CaseCluster>; 81 using CaseClusterIt = CaseClusterVector::iterator; 82 83 /// Sort Clusters and merge adjacent cases. 84 void sortAndRangeify(CaseClusterVector &Clusters); 85 86 struct CaseBits { 87 uint64_t Mask = 0; 88 MachineBasicBlock *BB = nullptr; 89 unsigned Bits = 0; 90 BranchProbability ExtraProb; 91 92 CaseBits() = default; 93 CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits, 94 BranchProbability Prob) 95 : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {} 96 }; 97 98 using CaseBitsVector = std::vector<CaseBits>; 99 100 /// This structure is used to communicate between SelectionDAGBuilder and 101 /// SDISel for the code generation of additional basic blocks needed by 102 /// multi-case switch statements. 103 struct CaseBlock { 104 // For the GISel interface. 105 struct PredInfoPair { 106 CmpInst::Predicate Pred; 107 // Set when no comparison should be emitted. 108 bool NoCmp; 109 }; 110 union { 111 // The condition code to use for the case block's setcc node. 112 // Besides the integer condition codes, this can also be SETTRUE, in which 113 // case no comparison gets emitted. 114 ISD::CondCode CC; 115 struct PredInfoPair PredInfo; 116 }; 117 118 // The LHS/MHS/RHS of the comparison to emit. 119 // Emit by default LHS op RHS. MHS is used for range comparisons: 120 // If MHS is not null: (LHS <= MHS) and (MHS <= RHS). 121 const Value *CmpLHS, *CmpMHS, *CmpRHS; 122 123 // The block to branch to if the setcc is true/false. 124 MachineBasicBlock *TrueBB, *FalseBB; 125 126 // The block into which to emit the code for the setcc and branches. 127 MachineBasicBlock *ThisBB; 128 129 /// The debug location of the instruction this CaseBlock was 130 /// produced from. 131 SDLoc DL; 132 DebugLoc DbgLoc; 133 134 // Branch weights. 135 BranchProbability TrueProb, FalseProb; 136 137 // Constructor for SelectionDAG. 138 CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs, 139 const Value *cmpmiddle, MachineBasicBlock *truebb, 140 MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl, 141 BranchProbability trueprob = BranchProbability::getUnknown(), 142 BranchProbability falseprob = BranchProbability::getUnknown()) 143 : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs), 144 TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl), 145 TrueProb(trueprob), FalseProb(falseprob) {} 146 147 // Constructor for GISel. 148 CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs, 149 const Value *cmprhs, const Value *cmpmiddle, 150 MachineBasicBlock *truebb, MachineBasicBlock *falsebb, 151 MachineBasicBlock *me, DebugLoc dl, 152 BranchProbability trueprob = BranchProbability::getUnknown(), 153 BranchProbability falseprob = BranchProbability::getUnknown()) 154 : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle), 155 CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me), 156 DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {} 157 }; 158 159 struct JumpTable { 160 /// The virtual register containing the index of the jump table entry 161 /// to jump to. 162 unsigned Reg; 163 /// The JumpTableIndex for this jump table in the function. 164 unsigned JTI; 165 /// The MBB into which to emit the code for the indirect jump. 166 MachineBasicBlock *MBB; 167 /// The MBB of the default bb, which is a successor of the range 168 /// check MBB. This is when updating PHI nodes in successors. 169 MachineBasicBlock *Default; 170 171 JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D) 172 : Reg(R), JTI(J), MBB(M), Default(D) {} 173 }; 174 struct JumpTableHeader { 175 APInt First; 176 APInt Last; 177 const Value *SValue; 178 MachineBasicBlock *HeaderBB; 179 bool Emitted; 180 bool OmitRangeCheck; 181 182 JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H, 183 bool E = false) 184 : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H), 185 Emitted(E), OmitRangeCheck(false) {} 186 }; 187 using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>; 188 189 struct BitTestCase { 190 uint64_t Mask; 191 MachineBasicBlock *ThisBB; 192 MachineBasicBlock *TargetBB; 193 BranchProbability ExtraProb; 194 195 BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr, 196 BranchProbability Prob) 197 : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {} 198 }; 199 200 using BitTestInfo = SmallVector<BitTestCase, 3>; 201 202 struct BitTestBlock { 203 APInt First; 204 APInt Range; 205 const Value *SValue; 206 unsigned Reg; 207 MVT RegVT; 208 bool Emitted; 209 bool ContiguousRange; 210 MachineBasicBlock *Parent; 211 MachineBasicBlock *Default; 212 BitTestInfo Cases; 213 BranchProbability Prob; 214 BranchProbability DefaultProb; 215 216 BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E, 217 bool CR, MachineBasicBlock *P, MachineBasicBlock *D, 218 BitTestInfo C, BranchProbability Pr) 219 : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg), 220 RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D), 221 Cases(std::move(C)), Prob(Pr) {} 222 }; 223 224 /// Return the range of value within a range. 225 uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, 226 unsigned Last); 227 228 /// Return the number of cases within a range. 229 uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases, 230 unsigned First, unsigned Last); 231 232 struct SwitchWorkListItem { 233 MachineBasicBlock *MBB; 234 CaseClusterIt FirstCluster; 235 CaseClusterIt LastCluster; 236 const ConstantInt *GE; 237 const ConstantInt *LT; 238 BranchProbability DefaultProb; 239 }; 240 using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>; 241 242 class SwitchLowering { 243 public: 244 SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {} 245 246 void init(const TargetLowering &tli, const TargetMachine &tm, 247 const DataLayout &dl) { 248 TLI = &tli; 249 TM = &tm; 250 DL = &dl; 251 } 252 253 /// Vector of CaseBlock structures used to communicate SwitchInst code 254 /// generation information. 255 std::vector<CaseBlock> SwitchCases; 256 257 /// Vector of JumpTable structures used to communicate SwitchInst code 258 /// generation information. 259 std::vector<JumpTableBlock> JTCases; 260 261 /// Vector of BitTestBlock structures used to communicate SwitchInst code 262 /// generation information. 263 std::vector<BitTestBlock> BitTestCases; 264 265 void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, 266 MachineBasicBlock *DefaultMBB); 267 268 bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, 269 unsigned Last, const SwitchInst *SI, 270 MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster); 271 272 273 void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI); 274 275 /// Build a bit test cluster from Clusters[First..Last]. Returns false if it 276 /// decides it's not a good idea. 277 bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, 278 const SwitchInst *SI, CaseCluster &BTCluster); 279 280 virtual void addSuccessorWithProb( 281 MachineBasicBlock *Src, MachineBasicBlock *Dst, 282 BranchProbability Prob = BranchProbability::getUnknown()) = 0; 283 284 virtual ~SwitchLowering() = default; 285 286 private: 287 const TargetLowering *TLI; 288 const TargetMachine *TM; 289 const DataLayout *DL; 290 FunctionLoweringInfo &FuncInfo; 291 }; 292 293 } // namespace SwitchCG 294 } // namespace llvm 295 296 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 297 298