1 //===- GVNSink.cpp - sink expressions into successors ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file GVNSink.cpp
10 /// This pass attempts to sink instructions into successors, reducing static
11 /// instruction count and enabling if-conversion.
12 ///
13 /// We use a variant of global value numbering to decide what can be sunk.
14 /// Consider:
15 ///
16 /// [ %a1 = add i32 %b, 1  ]   [ %c1 = add i32 %d, 1  ]
17 /// [ %a2 = xor i32 %a1, 1 ]   [ %c2 = xor i32 %c1, 1 ]
18 ///                  \           /
19 ///            [ %e = phi i32 %a2, %c2 ]
20 ///            [ add i32 %e, 4         ]
21 ///
22 ///
23 /// GVN would number %a1 and %c1 differently because they compute different
24 /// results - the VN of an instruction is a function of its opcode and the
25 /// transitive closure of its operands. This is the key property for hoisting
26 /// and CSE.
27 ///
28 /// What we want when sinking however is for a numbering that is a function of
29 /// the *uses* of an instruction, which allows us to answer the question "if I
30 /// replace %a1 with %c1, will it contribute in an equivalent way to all
31 /// successive instructions?". The PostValueTable class in GVN provides this
32 /// mapping.
33 //
34 //===----------------------------------------------------------------------===//
35 
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/DenseSet.h"
39 #include "llvm/ADT/Hashing.h"
40 #include "llvm/ADT/PostOrderIterator.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallPtrSet.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/Statistic.h"
45 #include "llvm/Analysis/GlobalsModRef.h"
46 #include "llvm/IR/BasicBlock.h"
47 #include "llvm/IR/CFG.h"
48 #include "llvm/IR/Constants.h"
49 #include "llvm/IR/Function.h"
50 #include "llvm/IR/InstrTypes.h"
51 #include "llvm/IR/Instruction.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/PassManager.h"
54 #include "llvm/IR/Type.h"
55 #include "llvm/IR/Use.h"
56 #include "llvm/IR/Value.h"
57 #include "llvm/Support/Allocator.h"
58 #include "llvm/Support/ArrayRecycler.h"
59 #include "llvm/Support/AtomicOrdering.h"
60 #include "llvm/Support/Casting.h"
61 #include "llvm/Support/Compiler.h"
62 #include "llvm/Support/Debug.h"
63 #include "llvm/Support/raw_ostream.h"
64 #include "llvm/Transforms/Scalar/GVN.h"
65 #include "llvm/Transforms/Scalar/GVNExpression.h"
66 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
67 #include "llvm/Transforms/Utils/Local.h"
68 #include <algorithm>
69 #include <cassert>
70 #include <cstddef>
71 #include <cstdint>
72 #include <iterator>
73 #include <utility>
74 
75 using namespace llvm;
76 
77 #define DEBUG_TYPE "gvn-sink"
78 
79 STATISTIC(NumRemoved, "Number of instructions removed");
80 
81 namespace llvm {
82 namespace GVNExpression {
83 
84 LLVM_DUMP_METHOD void Expression::dump() const {
85   print(dbgs());
86   dbgs() << "\n";
87 }
88 
89 } // end namespace GVNExpression
90 } // end namespace llvm
91 
92 namespace {
93 
94 static bool isMemoryInst(const Instruction *I) {
95   return isa<LoadInst>(I) || isa<StoreInst>(I) ||
96          (isa<InvokeInst>(I) && !cast<InvokeInst>(I)->doesNotAccessMemory()) ||
97          (isa<CallInst>(I) && !cast<CallInst>(I)->doesNotAccessMemory());
98 }
99 
100 /// Iterates through instructions in a set of blocks in reverse order from the
101 /// first non-terminator. For example (assume all blocks have size n):
102 ///   LockstepReverseIterator I([B1, B2, B3]);
103 ///   *I-- = [B1[n], B2[n], B3[n]];
104 ///   *I-- = [B1[n-1], B2[n-1], B3[n-1]];
105 ///   *I-- = [B1[n-2], B2[n-2], B3[n-2]];
106 ///   ...
107 ///
108 /// It continues until all blocks have been exhausted. Use \c getActiveBlocks()
109 /// to
110 /// determine which blocks are still going and the order they appear in the
111 /// list returned by operator*.
112 class LockstepReverseIterator {
113   ArrayRef<BasicBlock *> Blocks;
114   SmallSetVector<BasicBlock *, 4> ActiveBlocks;
115   SmallVector<Instruction *, 4> Insts;
116   bool Fail;
117 
118 public:
119   LockstepReverseIterator(ArrayRef<BasicBlock *> Blocks) : Blocks(Blocks) {
120     reset();
121   }
122 
123   void reset() {
124     Fail = false;
125     ActiveBlocks.clear();
126     for (BasicBlock *BB : Blocks)
127       ActiveBlocks.insert(BB);
128     Insts.clear();
129     for (BasicBlock *BB : Blocks) {
130       if (BB->size() <= 1) {
131         // Block wasn't big enough - only contained a terminator.
132         ActiveBlocks.remove(BB);
133         continue;
134       }
135       Insts.push_back(BB->getTerminator()->getPrevNode());
136     }
137     if (Insts.empty())
138       Fail = true;
139   }
140 
141   bool isValid() const { return !Fail; }
142   ArrayRef<Instruction *> operator*() const { return Insts; }
143 
144   // Note: This needs to return a SmallSetVector as the elements of
145   // ActiveBlocks will be later copied to Blocks using std::copy. The
146   // resultant order of elements in Blocks needs to be deterministic.
147   // Using SmallPtrSet instead causes non-deterministic order while
148   // copying. And we cannot simply sort Blocks as they need to match the
149   // corresponding Values.
150   SmallSetVector<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; }
151 
152   void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) {
153     for (auto II = Insts.begin(); II != Insts.end();) {
154       if (!Blocks.contains((*II)->getParent())) {
155         ActiveBlocks.remove((*II)->getParent());
156         II = Insts.erase(II);
157       } else {
158         ++II;
159       }
160     }
161   }
162 
163   void operator--() {
164     if (Fail)
165       return;
166     SmallVector<Instruction *, 4> NewInsts;
167     for (auto *Inst : Insts) {
168       if (Inst == &Inst->getParent()->front())
169         ActiveBlocks.remove(Inst->getParent());
170       else
171         NewInsts.push_back(Inst->getPrevNode());
172     }
173     if (NewInsts.empty()) {
174       Fail = true;
175       return;
176     }
177     Insts = NewInsts;
178   }
179 };
180 
181 //===----------------------------------------------------------------------===//
182 
183 /// Candidate solution for sinking. There may be different ways to
184 /// sink instructions, differing in the number of instructions sunk,
185 /// the number of predecessors sunk from and the number of PHIs
186 /// required.
187 struct SinkingInstructionCandidate {
188   unsigned NumBlocks;
189   unsigned NumInstructions;
190   unsigned NumPHIs;
191   unsigned NumMemoryInsts;
192   int Cost = -1;
193   SmallVector<BasicBlock *, 4> Blocks;
194 
195   void calculateCost(unsigned NumOrigPHIs, unsigned NumOrigBlocks) {
196     unsigned NumExtraPHIs = NumPHIs - NumOrigPHIs;
197     unsigned SplitEdgeCost = (NumOrigBlocks > NumBlocks) ? 2 : 0;
198     Cost = (NumInstructions * (NumBlocks - 1)) -
199            (NumExtraPHIs *
200             NumExtraPHIs) // PHIs are expensive, so make sure they're worth it.
201            - SplitEdgeCost;
202   }
203 
204   bool operator>(const SinkingInstructionCandidate &Other) const {
205     return Cost > Other.Cost;
206   }
207 };
208 
209 #ifndef NDEBUG
210 raw_ostream &operator<<(raw_ostream &OS, const SinkingInstructionCandidate &C) {
211   OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks
212      << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">";
213   return OS;
214 }
215 #endif
216 
217 //===----------------------------------------------------------------------===//
218 
219 /// Describes a PHI node that may or may not exist. These track the PHIs
220 /// that must be created if we sunk a sequence of instructions. It provides
221 /// a hash function for efficient equality comparisons.
222 class ModelledPHI {
223   SmallVector<Value *, 4> Values;
224   SmallVector<BasicBlock *, 4> Blocks;
225 
226 public:
227   ModelledPHI() = default;
228 
229   ModelledPHI(const PHINode *PN) {
230     // BasicBlock comes first so we sort by basic block pointer order, then by value pointer order.
231     SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops;
232     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I)
233       Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)});
234     llvm::sort(Ops);
235     for (auto &P : Ops) {
236       Blocks.push_back(P.first);
237       Values.push_back(P.second);
238     }
239   }
240 
241   /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI
242   /// without the same ID.
243   /// \note This is specifically for DenseMapInfo - do not use this!
244   static ModelledPHI createDummy(size_t ID) {
245     ModelledPHI M;
246     M.Values.push_back(reinterpret_cast<Value*>(ID));
247     return M;
248   }
249 
250   /// Create a PHI from an array of incoming values and incoming blocks.
251   template <typename VArray, typename BArray>
252   ModelledPHI(const VArray &V, const BArray &B) {
253     llvm::copy(V, std::back_inserter(Values));
254     llvm::copy(B, std::back_inserter(Blocks));
255   }
256 
257   /// Create a PHI from [I[OpNum] for I in Insts].
258   template <typename BArray>
259   ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) {
260     llvm::copy(B, std::back_inserter(Blocks));
261     for (auto *I : Insts)
262       Values.push_back(I->getOperand(OpNum));
263   }
264 
265   /// Restrict the PHI's contents down to only \c NewBlocks.
266   /// \c NewBlocks must be a subset of \c this->Blocks.
267   void restrictToBlocks(const SmallSetVector<BasicBlock *, 4> &NewBlocks) {
268     auto BI = Blocks.begin();
269     auto VI = Values.begin();
270     while (BI != Blocks.end()) {
271       assert(VI != Values.end());
272       if (!NewBlocks.contains(*BI)) {
273         BI = Blocks.erase(BI);
274         VI = Values.erase(VI);
275       } else {
276         ++BI;
277         ++VI;
278       }
279     }
280     assert(Blocks.size() == NewBlocks.size());
281   }
282 
283   ArrayRef<Value *> getValues() const { return Values; }
284 
285   bool areAllIncomingValuesSame() const {
286     return llvm::all_equal(Values);
287   }
288 
289   bool areAllIncomingValuesSameType() const {
290     return llvm::all_of(
291         Values, [&](Value *V) { return V->getType() == Values[0]->getType(); });
292   }
293 
294   bool areAnyIncomingValuesConstant() const {
295     return llvm::any_of(Values, [&](Value *V) { return isa<Constant>(V); });
296   }
297 
298   // Hash functor
299   unsigned hash() const {
300       return (unsigned)hash_combine_range(Values.begin(), Values.end());
301   }
302 
303   bool operator==(const ModelledPHI &Other) const {
304     return Values == Other.Values && Blocks == Other.Blocks;
305   }
306 };
307 
308 template <typename ModelledPHI> struct DenseMapInfo {
309   static inline ModelledPHI &getEmptyKey() {
310     static ModelledPHI Dummy = ModelledPHI::createDummy(0);
311     return Dummy;
312   }
313 
314   static inline ModelledPHI &getTombstoneKey() {
315     static ModelledPHI Dummy = ModelledPHI::createDummy(1);
316     return Dummy;
317   }
318 
319   static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); }
320 
321   static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) {
322     return LHS == RHS;
323   }
324 };
325 
326 using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>;
327 
328 //===----------------------------------------------------------------------===//
329 //                             ValueTable
330 //===----------------------------------------------------------------------===//
331 // This is a value number table where the value number is a function of the
332 // *uses* of a value, rather than its operands. Thus, if VN(A) == VN(B) we know
333 // that the program would be equivalent if we replaced A with PHI(A, B).
334 //===----------------------------------------------------------------------===//
335 
336 /// A GVN expression describing how an instruction is used. The operands
337 /// field of BasicExpression is used to store uses, not operands.
338 ///
339 /// This class also contains fields for discriminators used when determining
340 /// equivalence of instructions with sideeffects.
341 class InstructionUseExpr : public GVNExpression::BasicExpression {
342   unsigned MemoryUseOrder = -1;
343   bool Volatile = false;
344   ArrayRef<int> ShuffleMask;
345 
346 public:
347   InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R,
348                      BumpPtrAllocator &A)
349       : GVNExpression::BasicExpression(I->getNumUses()) {
350     allocateOperands(R, A);
351     setOpcode(I->getOpcode());
352     setType(I->getType());
353 
354     if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I))
355       ShuffleMask = SVI->getShuffleMask().copy(A);
356 
357     for (auto &U : I->uses())
358       op_push_back(U.getUser());
359     llvm::sort(op_begin(), op_end());
360   }
361 
362   void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; }
363   void setVolatile(bool V) { Volatile = V; }
364 
365   hash_code getHashValue() const override {
366     return hash_combine(GVNExpression::BasicExpression::getHashValue(),
367                         MemoryUseOrder, Volatile, ShuffleMask);
368   }
369 
370   template <typename Function> hash_code getHashValue(Function MapFn) {
371     hash_code H = hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile,
372                                ShuffleMask);
373     for (auto *V : operands())
374       H = hash_combine(H, MapFn(V));
375     return H;
376   }
377 };
378 
379 using BasicBlocksSet = SmallPtrSet<const BasicBlock *, 32>;
380 
381 class ValueTable {
382   DenseMap<Value *, uint32_t> ValueNumbering;
383   DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering;
384   DenseMap<size_t, uint32_t> HashNumbering;
385   BumpPtrAllocator Allocator;
386   ArrayRecycler<Value *> Recycler;
387   uint32_t nextValueNumber = 1;
388   BasicBlocksSet ReachableBBs;
389 
390   /// Create an expression for I based on its opcode and its uses. If I
391   /// touches or reads memory, the expression is also based upon its memory
392   /// order - see \c getMemoryUseOrder().
393   InstructionUseExpr *createExpr(Instruction *I) {
394     InstructionUseExpr *E =
395         new (Allocator) InstructionUseExpr(I, Recycler, Allocator);
396     if (isMemoryInst(I))
397       E->setMemoryUseOrder(getMemoryUseOrder(I));
398 
399     if (CmpInst *C = dyn_cast<CmpInst>(I)) {
400       CmpInst::Predicate Predicate = C->getPredicate();
401       E->setOpcode((C->getOpcode() << 8) | Predicate);
402     }
403     return E;
404   }
405 
406   /// Helper to compute the value number for a memory instruction
407   /// (LoadInst/StoreInst), including checking the memory ordering and
408   /// volatility.
409   template <class Inst> InstructionUseExpr *createMemoryExpr(Inst *I) {
410     if (isStrongerThanUnordered(I->getOrdering()) || I->isAtomic())
411       return nullptr;
412     InstructionUseExpr *E = createExpr(I);
413     E->setVolatile(I->isVolatile());
414     return E;
415   }
416 
417 public:
418   ValueTable() = default;
419 
420   /// Set basic blocks reachable from entry block.
421   void setReachableBBs(const BasicBlocksSet &ReachableBBs) {
422     this->ReachableBBs = ReachableBBs;
423   }
424 
425   /// Returns the value number for the specified value, assigning
426   /// it a new number if it did not have one before.
427   uint32_t lookupOrAdd(Value *V) {
428     auto VI = ValueNumbering.find(V);
429     if (VI != ValueNumbering.end())
430       return VI->second;
431 
432     if (!isa<Instruction>(V)) {
433       ValueNumbering[V] = nextValueNumber;
434       return nextValueNumber++;
435     }
436 
437     Instruction *I = cast<Instruction>(V);
438     if (!ReachableBBs.contains(I->getParent()))
439       return ~0U;
440 
441     InstructionUseExpr *exp = nullptr;
442     switch (I->getOpcode()) {
443     case Instruction::Load:
444       exp = createMemoryExpr(cast<LoadInst>(I));
445       break;
446     case Instruction::Store:
447       exp = createMemoryExpr(cast<StoreInst>(I));
448       break;
449     case Instruction::Call:
450     case Instruction::Invoke:
451     case Instruction::FNeg:
452     case Instruction::Add:
453     case Instruction::FAdd:
454     case Instruction::Sub:
455     case Instruction::FSub:
456     case Instruction::Mul:
457     case Instruction::FMul:
458     case Instruction::UDiv:
459     case Instruction::SDiv:
460     case Instruction::FDiv:
461     case Instruction::URem:
462     case Instruction::SRem:
463     case Instruction::FRem:
464     case Instruction::Shl:
465     case Instruction::LShr:
466     case Instruction::AShr:
467     case Instruction::And:
468     case Instruction::Or:
469     case Instruction::Xor:
470     case Instruction::ICmp:
471     case Instruction::FCmp:
472     case Instruction::Trunc:
473     case Instruction::ZExt:
474     case Instruction::SExt:
475     case Instruction::FPToUI:
476     case Instruction::FPToSI:
477     case Instruction::UIToFP:
478     case Instruction::SIToFP:
479     case Instruction::FPTrunc:
480     case Instruction::FPExt:
481     case Instruction::PtrToInt:
482     case Instruction::IntToPtr:
483     case Instruction::BitCast:
484     case Instruction::AddrSpaceCast:
485     case Instruction::Select:
486     case Instruction::ExtractElement:
487     case Instruction::InsertElement:
488     case Instruction::ShuffleVector:
489     case Instruction::InsertValue:
490     case Instruction::GetElementPtr:
491       exp = createExpr(I);
492       break;
493     default:
494       break;
495     }
496 
497     if (!exp) {
498       ValueNumbering[V] = nextValueNumber;
499       return nextValueNumber++;
500     }
501 
502     uint32_t e = ExpressionNumbering[exp];
503     if (!e) {
504       hash_code H = exp->getHashValue([=](Value *V) { return lookupOrAdd(V); });
505       auto I = HashNumbering.find(H);
506       if (I != HashNumbering.end()) {
507         e = I->second;
508       } else {
509         e = nextValueNumber++;
510         HashNumbering[H] = e;
511         ExpressionNumbering[exp] = e;
512       }
513     }
514     ValueNumbering[V] = e;
515     return e;
516   }
517 
518   /// Returns the value number of the specified value. Fails if the value has
519   /// not yet been numbered.
520   uint32_t lookup(Value *V) const {
521     auto VI = ValueNumbering.find(V);
522     assert(VI != ValueNumbering.end() && "Value not numbered?");
523     return VI->second;
524   }
525 
526   /// Removes all value numberings and resets the value table.
527   void clear() {
528     ValueNumbering.clear();
529     ExpressionNumbering.clear();
530     HashNumbering.clear();
531     Recycler.clear(Allocator);
532     nextValueNumber = 1;
533   }
534 
535   /// \c Inst uses or touches memory. Return an ID describing the memory state
536   /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2),
537   /// the exact same memory operations happen after I1 and I2.
538   ///
539   /// This is a very hard problem in general, so we use domain-specific
540   /// knowledge that we only ever check for equivalence between blocks sharing a
541   /// single immediate successor that is common, and when determining if I1 ==
542   /// I2 we will have already determined that next(I1) == next(I2). This
543   /// inductive property allows us to simply return the value number of the next
544   /// instruction that defines memory.
545   uint32_t getMemoryUseOrder(Instruction *Inst) {
546     auto *BB = Inst->getParent();
547     for (auto I = std::next(Inst->getIterator()), E = BB->end();
548          I != E && !I->isTerminator(); ++I) {
549       if (!isMemoryInst(&*I))
550         continue;
551       if (isa<LoadInst>(&*I))
552         continue;
553       CallInst *CI = dyn_cast<CallInst>(&*I);
554       if (CI && CI->onlyReadsMemory())
555         continue;
556       InvokeInst *II = dyn_cast<InvokeInst>(&*I);
557       if (II && II->onlyReadsMemory())
558         continue;
559       return lookupOrAdd(&*I);
560     }
561     return 0;
562   }
563 };
564 
565 //===----------------------------------------------------------------------===//
566 
567 class GVNSink {
568 public:
569   GVNSink() = default;
570 
571   bool run(Function &F) {
572     LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName()
573                       << "\n");
574 
575     unsigned NumSunk = 0;
576     ReversePostOrderTraversal<Function*> RPOT(&F);
577     VN.setReachableBBs(BasicBlocksSet(RPOT.begin(), RPOT.end()));
578     for (auto *N : RPOT)
579       NumSunk += sinkBB(N);
580 
581     return NumSunk > 0;
582   }
583 
584 private:
585   ValueTable VN;
586 
587   bool shouldAvoidSinkingInstruction(Instruction *I) {
588     // These instructions may change or break semantics if moved.
589     if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) ||
590         I->getType()->isTokenTy())
591       return true;
592     return false;
593   }
594 
595   /// The main heuristic function. Analyze the set of instructions pointed to by
596   /// LRI and return a candidate solution if these instructions can be sunk, or
597   /// std::nullopt otherwise.
598   std::optional<SinkingInstructionCandidate> analyzeInstructionForSinking(
599       LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum,
600       ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents);
601 
602   /// Create a ModelledPHI for each PHI in BB, adding to PHIs.
603   void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs,
604                           SmallPtrSetImpl<Value *> &PHIContents) {
605     for (PHINode &PN : BB->phis()) {
606       auto MPHI = ModelledPHI(&PN);
607       PHIs.insert(MPHI);
608       for (auto *V : MPHI.getValues())
609         PHIContents.insert(V);
610     }
611   }
612 
613   /// The main instruction sinking driver. Set up state and try and sink
614   /// instructions into BBEnd from its predecessors.
615   unsigned sinkBB(BasicBlock *BBEnd);
616 
617   /// Perform the actual mechanics of sinking an instruction from Blocks into
618   /// BBEnd, which is their only successor.
619   void sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd);
620 
621   /// Remove PHIs that all have the same incoming value.
622   void foldPointlessPHINodes(BasicBlock *BB) {
623     auto I = BB->begin();
624     while (PHINode *PN = dyn_cast<PHINode>(I++)) {
625       if (!llvm::all_of(PN->incoming_values(), [&](const Value *V) {
626             return V == PN->getIncomingValue(0);
627           }))
628         continue;
629       if (PN->getIncomingValue(0) != PN)
630         PN->replaceAllUsesWith(PN->getIncomingValue(0));
631       else
632         PN->replaceAllUsesWith(PoisonValue::get(PN->getType()));
633       PN->eraseFromParent();
634     }
635   }
636 };
637 
638 std::optional<SinkingInstructionCandidate>
639 GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI,
640                                       unsigned &InstNum,
641                                       unsigned &MemoryInstNum,
642                                       ModelledPHISet &NeededPHIs,
643                                       SmallPtrSetImpl<Value *> &PHIContents) {
644   auto Insts = *LRI;
645   LLVM_DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I
646                                                                   : Insts) {
647     I->dump();
648   } dbgs() << " ]\n";);
649 
650   DenseMap<uint32_t, unsigned> VNums;
651   for (auto *I : Insts) {
652     uint32_t N = VN.lookupOrAdd(I);
653     LLVM_DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n");
654     if (N == ~0U)
655       return std::nullopt;
656     VNums[N]++;
657   }
658   unsigned VNumToSink =
659       std::max_element(VNums.begin(), VNums.end(), llvm::less_second())->first;
660 
661   if (VNums[VNumToSink] == 1)
662     // Can't sink anything!
663     return std::nullopt;
664 
665   // Now restrict the number of incoming blocks down to only those with
666   // VNumToSink.
667   auto &ActivePreds = LRI.getActiveBlocks();
668   unsigned InitialActivePredSize = ActivePreds.size();
669   SmallVector<Instruction *, 4> NewInsts;
670   for (auto *I : Insts) {
671     if (VN.lookup(I) != VNumToSink)
672       ActivePreds.remove(I->getParent());
673     else
674       NewInsts.push_back(I);
675   }
676   for (auto *I : NewInsts)
677     if (shouldAvoidSinkingInstruction(I))
678       return std::nullopt;
679 
680   // If we've restricted the incoming blocks, restrict all needed PHIs also
681   // to that set.
682   bool RecomputePHIContents = false;
683   if (ActivePreds.size() != InitialActivePredSize) {
684     ModelledPHISet NewNeededPHIs;
685     for (auto P : NeededPHIs) {
686       P.restrictToBlocks(ActivePreds);
687       NewNeededPHIs.insert(P);
688     }
689     NeededPHIs = NewNeededPHIs;
690     LRI.restrictToBlocks(ActivePreds);
691     RecomputePHIContents = true;
692   }
693 
694   // The sunk instruction's results.
695   ModelledPHI NewPHI(NewInsts, ActivePreds);
696 
697   // Does sinking this instruction render previous PHIs redundant?
698   if (NeededPHIs.erase(NewPHI))
699     RecomputePHIContents = true;
700 
701   if (RecomputePHIContents) {
702     // The needed PHIs have changed, so recompute the set of all needed
703     // values.
704     PHIContents.clear();
705     for (auto &PHI : NeededPHIs)
706       PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
707   }
708 
709   // Is this instruction required by a later PHI that doesn't match this PHI?
710   // if so, we can't sink this instruction.
711   for (auto *V : NewPHI.getValues())
712     if (PHIContents.count(V))
713       // V exists in this PHI, but the whole PHI is different to NewPHI
714       // (else it would have been removed earlier). We cannot continue
715       // because this isn't representable.
716       return std::nullopt;
717 
718   // Which operands need PHIs?
719   // FIXME: If any of these fail, we should partition up the candidates to
720   // try and continue making progress.
721   Instruction *I0 = NewInsts[0];
722 
723   // If all instructions that are going to participate don't have the same
724   // number of operands, we can't do any useful PHI analysis for all operands.
725   auto hasDifferentNumOperands = [&I0](Instruction *I) {
726     return I->getNumOperands() != I0->getNumOperands();
727   };
728   if (any_of(NewInsts, hasDifferentNumOperands))
729     return std::nullopt;
730 
731   for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) {
732     ModelledPHI PHI(NewInsts, OpNum, ActivePreds);
733     if (PHI.areAllIncomingValuesSame())
734       continue;
735     if (!canReplaceOperandWithVariable(I0, OpNum))
736       // We can 't create a PHI from this instruction!
737       return std::nullopt;
738     if (NeededPHIs.count(PHI))
739       continue;
740     if (!PHI.areAllIncomingValuesSameType())
741       return std::nullopt;
742     // Don't create indirect calls! The called value is the final operand.
743     if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 &&
744         PHI.areAnyIncomingValuesConstant())
745       return std::nullopt;
746 
747     NeededPHIs.reserve(NeededPHIs.size());
748     NeededPHIs.insert(PHI);
749     PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
750   }
751 
752   if (isMemoryInst(NewInsts[0]))
753     ++MemoryInstNum;
754 
755   SinkingInstructionCandidate Cand;
756   Cand.NumInstructions = ++InstNum;
757   Cand.NumMemoryInsts = MemoryInstNum;
758   Cand.NumBlocks = ActivePreds.size();
759   Cand.NumPHIs = NeededPHIs.size();
760   append_range(Cand.Blocks, ActivePreds);
761 
762   return Cand;
763 }
764 
765 unsigned GVNSink::sinkBB(BasicBlock *BBEnd) {
766   LLVM_DEBUG(dbgs() << "GVNSink: running on basic block ";
767              BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
768   SmallVector<BasicBlock *, 4> Preds;
769   for (auto *B : predecessors(BBEnd)) {
770     auto *T = B->getTerminator();
771     if (isa<BranchInst>(T) || isa<SwitchInst>(T))
772       Preds.push_back(B);
773     else
774       return 0;
775   }
776   if (Preds.size() < 2)
777     return 0;
778   llvm::sort(Preds);
779 
780   unsigned NumOrigPreds = Preds.size();
781   // We can only sink instructions through unconditional branches.
782   llvm::erase_if(Preds, [](BasicBlock *BB) {
783     return BB->getTerminator()->getNumSuccessors() != 1;
784   });
785 
786   LockstepReverseIterator LRI(Preds);
787   SmallVector<SinkingInstructionCandidate, 4> Candidates;
788   unsigned InstNum = 0, MemoryInstNum = 0;
789   ModelledPHISet NeededPHIs;
790   SmallPtrSet<Value *, 4> PHIContents;
791   analyzeInitialPHIs(BBEnd, NeededPHIs, PHIContents);
792   unsigned NumOrigPHIs = NeededPHIs.size();
793 
794   while (LRI.isValid()) {
795     auto Cand = analyzeInstructionForSinking(LRI, InstNum, MemoryInstNum,
796                                              NeededPHIs, PHIContents);
797     if (!Cand)
798       break;
799     Cand->calculateCost(NumOrigPHIs, Preds.size());
800     Candidates.emplace_back(*Cand);
801     --LRI;
802   }
803 
804   llvm::stable_sort(Candidates, std::greater<SinkingInstructionCandidate>());
805   LLVM_DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C
806                                                          : Candidates) dbgs()
807                                                     << "  " << C << "\n";);
808 
809   // Pick the top candidate, as long it is positive!
810   if (Candidates.empty() || Candidates.front().Cost <= 0)
811     return 0;
812   auto C = Candidates.front();
813 
814   LLVM_DEBUG(dbgs() << " -- Sinking: " << C << "\n");
815   BasicBlock *InsertBB = BBEnd;
816   if (C.Blocks.size() < NumOrigPreds) {
817     LLVM_DEBUG(dbgs() << " -- Splitting edge to ";
818                BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
819     InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split");
820     if (!InsertBB) {
821       LLVM_DEBUG(dbgs() << " -- FAILED to split edge!\n");
822       // Edge couldn't be split.
823       return 0;
824     }
825   }
826 
827   for (unsigned I = 0; I < C.NumInstructions; ++I)
828     sinkLastInstruction(C.Blocks, InsertBB);
829 
830   return C.NumInstructions;
831 }
832 
833 void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks,
834                                   BasicBlock *BBEnd) {
835   SmallVector<Instruction *, 4> Insts;
836   for (BasicBlock *BB : Blocks)
837     Insts.push_back(BB->getTerminator()->getPrevNode());
838   Instruction *I0 = Insts.front();
839 
840   SmallVector<Value *, 4> NewOperands;
841   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) {
842     bool NeedPHI = llvm::any_of(Insts, [&I0, O](const Instruction *I) {
843       return I->getOperand(O) != I0->getOperand(O);
844     });
845     if (!NeedPHI) {
846       NewOperands.push_back(I0->getOperand(O));
847       continue;
848     }
849 
850     // Create a new PHI in the successor block and populate it.
851     auto *Op = I0->getOperand(O);
852     assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!");
853     auto *PN = PHINode::Create(Op->getType(), Insts.size(),
854                                Op->getName() + ".sink", &BBEnd->front());
855     for (auto *I : Insts)
856       PN->addIncoming(I->getOperand(O), I->getParent());
857     NewOperands.push_back(PN);
858   }
859 
860   // Arbitrarily use I0 as the new "common" instruction; remap its operands
861   // and move it to the start of the successor block.
862   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O)
863     I0->getOperandUse(O).set(NewOperands[O]);
864   I0->moveBefore(&*BBEnd->getFirstInsertionPt());
865 
866   // Update metadata and IR flags.
867   for (auto *I : Insts)
868     if (I != I0) {
869       combineMetadataForCSE(I0, I, true);
870       I0->andIRFlags(I);
871     }
872 
873   for (auto *I : Insts)
874     if (I != I0)
875       I->replaceAllUsesWith(I0);
876   foldPointlessPHINodes(BBEnd);
877 
878   // Finally nuke all instructions apart from the common instruction.
879   for (auto *I : Insts)
880     if (I != I0)
881       I->eraseFromParent();
882 
883   NumRemoved += Insts.size() - 1;
884 }
885 
886 } // end anonymous namespace
887 
888 PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) {
889   GVNSink G;
890   if (!G.run(F))
891     return PreservedAnalyses::all();
892   return PreservedAnalyses::none();
893 }
894