1 //===- DFAJumpThreading.cpp - Threads a switch statement inside a loop ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transform each threading path to effectively jump thread the DFA. For
10 // example, the CFG below could be transformed as follows, where the cloned
11 // blocks unconditionally branch to the next correct case based on what is
12 // identified in the analysis.
13 //
14 //          sw.bb                        sw.bb
15 //        /   |   \                    /   |   \
16 //   case1  case2  case3          case1  case2  case3
17 //        \   |   /                 |      |      |
18 //       determinator            det.2   det.3  det.1
19 //        br sw.bb                /        |        \
20 //                          sw.bb.2     sw.bb.3     sw.bb.1
21 //                           br case2    br case3    br case1§
22 //
23 // Definitions and Terminology:
24 //
25 // * Threading path:
26 //   a list of basic blocks, the exit state, and the block that determines
27 //   the next state, for which the following notation will be used:
28 //   < path of BBs that form a cycle > [ state, determinator ]
29 //
30 // * Predictable switch:
31 //   The switch variable is always a known constant so that all conditional
32 //   jumps based on switch variable can be converted to unconditional jump.
33 //
34 // * Determinator:
35 //   The basic block that determines the next state of the DFA.
36 //
37 // Representing the optimization in C-like pseudocode: the code pattern on the
38 // left could functionally be transformed to the right pattern if the switch
39 // condition is predictable.
40 //
41 //  X = A                       goto A
42 //  for (...)                   A:
43 //    switch (X)                  ...
44 //      case A                    goto B
45 //        X = B                 B:
46 //      case B                    ...
47 //        X = C                   goto C
48 //
49 // The pass first checks that switch variable X is decided by the control flow
50 // path taken in the loop; for example, in case B, the next value of X is
51 // decided to be C. It then enumerates through all paths in the loop and labels
52 // the basic blocks where the next state is decided.
53 //
54 // Using this information it creates new paths that unconditionally branch to
55 // the next case. This involves cloning code, so it only gets triggered if the
56 // amount of code duplicated is below a threshold.
57 //
58 //===----------------------------------------------------------------------===//
59 
60 #include "llvm/Transforms/Scalar/DFAJumpThreading.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/DenseMap.h"
63 #include "llvm/ADT/SmallSet.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/AssumptionCache.h"
66 #include "llvm/Analysis/CodeMetrics.h"
67 #include "llvm/Analysis/DomTreeUpdater.h"
68 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
69 #include "llvm/Analysis/TargetTransformInfo.h"
70 #include "llvm/IR/CFG.h"
71 #include "llvm/IR/Constants.h"
72 #include "llvm/IR/IntrinsicInst.h"
73 #include "llvm/InitializePasses.h"
74 #include "llvm/Pass.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Debug.h"
77 #include "llvm/Transforms/Scalar.h"
78 #include "llvm/Transforms/Utils/Cloning.h"
79 #include "llvm/Transforms/Utils/SSAUpdaterBulk.h"
80 #include "llvm/Transforms/Utils/ValueMapper.h"
81 #include <algorithm>
82 #include <deque>
83 
84 #ifdef EXPENSIVE_CHECKS
85 #include "llvm/IR/Verifier.h"
86 #endif
87 
88 using namespace llvm;
89 
90 #define DEBUG_TYPE "dfa-jump-threading"
91 
92 STATISTIC(NumTransforms, "Number of transformations done");
93 STATISTIC(NumCloned, "Number of blocks cloned");
94 STATISTIC(NumPaths, "Number of individual paths threaded");
95 
96 static cl::opt<bool>
97     ClViewCfgBefore("dfa-jump-view-cfg-before",
98                     cl::desc("View the CFG before DFA Jump Threading"),
99                     cl::Hidden, cl::init(false));
100 
101 static cl::opt<unsigned> MaxPathLength(
102     "dfa-max-path-length",
103     cl::desc("Max number of blocks searched to find a threading path"),
104     cl::Hidden, cl::init(20));
105 
106 static cl::opt<unsigned> MaxNumPaths(
107     "dfa-max-num-paths",
108     cl::desc("Max number of paths enumerated around a switch"),
109     cl::Hidden, cl::init(200));
110 
111 static cl::opt<unsigned>
112     CostThreshold("dfa-cost-threshold",
113                   cl::desc("Maximum cost accepted for the transformation"),
114                   cl::Hidden, cl::init(50));
115 
116 namespace {
117 
118 class SelectInstToUnfold {
119   SelectInst *SI;
120   PHINode *SIUse;
121 
122 public:
123   SelectInstToUnfold(SelectInst *SI, PHINode *SIUse) : SI(SI), SIUse(SIUse) {}
124 
125   SelectInst *getInst() { return SI; }
126   PHINode *getUse() { return SIUse; }
127 
128   explicit operator bool() const { return SI && SIUse; }
129 };
130 
131 void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold,
132             std::vector<SelectInstToUnfold> *NewSIsToUnfold,
133             std::vector<BasicBlock *> *NewBBs);
134 
135 class DFAJumpThreading {
136 public:
137   DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT,
138                    TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE)
139       : AC(AC), DT(DT), TTI(TTI), ORE(ORE) {}
140 
141   bool run(Function &F);
142 
143 private:
144   void
145   unfoldSelectInstrs(DominatorTree *DT,
146                      const SmallVector<SelectInstToUnfold, 4> &SelectInsts) {
147     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
148     SmallVector<SelectInstToUnfold, 4> Stack;
149     for (SelectInstToUnfold SIToUnfold : SelectInsts)
150       Stack.push_back(SIToUnfold);
151 
152     while (!Stack.empty()) {
153       SelectInstToUnfold SIToUnfold = Stack.pop_back_val();
154 
155       std::vector<SelectInstToUnfold> NewSIsToUnfold;
156       std::vector<BasicBlock *> NewBBs;
157       unfold(&DTU, SIToUnfold, &NewSIsToUnfold, &NewBBs);
158 
159       // Put newly discovered select instructions into the work list.
160       for (const SelectInstToUnfold &NewSIToUnfold : NewSIsToUnfold)
161         Stack.push_back(NewSIToUnfold);
162     }
163   }
164 
165   AssumptionCache *AC;
166   DominatorTree *DT;
167   TargetTransformInfo *TTI;
168   OptimizationRemarkEmitter *ORE;
169 };
170 
171 class DFAJumpThreadingLegacyPass : public FunctionPass {
172 public:
173   static char ID; // Pass identification
174   DFAJumpThreadingLegacyPass() : FunctionPass(ID) {}
175 
176   void getAnalysisUsage(AnalysisUsage &AU) const override {
177     AU.addRequired<AssumptionCacheTracker>();
178     AU.addRequired<DominatorTreeWrapperPass>();
179     AU.addPreserved<DominatorTreeWrapperPass>();
180     AU.addRequired<TargetTransformInfoWrapperPass>();
181     AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
182   }
183 
184   bool runOnFunction(Function &F) override {
185     if (skipFunction(F))
186       return false;
187 
188     AssumptionCache *AC =
189         &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
190     DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
191     TargetTransformInfo *TTI =
192         &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
193     OptimizationRemarkEmitter *ORE =
194         &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
195 
196     return DFAJumpThreading(AC, DT, TTI, ORE).run(F);
197   }
198 };
199 } // end anonymous namespace
200 
201 char DFAJumpThreadingLegacyPass::ID = 0;
202 INITIALIZE_PASS_BEGIN(DFAJumpThreadingLegacyPass, "dfa-jump-threading",
203                       "DFA Jump Threading", false, false)
204 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
205 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
206 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
207 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
208 INITIALIZE_PASS_END(DFAJumpThreadingLegacyPass, "dfa-jump-threading",
209                     "DFA Jump Threading", false, false)
210 
211 // Public interface to the DFA Jump Threading pass
212 FunctionPass *llvm::createDFAJumpThreadingPass() {
213   return new DFAJumpThreadingLegacyPass();
214 }
215 
216 namespace {
217 
218 /// Create a new basic block and sink \p SIToSink into it.
219 void createBasicBlockAndSinkSelectInst(
220     DomTreeUpdater *DTU, SelectInst *SI, PHINode *SIUse, SelectInst *SIToSink,
221     BasicBlock *EndBlock, StringRef NewBBName, BasicBlock **NewBlock,
222     BranchInst **NewBranch, std::vector<SelectInstToUnfold> *NewSIsToUnfold,
223     std::vector<BasicBlock *> *NewBBs) {
224   assert(SIToSink->hasOneUse());
225   assert(NewBlock);
226   assert(NewBranch);
227   *NewBlock = BasicBlock::Create(SI->getContext(), NewBBName,
228                                  EndBlock->getParent(), EndBlock);
229   NewBBs->push_back(*NewBlock);
230   *NewBranch = BranchInst::Create(EndBlock, *NewBlock);
231   SIToSink->moveBefore(*NewBranch);
232   NewSIsToUnfold->push_back(SelectInstToUnfold(SIToSink, SIUse));
233   DTU->applyUpdates({{DominatorTree::Insert, *NewBlock, EndBlock}});
234 }
235 
236 /// Unfold the select instruction held in \p SIToUnfold by replacing it with
237 /// control flow.
238 ///
239 /// Put newly discovered select instructions into \p NewSIsToUnfold. Put newly
240 /// created basic blocks into \p NewBBs.
241 ///
242 /// TODO: merge it with CodeGenPrepare::optimizeSelectInst() if possible.
243 void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold,
244             std::vector<SelectInstToUnfold> *NewSIsToUnfold,
245             std::vector<BasicBlock *> *NewBBs) {
246   SelectInst *SI = SIToUnfold.getInst();
247   PHINode *SIUse = SIToUnfold.getUse();
248   BasicBlock *StartBlock = SI->getParent();
249   BasicBlock *EndBlock = SIUse->getParent();
250   BranchInst *StartBlockTerm =
251       dyn_cast<BranchInst>(StartBlock->getTerminator());
252 
253   assert(StartBlockTerm && StartBlockTerm->isUnconditional());
254   assert(SI->hasOneUse());
255 
256   // These are the new basic blocks for the conditional branch.
257   // At least one will become an actual new basic block.
258   BasicBlock *TrueBlock = nullptr;
259   BasicBlock *FalseBlock = nullptr;
260   BranchInst *TrueBranch = nullptr;
261   BranchInst *FalseBranch = nullptr;
262 
263   // Sink select instructions to be able to unfold them later.
264   if (SelectInst *SIOp = dyn_cast<SelectInst>(SI->getTrueValue())) {
265     createBasicBlockAndSinkSelectInst(DTU, SI, SIUse, SIOp, EndBlock,
266                                       "si.unfold.true", &TrueBlock, &TrueBranch,
267                                       NewSIsToUnfold, NewBBs);
268   }
269   if (SelectInst *SIOp = dyn_cast<SelectInst>(SI->getFalseValue())) {
270     createBasicBlockAndSinkSelectInst(DTU, SI, SIUse, SIOp, EndBlock,
271                                       "si.unfold.false", &FalseBlock,
272                                       &FalseBranch, NewSIsToUnfold, NewBBs);
273   }
274 
275   // If there was nothing to sink, then arbitrarily choose the 'false' side
276   // for a new input value to the PHI.
277   if (!TrueBlock && !FalseBlock) {
278     FalseBlock = BasicBlock::Create(SI->getContext(), "si.unfold.false",
279                                     EndBlock->getParent(), EndBlock);
280     NewBBs->push_back(FalseBlock);
281     BranchInst::Create(EndBlock, FalseBlock);
282     DTU->applyUpdates({{DominatorTree::Insert, FalseBlock, EndBlock}});
283   }
284 
285   // Insert the real conditional branch based on the original condition.
286   // If we did not create a new block for one of the 'true' or 'false' paths
287   // of the condition, it means that side of the branch goes to the end block
288   // directly and the path originates from the start block from the point of
289   // view of the new PHI.
290   BasicBlock *TT = EndBlock;
291   BasicBlock *FT = EndBlock;
292   if (TrueBlock && FalseBlock) {
293     // A diamond.
294     TT = TrueBlock;
295     FT = FalseBlock;
296 
297     // Update the phi node of SI.
298     SIUse->removeIncomingValue(StartBlock, /* DeletePHIIfEmpty = */ false);
299     SIUse->addIncoming(SI->getTrueValue(), TrueBlock);
300     SIUse->addIncoming(SI->getFalseValue(), FalseBlock);
301 
302     // Update any other PHI nodes in EndBlock.
303     for (PHINode &Phi : EndBlock->phis()) {
304       if (&Phi != SIUse) {
305         Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), TrueBlock);
306         Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), FalseBlock);
307       }
308     }
309   } else {
310     BasicBlock *NewBlock = nullptr;
311     Value *SIOp1 = SI->getTrueValue();
312     Value *SIOp2 = SI->getFalseValue();
313 
314     // A triangle pointing right.
315     if (!TrueBlock) {
316       NewBlock = FalseBlock;
317       FT = FalseBlock;
318     }
319     // A triangle pointing left.
320     else {
321       NewBlock = TrueBlock;
322       TT = TrueBlock;
323       std::swap(SIOp1, SIOp2);
324     }
325 
326     // Update the phi node of SI.
327     for (unsigned Idx = 0; Idx < SIUse->getNumIncomingValues(); ++Idx) {
328       if (SIUse->getIncomingBlock(Idx) == StartBlock)
329         SIUse->setIncomingValue(Idx, SIOp1);
330     }
331     SIUse->addIncoming(SIOp2, NewBlock);
332 
333     // Update any other PHI nodes in EndBlock.
334     for (auto II = EndBlock->begin(); PHINode *Phi = dyn_cast<PHINode>(II);
335          ++II) {
336       if (Phi != SIUse)
337         Phi->addIncoming(Phi->getIncomingValueForBlock(StartBlock), NewBlock);
338     }
339   }
340   StartBlockTerm->eraseFromParent();
341   BranchInst::Create(TT, FT, SI->getCondition(), StartBlock);
342   DTU->applyUpdates({{DominatorTree::Insert, StartBlock, TT},
343                      {DominatorTree::Insert, StartBlock, FT}});
344 
345   // The select is now dead.
346   SI->eraseFromParent();
347 }
348 
349 struct ClonedBlock {
350   BasicBlock *BB;
351   uint64_t State; ///< \p State corresponds to the next value of a switch stmnt.
352 };
353 
354 typedef std::deque<BasicBlock *> PathType;
355 typedef std::vector<PathType> PathsType;
356 typedef SmallPtrSet<const BasicBlock *, 8> VisitedBlocks;
357 typedef std::vector<ClonedBlock> CloneList;
358 
359 // This data structure keeps track of all blocks that have been cloned.  If two
360 // different ThreadingPaths clone the same block for a certain state it should
361 // be reused, and it can be looked up in this map.
362 typedef DenseMap<BasicBlock *, CloneList> DuplicateBlockMap;
363 
364 // This map keeps track of all the new definitions for an instruction. This
365 // information is needed when restoring SSA form after cloning blocks.
366 typedef MapVector<Instruction *, std::vector<Instruction *>> DefMap;
367 
368 inline raw_ostream &operator<<(raw_ostream &OS, const PathType &Path) {
369   OS << "< ";
370   for (const BasicBlock *BB : Path) {
371     std::string BBName;
372     if (BB->hasName())
373       raw_string_ostream(BBName) << BB->getName();
374     else
375       raw_string_ostream(BBName) << BB;
376     OS << BBName << " ";
377   }
378   OS << ">";
379   return OS;
380 }
381 
382 /// ThreadingPath is a path in the control flow of a loop that can be threaded
383 /// by cloning necessary basic blocks and replacing conditional branches with
384 /// unconditional ones. A threading path includes a list of basic blocks, the
385 /// exit state, and the block that determines the next state.
386 struct ThreadingPath {
387   /// Exit value is DFA's exit state for the given path.
388   uint64_t getExitValue() const { return ExitVal; }
389   void setExitValue(const ConstantInt *V) {
390     ExitVal = V->getZExtValue();
391     IsExitValSet = true;
392   }
393   bool isExitValueSet() const { return IsExitValSet; }
394 
395   /// Determinator is the basic block that determines the next state of the DFA.
396   const BasicBlock *getDeterminatorBB() const { return DBB; }
397   void setDeterminator(const BasicBlock *BB) { DBB = BB; }
398 
399   /// Path is a list of basic blocks.
400   const PathType &getPath() const { return Path; }
401   void setPath(const PathType &NewPath) { Path = NewPath; }
402 
403   void print(raw_ostream &OS) const {
404     OS << Path << " [ " << ExitVal << ", " << DBB->getName() << " ]";
405   }
406 
407 private:
408   PathType Path;
409   uint64_t ExitVal;
410   const BasicBlock *DBB = nullptr;
411   bool IsExitValSet = false;
412 };
413 
414 #ifndef NDEBUG
415 inline raw_ostream &operator<<(raw_ostream &OS, const ThreadingPath &TPath) {
416   TPath.print(OS);
417   return OS;
418 }
419 #endif
420 
421 struct MainSwitch {
422   MainSwitch(SwitchInst *SI, OptimizationRemarkEmitter *ORE) {
423     if (isCandidate(SI)) {
424       Instr = SI;
425     } else {
426       ORE->emit([&]() {
427         return OptimizationRemarkMissed(DEBUG_TYPE, "SwitchNotPredictable", SI)
428                << "Switch instruction is not predictable.";
429       });
430     }
431   }
432 
433   virtual ~MainSwitch() = default;
434 
435   SwitchInst *getInstr() const { return Instr; }
436   const SmallVector<SelectInstToUnfold, 4> getSelectInsts() {
437     return SelectInsts;
438   }
439 
440 private:
441   /// Do a use-def chain traversal starting from the switch condition to see if
442   /// \p SI is a potential condidate.
443   ///
444   /// Also, collect select instructions to unfold.
445   bool isCandidate(const SwitchInst *SI) {
446     std::deque<Value *> Q;
447     SmallSet<Value *, 16> SeenValues;
448     SelectInsts.clear();
449 
450     Value *SICond = SI->getCondition();
451     LLVM_DEBUG(dbgs() << "\tSICond: " << *SICond << "\n");
452     if (!isa<PHINode>(SICond))
453       return false;
454 
455     addToQueue(SICond, Q, SeenValues);
456 
457     while (!Q.empty()) {
458       Value *Current = Q.front();
459       Q.pop_front();
460 
461       if (auto *Phi = dyn_cast<PHINode>(Current)) {
462         for (Value *Incoming : Phi->incoming_values()) {
463           addToQueue(Incoming, Q, SeenValues);
464         }
465         LLVM_DEBUG(dbgs() << "\tphi: " << *Phi << "\n");
466       } else if (SelectInst *SelI = dyn_cast<SelectInst>(Current)) {
467         if (!isValidSelectInst(SelI))
468           return false;
469         addToQueue(SelI->getTrueValue(), Q, SeenValues);
470         addToQueue(SelI->getFalseValue(), Q, SeenValues);
471         LLVM_DEBUG(dbgs() << "\tselect: " << *SelI << "\n");
472         if (auto *SelIUse = dyn_cast<PHINode>(SelI->user_back()))
473           SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse));
474       } else if (isa<Constant>(Current)) {
475         LLVM_DEBUG(dbgs() << "\tconst: " << *Current << "\n");
476         continue;
477       } else {
478         LLVM_DEBUG(dbgs() << "\tother: " << *Current << "\n");
479         // Allow unpredictable values. The hope is that those will be the
480         // initial switch values that can be ignored (they will hit the
481         // unthreaded switch) but this assumption will get checked later after
482         // paths have been enumerated (in function getStateDefMap).
483         continue;
484       }
485     }
486 
487     return true;
488   }
489 
490   void addToQueue(Value *Val, std::deque<Value *> &Q,
491                   SmallSet<Value *, 16> &SeenValues) {
492     if (SeenValues.contains(Val))
493       return;
494     Q.push_back(Val);
495     SeenValues.insert(Val);
496   }
497 
498   bool isValidSelectInst(SelectInst *SI) {
499     if (!SI->hasOneUse())
500       return false;
501 
502     Instruction *SIUse = dyn_cast<Instruction>(SI->user_back());
503     // The use of the select inst should be either a phi or another select.
504     if (!SIUse && !(isa<PHINode>(SIUse) || isa<SelectInst>(SIUse)))
505       return false;
506 
507     BasicBlock *SIBB = SI->getParent();
508 
509     // Currently, we can only expand select instructions in basic blocks with
510     // one successor.
511     BranchInst *SITerm = dyn_cast<BranchInst>(SIBB->getTerminator());
512     if (!SITerm || !SITerm->isUnconditional())
513       return false;
514 
515     if (isa<PHINode>(SIUse) &&
516         SIBB->getSingleSuccessor() != cast<Instruction>(SIUse)->getParent())
517       return false;
518 
519     // If select will not be sunk during unfolding, and it is in the same basic
520     // block as another state defining select, then cannot unfold both.
521     for (SelectInstToUnfold SIToUnfold : SelectInsts) {
522       SelectInst *PrevSI = SIToUnfold.getInst();
523       if (PrevSI->getTrueValue() != SI && PrevSI->getFalseValue() != SI &&
524           PrevSI->getParent() == SI->getParent())
525         return false;
526     }
527 
528     return true;
529   }
530 
531   SwitchInst *Instr = nullptr;
532   SmallVector<SelectInstToUnfold, 4> SelectInsts;
533 };
534 
535 struct AllSwitchPaths {
536   AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE)
537       : Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()),
538         ORE(ORE) {}
539 
540   std::vector<ThreadingPath> &getThreadingPaths() { return TPaths; }
541   unsigned getNumThreadingPaths() { return TPaths.size(); }
542   SwitchInst *getSwitchInst() { return Switch; }
543   BasicBlock *getSwitchBlock() { return SwitchBlock; }
544 
545   void run() {
546     VisitedBlocks Visited;
547     PathsType LoopPaths = paths(SwitchBlock, Visited, /* PathDepth = */ 1);
548     StateDefMap StateDef = getStateDefMap(LoopPaths);
549 
550     if (StateDef.empty()) {
551       ORE->emit([&]() {
552         return OptimizationRemarkMissed(DEBUG_TYPE, "SwitchNotPredictable",
553                                         Switch)
554                << "Switch instruction is not predictable.";
555       });
556       return;
557     }
558 
559     for (PathType Path : LoopPaths) {
560       ThreadingPath TPath;
561 
562       const BasicBlock *PrevBB = Path.back();
563       for (const BasicBlock *BB : Path) {
564         if (StateDef.count(BB) != 0) {
565           const PHINode *Phi = dyn_cast<PHINode>(StateDef[BB]);
566           assert(Phi && "Expected a state-defining instr to be a phi node.");
567 
568           const Value *V = Phi->getIncomingValueForBlock(PrevBB);
569           if (const ConstantInt *C = dyn_cast<const ConstantInt>(V)) {
570             TPath.setExitValue(C);
571             TPath.setDeterminator(BB);
572             TPath.setPath(Path);
573           }
574         }
575 
576         // Switch block is the determinator, this is the final exit value.
577         if (TPath.isExitValueSet() && BB == Path.front())
578           break;
579 
580         PrevBB = BB;
581       }
582 
583       if (TPath.isExitValueSet() && isSupported(TPath))
584         TPaths.push_back(TPath);
585     }
586   }
587 
588 private:
589   // Value: an instruction that defines a switch state;
590   // Key: the parent basic block of that instruction.
591   typedef DenseMap<const BasicBlock *, const PHINode *> StateDefMap;
592 
593   PathsType paths(BasicBlock *BB, VisitedBlocks &Visited,
594                   unsigned PathDepth) const {
595     PathsType Res;
596 
597     // Stop exploring paths after visiting MaxPathLength blocks
598     if (PathDepth > MaxPathLength) {
599       ORE->emit([&]() {
600         return OptimizationRemarkAnalysis(DEBUG_TYPE, "MaxPathLengthReached",
601                                           Switch)
602                << "Exploration stopped after visiting MaxPathLength="
603                << ore::NV("MaxPathLength", MaxPathLength) << " blocks.";
604       });
605       return Res;
606     }
607 
608     Visited.insert(BB);
609 
610     // Some blocks have multiple edges to the same successor, and this set
611     // is used to prevent a duplicate path from being generated
612     SmallSet<BasicBlock *, 4> Successors;
613     for (BasicBlock *Succ : successors(BB)) {
614       if (!Successors.insert(Succ).second)
615         continue;
616 
617       // Found a cycle through the SwitchBlock
618       if (Succ == SwitchBlock) {
619         Res.push_back({BB});
620         continue;
621       }
622 
623       // We have encountered a cycle, do not get caught in it
624       if (Visited.contains(Succ))
625         continue;
626 
627       PathsType SuccPaths = paths(Succ, Visited, PathDepth + 1);
628       for (PathType Path : SuccPaths) {
629         PathType NewPath(Path);
630         NewPath.push_front(BB);
631         Res.push_back(NewPath);
632         if (Res.size() >= MaxNumPaths) {
633           return Res;
634         }
635       }
636     }
637     // This block could now be visited again from a different predecessor. Note
638     // that this will result in exponential runtime. Subpaths could possibly be
639     // cached but it takes a lot of memory to store them.
640     Visited.erase(BB);
641     return Res;
642   }
643 
644   /// Walk the use-def chain and collect all the state-defining instructions.
645   ///
646   /// Return an empty map if unpredictable values encountered inside the basic
647   /// blocks of \p LoopPaths.
648   StateDefMap getStateDefMap(const PathsType &LoopPaths) const {
649     StateDefMap Res;
650 
651     // Basic blocks belonging to any of the loops around the switch statement.
652     SmallPtrSet<BasicBlock *, 16> LoopBBs;
653     for (const PathType &Path : LoopPaths) {
654       for (BasicBlock *BB : Path)
655         LoopBBs.insert(BB);
656     }
657 
658     Value *FirstDef = Switch->getOperand(0);
659 
660     assert(isa<PHINode>(FirstDef) && "The first definition must be a phi.");
661 
662     SmallVector<PHINode *, 8> Stack;
663     Stack.push_back(dyn_cast<PHINode>(FirstDef));
664     SmallSet<Value *, 16> SeenValues;
665 
666     while (!Stack.empty()) {
667       PHINode *CurPhi = Stack.pop_back_val();
668 
669       Res[CurPhi->getParent()] = CurPhi;
670       SeenValues.insert(CurPhi);
671 
672       for (BasicBlock *IncomingBB : CurPhi->blocks()) {
673         Value *Incoming = CurPhi->getIncomingValueForBlock(IncomingBB);
674         bool IsOutsideLoops = LoopBBs.count(IncomingBB) == 0;
675         if (Incoming == FirstDef || isa<ConstantInt>(Incoming) ||
676             SeenValues.contains(Incoming) || IsOutsideLoops) {
677           continue;
678         }
679 
680         // Any unpredictable value inside the loops means we must bail out.
681         if (!isa<PHINode>(Incoming))
682           return StateDefMap();
683 
684         Stack.push_back(cast<PHINode>(Incoming));
685       }
686     }
687 
688     return Res;
689   }
690 
691   /// The determinator BB should precede the switch-defining BB.
692   ///
693   /// Otherwise, it is possible that the state defined in the determinator block
694   /// defines the state for the next iteration of the loop, rather than for the
695   /// current one.
696   ///
697   /// Currently supported paths:
698   /// \code
699   /// < switch bb1 determ def > [ 42, determ ]
700   /// < switch_and_def bb1 determ > [ 42, determ ]
701   /// < switch_and_def_and_determ bb1 > [ 42, switch_and_def_and_determ ]
702   /// \endcode
703   ///
704   /// Unsupported paths:
705   /// \code
706   /// < switch bb1 def determ > [ 43, determ ]
707   /// < switch_and_determ bb1 def > [ 43, switch_and_determ ]
708   /// \endcode
709   bool isSupported(const ThreadingPath &TPath) {
710     Instruction *SwitchCondI = dyn_cast<Instruction>(Switch->getCondition());
711     assert(SwitchCondI);
712     if (!SwitchCondI)
713       return false;
714 
715     const BasicBlock *SwitchCondDefBB = SwitchCondI->getParent();
716     const BasicBlock *SwitchCondUseBB = Switch->getParent();
717     const BasicBlock *DeterminatorBB = TPath.getDeterminatorBB();
718 
719     assert(
720         SwitchCondUseBB == TPath.getPath().front() &&
721         "The first BB in a threading path should have the switch instruction");
722     if (SwitchCondUseBB != TPath.getPath().front())
723       return false;
724 
725     // Make DeterminatorBB the first element in Path.
726     PathType Path = TPath.getPath();
727     auto ItDet = std::find(Path.begin(), Path.end(), DeterminatorBB);
728     std::rotate(Path.begin(), ItDet, Path.end());
729 
730     bool IsDetBBSeen = false;
731     bool IsDefBBSeen = false;
732     bool IsUseBBSeen = false;
733     for (BasicBlock *BB : Path) {
734       if (BB == DeterminatorBB)
735         IsDetBBSeen = true;
736       if (BB == SwitchCondDefBB)
737         IsDefBBSeen = true;
738       if (BB == SwitchCondUseBB)
739         IsUseBBSeen = true;
740       if (IsDetBBSeen && IsUseBBSeen && !IsDefBBSeen)
741         return false;
742     }
743 
744     return true;
745   }
746 
747   SwitchInst *Switch;
748   BasicBlock *SwitchBlock;
749   OptimizationRemarkEmitter *ORE;
750   std::vector<ThreadingPath> TPaths;
751 };
752 
753 struct TransformDFA {
754   TransformDFA(AllSwitchPaths *SwitchPaths, DominatorTree *DT,
755                AssumptionCache *AC, TargetTransformInfo *TTI,
756                OptimizationRemarkEmitter *ORE,
757                SmallPtrSet<const Value *, 32> EphValues)
758       : SwitchPaths(SwitchPaths), DT(DT), AC(AC), TTI(TTI), ORE(ORE),
759         EphValues(EphValues) {}
760 
761   void run() {
762     if (isLegalAndProfitableToTransform()) {
763       createAllExitPaths();
764       NumTransforms++;
765     }
766   }
767 
768 private:
769   /// This function performs both a legality check and profitability check at
770   /// the same time since it is convenient to do so. It iterates through all
771   /// blocks that will be cloned, and keeps track of the duplication cost. It
772   /// also returns false if it is illegal to clone some required block.
773   bool isLegalAndProfitableToTransform() {
774     CodeMetrics Metrics;
775     SwitchInst *Switch = SwitchPaths->getSwitchInst();
776 
777     // Note that DuplicateBlockMap is not being used as intended here. It is
778     // just being used to ensure (BB, State) pairs are only counted once.
779     DuplicateBlockMap DuplicateMap;
780 
781     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
782       PathType PathBBs = TPath.getPath();
783       uint64_t NextState = TPath.getExitValue();
784       const BasicBlock *Determinator = TPath.getDeterminatorBB();
785 
786       // Update Metrics for the Switch block, this is always cloned
787       BasicBlock *BB = SwitchPaths->getSwitchBlock();
788       BasicBlock *VisitedBB = getClonedBB(BB, NextState, DuplicateMap);
789       if (!VisitedBB) {
790         Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
791         DuplicateMap[BB].push_back({BB, NextState});
792       }
793 
794       // If the Switch block is the Determinator, then we can continue since
795       // this is the only block that is cloned and we already counted for it.
796       if (PathBBs.front() == Determinator)
797         continue;
798 
799       // Otherwise update Metrics for all blocks that will be cloned. If any
800       // block is already cloned and would be reused, don't double count it.
801       auto DetIt = std::find(PathBBs.begin(), PathBBs.end(), Determinator);
802       for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
803         BB = *BBIt;
804         VisitedBB = getClonedBB(BB, NextState, DuplicateMap);
805         if (VisitedBB)
806           continue;
807         Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
808         DuplicateMap[BB].push_back({BB, NextState});
809       }
810 
811       if (Metrics.notDuplicatable) {
812         LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
813                           << "non-duplicatable instructions.\n");
814         ORE->emit([&]() {
815           return OptimizationRemarkMissed(DEBUG_TYPE, "NonDuplicatableInst",
816                                           Switch)
817                  << "Contains non-duplicatable instructions.";
818         });
819         return false;
820       }
821 
822       if (Metrics.convergent) {
823         LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
824                           << "convergent instructions.\n");
825         ORE->emit([&]() {
826           return OptimizationRemarkMissed(DEBUG_TYPE, "ConvergentInst", Switch)
827                  << "Contains convergent instructions.";
828         });
829         return false;
830       }
831 
832       if (!Metrics.NumInsts.isValid()) {
833         LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
834                           << "instructions with invalid cost.\n");
835         ORE->emit([&]() {
836           return OptimizationRemarkMissed(DEBUG_TYPE, "ConvergentInst", Switch)
837                  << "Contains instructions with invalid cost.";
838         });
839         return false;
840       }
841     }
842 
843     unsigned DuplicationCost = 0;
844 
845     unsigned JumpTableSize = 0;
846     TTI->getEstimatedNumberOfCaseClusters(*Switch, JumpTableSize, nullptr,
847                                           nullptr);
848     if (JumpTableSize == 0) {
849       // Factor in the number of conditional branches reduced from jump
850       // threading. Assume that lowering the switch block is implemented by
851       // using binary search, hence the LogBase2().
852       unsigned CondBranches =
853           APInt(32, Switch->getNumSuccessors()).ceilLogBase2();
854       DuplicationCost = *Metrics.NumInsts.getValue() / CondBranches;
855     } else {
856       // Compared with jump tables, the DFA optimizer removes an indirect branch
857       // on each loop iteration, thus making branch prediction more precise. The
858       // more branch targets there are, the more likely it is for the branch
859       // predictor to make a mistake, and the more benefit there is in the DFA
860       // optimizer. Thus, the more branch targets there are, the lower is the
861       // cost of the DFA opt.
862       DuplicationCost = *Metrics.NumInsts.getValue() / JumpTableSize;
863     }
864 
865     LLVM_DEBUG(dbgs() << "\nDFA Jump Threading: Cost to jump thread block "
866                       << SwitchPaths->getSwitchBlock()->getName()
867                       << " is: " << DuplicationCost << "\n\n");
868 
869     if (DuplicationCost > CostThreshold) {
870       LLVM_DEBUG(dbgs() << "Not jump threading, duplication cost exceeds the "
871                         << "cost threshold.\n");
872       ORE->emit([&]() {
873         return OptimizationRemarkMissed(DEBUG_TYPE, "NotProfitable", Switch)
874                << "Duplication cost exceeds the cost threshold (cost="
875                << ore::NV("Cost", DuplicationCost)
876                << ", threshold=" << ore::NV("Threshold", CostThreshold) << ").";
877       });
878       return false;
879     }
880 
881     ORE->emit([&]() {
882       return OptimizationRemark(DEBUG_TYPE, "JumpThreaded", Switch)
883              << "Switch statement jump-threaded.";
884     });
885 
886     return true;
887   }
888 
889   /// Transform each threading path to effectively jump thread the DFA.
890   void createAllExitPaths() {
891     DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Eager);
892 
893     // Move the switch block to the end of the path, since it will be duplicated
894     BasicBlock *SwitchBlock = SwitchPaths->getSwitchBlock();
895     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
896       LLVM_DEBUG(dbgs() << TPath << "\n");
897       PathType NewPath(TPath.getPath());
898       NewPath.push_back(SwitchBlock);
899       TPath.setPath(NewPath);
900     }
901 
902     // Transform the ThreadingPaths and keep track of the cloned values
903     DuplicateBlockMap DuplicateMap;
904     DefMap NewDefs;
905 
906     SmallSet<BasicBlock *, 16> BlocksToClean;
907     for (BasicBlock *BB : successors(SwitchBlock))
908       BlocksToClean.insert(BB);
909 
910     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
911       createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU);
912       NumPaths++;
913     }
914 
915     // After all paths are cloned, now update the last successor of the cloned
916     // path so it skips over the switch statement
917     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths())
918       updateLastSuccessor(TPath, DuplicateMap, &DTU);
919 
920     // For each instruction that was cloned and used outside, update its uses
921     updateSSA(NewDefs);
922 
923     // Clean PHI Nodes for the newly created blocks
924     for (BasicBlock *BB : BlocksToClean)
925       cleanPhiNodes(BB);
926   }
927 
928   /// For a specific ThreadingPath \p Path, create an exit path starting from
929   /// the determinator block.
930   ///
931   /// To remember the correct destination, we have to duplicate blocks
932   /// corresponding to each state. Also update the terminating instruction of
933   /// the predecessors, and phis in the successor blocks.
934   void createExitPath(DefMap &NewDefs, ThreadingPath &Path,
935                       DuplicateBlockMap &DuplicateMap,
936                       SmallSet<BasicBlock *, 16> &BlocksToClean,
937                       DomTreeUpdater *DTU) {
938     uint64_t NextState = Path.getExitValue();
939     const BasicBlock *Determinator = Path.getDeterminatorBB();
940     PathType PathBBs = Path.getPath();
941 
942     // Don't select the placeholder block in front
943     if (PathBBs.front() == Determinator)
944       PathBBs.pop_front();
945 
946     auto DetIt = std::find(PathBBs.begin(), PathBBs.end(), Determinator);
947     auto Prev = std::prev(DetIt);
948     BasicBlock *PrevBB = *Prev;
949     for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
950       BasicBlock *BB = *BBIt;
951       BlocksToClean.insert(BB);
952 
953       // We already cloned BB for this NextState, now just update the branch
954       // and continue.
955       BasicBlock *NextBB = getClonedBB(BB, NextState, DuplicateMap);
956       if (NextBB) {
957         updatePredecessor(PrevBB, BB, NextBB, DTU);
958         PrevBB = NextBB;
959         continue;
960       }
961 
962       // Clone the BB and update the successor of Prev to jump to the new block
963       BasicBlock *NewBB = cloneBlockAndUpdatePredecessor(
964           BB, PrevBB, NextState, DuplicateMap, NewDefs, DTU);
965       DuplicateMap[BB].push_back({NewBB, NextState});
966       BlocksToClean.insert(NewBB);
967       PrevBB = NewBB;
968     }
969   }
970 
971   /// Restore SSA form after cloning blocks.
972   ///
973   /// Each cloned block creates new defs for a variable, and the uses need to be
974   /// updated to reflect this. The uses may be replaced with a cloned value, or
975   /// some derived phi instruction. Note that all uses of a value defined in the
976   /// same block were already remapped when cloning the block.
977   void updateSSA(DefMap &NewDefs) {
978     SSAUpdaterBulk SSAUpdate;
979     SmallVector<Use *, 16> UsesToRename;
980 
981     for (auto KV : NewDefs) {
982       Instruction *I = KV.first;
983       BasicBlock *BB = I->getParent();
984       std::vector<Instruction *> Cloned = KV.second;
985 
986       // Scan all uses of this instruction to see if it is used outside of its
987       // block, and if so, record them in UsesToRename.
988       for (Use &U : I->uses()) {
989         Instruction *User = cast<Instruction>(U.getUser());
990         if (PHINode *UserPN = dyn_cast<PHINode>(User)) {
991           if (UserPN->getIncomingBlock(U) == BB)
992             continue;
993         } else if (User->getParent() == BB) {
994           continue;
995         }
996 
997         UsesToRename.push_back(&U);
998       }
999 
1000       // If there are no uses outside the block, we're done with this
1001       // instruction.
1002       if (UsesToRename.empty())
1003         continue;
1004       LLVM_DEBUG(dbgs() << "DFA-JT: Renaming non-local uses of: " << *I
1005                         << "\n");
1006 
1007       // We found a use of I outside of BB.  Rename all uses of I that are
1008       // outside its block to be uses of the appropriate PHI node etc.  See
1009       // ValuesInBlocks with the values we know.
1010       unsigned VarNum = SSAUpdate.AddVariable(I->getName(), I->getType());
1011       SSAUpdate.AddAvailableValue(VarNum, BB, I);
1012       for (Instruction *New : Cloned)
1013         SSAUpdate.AddAvailableValue(VarNum, New->getParent(), New);
1014 
1015       while (!UsesToRename.empty())
1016         SSAUpdate.AddUse(VarNum, UsesToRename.pop_back_val());
1017 
1018       LLVM_DEBUG(dbgs() << "\n");
1019     }
1020     // SSAUpdater handles phi placement and renaming uses with the appropriate
1021     // value.
1022     SSAUpdate.RewriteAllUses(DT);
1023   }
1024 
1025   /// Clones a basic block, and adds it to the CFG.
1026   ///
1027   /// This function also includes updating phi nodes in the successors of the
1028   /// BB, and remapping uses that were defined locally in the cloned BB.
1029   BasicBlock *cloneBlockAndUpdatePredecessor(BasicBlock *BB, BasicBlock *PrevBB,
1030                                              uint64_t NextState,
1031                                              DuplicateBlockMap &DuplicateMap,
1032                                              DefMap &NewDefs,
1033                                              DomTreeUpdater *DTU) {
1034     ValueToValueMapTy VMap;
1035     BasicBlock *NewBB = CloneBasicBlock(
1036         BB, VMap, ".jt" + std::to_string(NextState), BB->getParent());
1037     NewBB->moveAfter(BB);
1038     NumCloned++;
1039 
1040     for (Instruction &I : *NewBB) {
1041       // Do not remap operands of PHINode in case a definition in BB is an
1042       // incoming value to a phi in the same block. This incoming value will
1043       // be renamed later while restoring SSA.
1044       if (isa<PHINode>(&I))
1045         continue;
1046       RemapInstruction(&I, VMap,
1047                        RF_IgnoreMissingLocals | RF_NoModuleLevelChanges);
1048       if (AssumeInst *II = dyn_cast<AssumeInst>(&I))
1049         AC->registerAssumption(II);
1050     }
1051 
1052     updateSuccessorPhis(BB, NewBB, NextState, VMap, DuplicateMap);
1053     updatePredecessor(PrevBB, BB, NewBB, DTU);
1054     updateDefMap(NewDefs, VMap);
1055 
1056     // Add all successors to the DominatorTree
1057     SmallPtrSet<BasicBlock *, 4> SuccSet;
1058     for (auto *SuccBB : successors(NewBB)) {
1059       if (SuccSet.insert(SuccBB).second)
1060         DTU->applyUpdates({{DominatorTree::Insert, NewBB, SuccBB}});
1061     }
1062     SuccSet.clear();
1063     return NewBB;
1064   }
1065 
1066   /// Update the phi nodes in BB's successors.
1067   ///
1068   /// This means creating a new incoming value from NewBB with the new
1069   /// instruction wherever there is an incoming value from BB.
1070   void updateSuccessorPhis(BasicBlock *BB, BasicBlock *ClonedBB,
1071                            uint64_t NextState, ValueToValueMapTy &VMap,
1072                            DuplicateBlockMap &DuplicateMap) {
1073     std::vector<BasicBlock *> BlocksToUpdate;
1074 
1075     // If BB is the last block in the path, we can simply update the one case
1076     // successor that will be reached.
1077     if (BB == SwitchPaths->getSwitchBlock()) {
1078       SwitchInst *Switch = SwitchPaths->getSwitchInst();
1079       BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1080       BlocksToUpdate.push_back(NextCase);
1081       BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
1082       if (ClonedSucc)
1083         BlocksToUpdate.push_back(ClonedSucc);
1084     }
1085     // Otherwise update phis in all successors.
1086     else {
1087       for (BasicBlock *Succ : successors(BB)) {
1088         BlocksToUpdate.push_back(Succ);
1089 
1090         // Check if a successor has already been cloned for the particular exit
1091         // value. In this case if a successor was already cloned, the phi nodes
1092         // in the cloned block should be updated directly.
1093         BasicBlock *ClonedSucc = getClonedBB(Succ, NextState, DuplicateMap);
1094         if (ClonedSucc)
1095           BlocksToUpdate.push_back(ClonedSucc);
1096       }
1097     }
1098 
1099     // If there is a phi with an incoming value from BB, create a new incoming
1100     // value for the new predecessor ClonedBB. The value will either be the same
1101     // value from BB or a cloned value.
1102     for (BasicBlock *Succ : BlocksToUpdate) {
1103       for (auto II = Succ->begin(); PHINode *Phi = dyn_cast<PHINode>(II);
1104            ++II) {
1105         Value *Incoming = Phi->getIncomingValueForBlock(BB);
1106         if (Incoming) {
1107           if (isa<Constant>(Incoming)) {
1108             Phi->addIncoming(Incoming, ClonedBB);
1109             continue;
1110           }
1111           Value *ClonedVal = VMap[Incoming];
1112           if (ClonedVal)
1113             Phi->addIncoming(ClonedVal, ClonedBB);
1114           else
1115             Phi->addIncoming(Incoming, ClonedBB);
1116         }
1117       }
1118     }
1119   }
1120 
1121   /// Sets the successor of PrevBB to be NewBB instead of OldBB. Note that all
1122   /// other successors are kept as well.
1123   void updatePredecessor(BasicBlock *PrevBB, BasicBlock *OldBB,
1124                          BasicBlock *NewBB, DomTreeUpdater *DTU) {
1125     // When a path is reused, there is a chance that predecessors were already
1126     // updated before. Check if the predecessor needs to be updated first.
1127     if (!isPredecessor(OldBB, PrevBB))
1128       return;
1129 
1130     Instruction *PrevTerm = PrevBB->getTerminator();
1131     for (unsigned Idx = 0; Idx < PrevTerm->getNumSuccessors(); Idx++) {
1132       if (PrevTerm->getSuccessor(Idx) == OldBB) {
1133         OldBB->removePredecessor(PrevBB, /* KeepOneInputPHIs = */ true);
1134         PrevTerm->setSuccessor(Idx, NewBB);
1135       }
1136     }
1137     DTU->applyUpdates({{DominatorTree::Delete, PrevBB, OldBB},
1138                        {DominatorTree::Insert, PrevBB, NewBB}});
1139   }
1140 
1141   /// Add new value mappings to the DefMap to keep track of all new definitions
1142   /// for a particular instruction. These will be used while updating SSA form.
1143   void updateDefMap(DefMap &NewDefs, ValueToValueMapTy &VMap) {
1144     SmallVector<std::pair<Instruction *, Instruction *>> NewDefsVector;
1145     NewDefsVector.reserve(VMap.size());
1146 
1147     for (auto Entry : VMap) {
1148       Instruction *Inst =
1149           dyn_cast<Instruction>(const_cast<Value *>(Entry.first));
1150       if (!Inst || !Entry.second || isa<BranchInst>(Inst) ||
1151           isa<SwitchInst>(Inst)) {
1152         continue;
1153       }
1154 
1155       Instruction *Cloned = dyn_cast<Instruction>(Entry.second);
1156       if (!Cloned)
1157         continue;
1158 
1159       NewDefsVector.push_back({Inst, Cloned});
1160     }
1161 
1162     // Sort the defs to get deterministic insertion order into NewDefs.
1163     sort(NewDefsVector, [](const auto &LHS, const auto &RHS) {
1164       if (LHS.first == RHS.first)
1165         return LHS.second->comesBefore(RHS.second);
1166       return LHS.first->comesBefore(RHS.first);
1167     });
1168 
1169     for (const auto &KV : NewDefsVector)
1170       NewDefs[KV.first].push_back(KV.second);
1171   }
1172 
1173   /// Update the last branch of a particular cloned path to point to the correct
1174   /// case successor.
1175   ///
1176   /// Note that this is an optional step and would have been done in later
1177   /// optimizations, but it makes the CFG significantly easier to work with.
1178   void updateLastSuccessor(ThreadingPath &TPath,
1179                            DuplicateBlockMap &DuplicateMap,
1180                            DomTreeUpdater *DTU) {
1181     uint64_t NextState = TPath.getExitValue();
1182     BasicBlock *BB = TPath.getPath().back();
1183     BasicBlock *LastBlock = getClonedBB(BB, NextState, DuplicateMap);
1184 
1185     // Note multiple paths can end at the same block so check that it is not
1186     // updated yet
1187     if (!isa<SwitchInst>(LastBlock->getTerminator()))
1188       return;
1189     SwitchInst *Switch = cast<SwitchInst>(LastBlock->getTerminator());
1190     BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1191 
1192     std::vector<DominatorTree::UpdateType> DTUpdates;
1193     SmallPtrSet<BasicBlock *, 4> SuccSet;
1194     for (BasicBlock *Succ : successors(LastBlock)) {
1195       if (Succ != NextCase && SuccSet.insert(Succ).second)
1196         DTUpdates.push_back({DominatorTree::Delete, LastBlock, Succ});
1197     }
1198 
1199     Switch->eraseFromParent();
1200     BranchInst::Create(NextCase, LastBlock);
1201 
1202     DTU->applyUpdates(DTUpdates);
1203   }
1204 
1205   /// After cloning blocks, some of the phi nodes have extra incoming values
1206   /// that are no longer used. This function removes them.
1207   void cleanPhiNodes(BasicBlock *BB) {
1208     // If BB is no longer reachable, remove any remaining phi nodes
1209     if (pred_empty(BB)) {
1210       std::vector<PHINode *> PhiToRemove;
1211       for (auto II = BB->begin(); PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1212         PhiToRemove.push_back(Phi);
1213       }
1214       for (PHINode *PN : PhiToRemove) {
1215         PN->replaceAllUsesWith(PoisonValue::get(PN->getType()));
1216         PN->eraseFromParent();
1217       }
1218       return;
1219     }
1220 
1221     // Remove any incoming values that come from an invalid predecessor
1222     for (auto II = BB->begin(); PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1223       std::vector<BasicBlock *> BlocksToRemove;
1224       for (BasicBlock *IncomingBB : Phi->blocks()) {
1225         if (!isPredecessor(BB, IncomingBB))
1226           BlocksToRemove.push_back(IncomingBB);
1227       }
1228       for (BasicBlock *BB : BlocksToRemove)
1229         Phi->removeIncomingValue(BB);
1230     }
1231   }
1232 
1233   /// Checks if BB was already cloned for a particular next state value. If it
1234   /// was then it returns this cloned block, and otherwise null.
1235   BasicBlock *getClonedBB(BasicBlock *BB, uint64_t NextState,
1236                           DuplicateBlockMap &DuplicateMap) {
1237     CloneList ClonedBBs = DuplicateMap[BB];
1238 
1239     // Find an entry in the CloneList with this NextState. If it exists then
1240     // return the corresponding BB
1241     auto It = llvm::find_if(ClonedBBs, [NextState](const ClonedBlock &C) {
1242       return C.State == NextState;
1243     });
1244     return It != ClonedBBs.end() ? (*It).BB : nullptr;
1245   }
1246 
1247   /// Helper to get the successor corresponding to a particular case value for
1248   /// a switch statement.
1249   BasicBlock *getNextCaseSuccessor(SwitchInst *Switch, uint64_t NextState) {
1250     BasicBlock *NextCase = nullptr;
1251     for (auto Case : Switch->cases()) {
1252       if (Case.getCaseValue()->getZExtValue() == NextState) {
1253         NextCase = Case.getCaseSuccessor();
1254         break;
1255       }
1256     }
1257     if (!NextCase)
1258       NextCase = Switch->getDefaultDest();
1259     return NextCase;
1260   }
1261 
1262   /// Returns true if IncomingBB is a predecessor of BB.
1263   bool isPredecessor(BasicBlock *BB, BasicBlock *IncomingBB) {
1264     return llvm::is_contained(predecessors(BB), IncomingBB);
1265   }
1266 
1267   AllSwitchPaths *SwitchPaths;
1268   DominatorTree *DT;
1269   AssumptionCache *AC;
1270   TargetTransformInfo *TTI;
1271   OptimizationRemarkEmitter *ORE;
1272   SmallPtrSet<const Value *, 32> EphValues;
1273   std::vector<ThreadingPath> TPaths;
1274 };
1275 
1276 bool DFAJumpThreading::run(Function &F) {
1277   LLVM_DEBUG(dbgs() << "\nDFA Jump threading: " << F.getName() << "\n");
1278 
1279   if (F.hasOptSize()) {
1280     LLVM_DEBUG(dbgs() << "Skipping due to the 'minsize' attribute\n");
1281     return false;
1282   }
1283 
1284   if (ClViewCfgBefore)
1285     F.viewCFG();
1286 
1287   SmallVector<AllSwitchPaths, 2> ThreadableLoops;
1288   bool MadeChanges = false;
1289 
1290   for (BasicBlock &BB : F) {
1291     auto *SI = dyn_cast<SwitchInst>(BB.getTerminator());
1292     if (!SI)
1293       continue;
1294 
1295     LLVM_DEBUG(dbgs() << "\nCheck if SwitchInst in BB " << BB.getName()
1296                       << " is a candidate\n");
1297     MainSwitch Switch(SI, ORE);
1298 
1299     if (!Switch.getInstr())
1300       continue;
1301 
1302     LLVM_DEBUG(dbgs() << "\nSwitchInst in BB " << BB.getName() << " is a "
1303                       << "candidate for jump threading\n");
1304     LLVM_DEBUG(SI->dump());
1305 
1306     unfoldSelectInstrs(DT, Switch.getSelectInsts());
1307     if (!Switch.getSelectInsts().empty())
1308       MadeChanges = true;
1309 
1310     AllSwitchPaths SwitchPaths(&Switch, ORE);
1311     SwitchPaths.run();
1312 
1313     if (SwitchPaths.getNumThreadingPaths() > 0) {
1314       ThreadableLoops.push_back(SwitchPaths);
1315 
1316       // For the time being limit this optimization to occurring once in a
1317       // function since it can change the CFG significantly. This is not a
1318       // strict requirement but it can cause buggy behavior if there is an
1319       // overlap of blocks in different opportunities. There is a lot of room to
1320       // experiment with catching more opportunities here.
1321       break;
1322     }
1323   }
1324 
1325   SmallPtrSet<const Value *, 32> EphValues;
1326   if (ThreadableLoops.size() > 0)
1327     CodeMetrics::collectEphemeralValues(&F, AC, EphValues);
1328 
1329   for (AllSwitchPaths SwitchPaths : ThreadableLoops) {
1330     TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues);
1331     Transform.run();
1332     MadeChanges = true;
1333   }
1334 
1335 #ifdef EXPENSIVE_CHECKS
1336   assert(DT->verify(DominatorTree::VerificationLevel::Full));
1337   verifyFunction(F, &dbgs());
1338 #endif
1339 
1340   return MadeChanges;
1341 }
1342 
1343 } // end anonymous namespace
1344 
1345 /// Integrate with the new Pass Manager
1346 PreservedAnalyses DFAJumpThreadingPass::run(Function &F,
1347                                             FunctionAnalysisManager &AM) {
1348   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
1349   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
1350   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
1351   OptimizationRemarkEmitter ORE(&F);
1352 
1353   if (!DFAJumpThreading(&AC, &DT, &TTI, &ORE).run(F))
1354     return PreservedAnalyses::all();
1355 
1356   PreservedAnalyses PA;
1357   PA.preserve<DominatorTreeAnalysis>();
1358   return PA;
1359 }
1360