1 //===- GVNSink.cpp - sink expressions into successors ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file GVNSink.cpp
10 /// This pass attempts to sink instructions into successors, reducing static
11 /// instruction count and enabling if-conversion.
12 ///
13 /// We use a variant of global value numbering to decide what can be sunk.
14 /// Consider:
15 ///
16 /// [ %a1 = add i32 %b, 1  ]   [ %c1 = add i32 %d, 1  ]
17 /// [ %a2 = xor i32 %a1, 1 ]   [ %c2 = xor i32 %c1, 1 ]
18 ///                  \           /
19 ///            [ %e = phi i32 %a2, %c2 ]
20 ///            [ add i32 %e, 4         ]
21 ///
22 ///
23 /// GVN would number %a1 and %c1 differently because they compute different
24 /// results - the VN of an instruction is a function of its opcode and the
25 /// transitive closure of its operands. This is the key property for hoisting
26 /// and CSE.
27 ///
28 /// What we want when sinking however is for a numbering that is a function of
29 /// the *uses* of an instruction, which allows us to answer the question "if I
30 /// replace %a1 with %c1, will it contribute in an equivalent way to all
31 /// successive instructions?". The PostValueTable class in GVN provides this
32 /// mapping.
33 //
34 //===----------------------------------------------------------------------===//
35 
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/DenseSet.h"
39 #include "llvm/ADT/Hashing.h"
40 #include "llvm/ADT/PostOrderIterator.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallPtrSet.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/Statistic.h"
45 #include "llvm/Analysis/GlobalsModRef.h"
46 #include "llvm/IR/BasicBlock.h"
47 #include "llvm/IR/CFG.h"
48 #include "llvm/IR/Constants.h"
49 #include "llvm/IR/Function.h"
50 #include "llvm/IR/InstrTypes.h"
51 #include "llvm/IR/Instruction.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/PassManager.h"
54 #include "llvm/IR/Type.h"
55 #include "llvm/IR/Use.h"
56 #include "llvm/IR/Value.h"
57 #include "llvm/InitializePasses.h"
58 #include "llvm/Pass.h"
59 #include "llvm/Support/Allocator.h"
60 #include "llvm/Support/ArrayRecycler.h"
61 #include "llvm/Support/AtomicOrdering.h"
62 #include "llvm/Support/Casting.h"
63 #include "llvm/Support/Compiler.h"
64 #include "llvm/Support/Debug.h"
65 #include "llvm/Support/raw_ostream.h"
66 #include "llvm/Transforms/Scalar.h"
67 #include "llvm/Transforms/Scalar/GVN.h"
68 #include "llvm/Transforms/Scalar/GVNExpression.h"
69 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
70 #include "llvm/Transforms/Utils/Local.h"
71 #include <algorithm>
72 #include <cassert>
73 #include <cstddef>
74 #include <cstdint>
75 #include <iterator>
76 #include <utility>
77 
78 using namespace llvm;
79 
80 #define DEBUG_TYPE "gvn-sink"
81 
82 STATISTIC(NumRemoved, "Number of instructions removed");
83 
84 namespace llvm {
85 namespace GVNExpression {
86 
87 LLVM_DUMP_METHOD void Expression::dump() const {
88   print(dbgs());
89   dbgs() << "\n";
90 }
91 
92 } // end namespace GVNExpression
93 } // end namespace llvm
94 
95 namespace {
96 
97 static bool isMemoryInst(const Instruction *I) {
98   return isa<LoadInst>(I) || isa<StoreInst>(I) ||
99          (isa<InvokeInst>(I) && !cast<InvokeInst>(I)->doesNotAccessMemory()) ||
100          (isa<CallInst>(I) && !cast<CallInst>(I)->doesNotAccessMemory());
101 }
102 
103 /// Iterates through instructions in a set of blocks in reverse order from the
104 /// first non-terminator. For example (assume all blocks have size n):
105 ///   LockstepReverseIterator I([B1, B2, B3]);
106 ///   *I-- = [B1[n], B2[n], B3[n]];
107 ///   *I-- = [B1[n-1], B2[n-1], B3[n-1]];
108 ///   *I-- = [B1[n-2], B2[n-2], B3[n-2]];
109 ///   ...
110 ///
111 /// It continues until all blocks have been exhausted. Use \c getActiveBlocks()
112 /// to
113 /// determine which blocks are still going and the order they appear in the
114 /// list returned by operator*.
115 class LockstepReverseIterator {
116   ArrayRef<BasicBlock *> Blocks;
117   SmallSetVector<BasicBlock *, 4> ActiveBlocks;
118   SmallVector<Instruction *, 4> Insts;
119   bool Fail;
120 
121 public:
122   LockstepReverseIterator(ArrayRef<BasicBlock *> Blocks) : Blocks(Blocks) {
123     reset();
124   }
125 
126   void reset() {
127     Fail = false;
128     ActiveBlocks.clear();
129     for (BasicBlock *BB : Blocks)
130       ActiveBlocks.insert(BB);
131     Insts.clear();
132     for (BasicBlock *BB : Blocks) {
133       if (BB->size() <= 1) {
134         // Block wasn't big enough - only contained a terminator.
135         ActiveBlocks.remove(BB);
136         continue;
137       }
138       Insts.push_back(BB->getTerminator()->getPrevNode());
139     }
140     if (Insts.empty())
141       Fail = true;
142   }
143 
144   bool isValid() const { return !Fail; }
145   ArrayRef<Instruction *> operator*() const { return Insts; }
146 
147   // Note: This needs to return a SmallSetVector as the elements of
148   // ActiveBlocks will be later copied to Blocks using std::copy. The
149   // resultant order of elements in Blocks needs to be deterministic.
150   // Using SmallPtrSet instead causes non-deterministic order while
151   // copying. And we cannot simply sort Blocks as they need to match the
152   // corresponding Values.
153   SmallSetVector<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; }
154 
155   void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) {
156     for (auto II = Insts.begin(); II != Insts.end();) {
157       if (!llvm::is_contained(Blocks, (*II)->getParent())) {
158         ActiveBlocks.remove((*II)->getParent());
159         II = Insts.erase(II);
160       } else {
161         ++II;
162       }
163     }
164   }
165 
166   void operator--() {
167     if (Fail)
168       return;
169     SmallVector<Instruction *, 4> NewInsts;
170     for (auto *Inst : Insts) {
171       if (Inst == &Inst->getParent()->front())
172         ActiveBlocks.remove(Inst->getParent());
173       else
174         NewInsts.push_back(Inst->getPrevNode());
175     }
176     if (NewInsts.empty()) {
177       Fail = true;
178       return;
179     }
180     Insts = NewInsts;
181   }
182 };
183 
184 //===----------------------------------------------------------------------===//
185 
186 /// Candidate solution for sinking. There may be different ways to
187 /// sink instructions, differing in the number of instructions sunk,
188 /// the number of predecessors sunk from and the number of PHIs
189 /// required.
190 struct SinkingInstructionCandidate {
191   unsigned NumBlocks;
192   unsigned NumInstructions;
193   unsigned NumPHIs;
194   unsigned NumMemoryInsts;
195   int Cost = -1;
196   SmallVector<BasicBlock *, 4> Blocks;
197 
198   void calculateCost(unsigned NumOrigPHIs, unsigned NumOrigBlocks) {
199     unsigned NumExtraPHIs = NumPHIs - NumOrigPHIs;
200     unsigned SplitEdgeCost = (NumOrigBlocks > NumBlocks) ? 2 : 0;
201     Cost = (NumInstructions * (NumBlocks - 1)) -
202            (NumExtraPHIs *
203             NumExtraPHIs) // PHIs are expensive, so make sure they're worth it.
204            - SplitEdgeCost;
205   }
206 
207   bool operator>(const SinkingInstructionCandidate &Other) const {
208     return Cost > Other.Cost;
209   }
210 };
211 
212 #ifndef NDEBUG
213 raw_ostream &operator<<(raw_ostream &OS, const SinkingInstructionCandidate &C) {
214   OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks
215      << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">";
216   return OS;
217 }
218 #endif
219 
220 //===----------------------------------------------------------------------===//
221 
222 /// Describes a PHI node that may or may not exist. These track the PHIs
223 /// that must be created if we sunk a sequence of instructions. It provides
224 /// a hash function for efficient equality comparisons.
225 class ModelledPHI {
226   SmallVector<Value *, 4> Values;
227   SmallVector<BasicBlock *, 4> Blocks;
228 
229 public:
230   ModelledPHI() = default;
231 
232   ModelledPHI(const PHINode *PN) {
233     // BasicBlock comes first so we sort by basic block pointer order, then by value pointer order.
234     SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops;
235     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I)
236       Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)});
237     llvm::sort(Ops);
238     for (auto &P : Ops) {
239       Blocks.push_back(P.first);
240       Values.push_back(P.second);
241     }
242   }
243 
244   /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI
245   /// without the same ID.
246   /// \note This is specifically for DenseMapInfo - do not use this!
247   static ModelledPHI createDummy(size_t ID) {
248     ModelledPHI M;
249     M.Values.push_back(reinterpret_cast<Value*>(ID));
250     return M;
251   }
252 
253   /// Create a PHI from an array of incoming values and incoming blocks.
254   template <typename VArray, typename BArray>
255   ModelledPHI(const VArray &V, const BArray &B) {
256     llvm::copy(V, std::back_inserter(Values));
257     llvm::copy(B, std::back_inserter(Blocks));
258   }
259 
260   /// Create a PHI from [I[OpNum] for I in Insts].
261   template <typename BArray>
262   ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) {
263     llvm::copy(B, std::back_inserter(Blocks));
264     for (auto *I : Insts)
265       Values.push_back(I->getOperand(OpNum));
266   }
267 
268   /// Restrict the PHI's contents down to only \c NewBlocks.
269   /// \c NewBlocks must be a subset of \c this->Blocks.
270   void restrictToBlocks(const SmallSetVector<BasicBlock *, 4> &NewBlocks) {
271     auto BI = Blocks.begin();
272     auto VI = Values.begin();
273     while (BI != Blocks.end()) {
274       assert(VI != Values.end());
275       if (!llvm::is_contained(NewBlocks, *BI)) {
276         BI = Blocks.erase(BI);
277         VI = Values.erase(VI);
278       } else {
279         ++BI;
280         ++VI;
281       }
282     }
283     assert(Blocks.size() == NewBlocks.size());
284   }
285 
286   ArrayRef<Value *> getValues() const { return Values; }
287 
288   bool areAllIncomingValuesSame() const {
289     return llvm::all_equal(Values);
290   }
291 
292   bool areAllIncomingValuesSameType() const {
293     return llvm::all_of(
294         Values, [&](Value *V) { return V->getType() == Values[0]->getType(); });
295   }
296 
297   bool areAnyIncomingValuesConstant() const {
298     return llvm::any_of(Values, [&](Value *V) { return isa<Constant>(V); });
299   }
300 
301   // Hash functor
302   unsigned hash() const {
303       return (unsigned)hash_combine_range(Values.begin(), Values.end());
304   }
305 
306   bool operator==(const ModelledPHI &Other) const {
307     return Values == Other.Values && Blocks == Other.Blocks;
308   }
309 };
310 
311 template <typename ModelledPHI> struct DenseMapInfo {
312   static inline ModelledPHI &getEmptyKey() {
313     static ModelledPHI Dummy = ModelledPHI::createDummy(0);
314     return Dummy;
315   }
316 
317   static inline ModelledPHI &getTombstoneKey() {
318     static ModelledPHI Dummy = ModelledPHI::createDummy(1);
319     return Dummy;
320   }
321 
322   static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); }
323 
324   static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) {
325     return LHS == RHS;
326   }
327 };
328 
329 using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>;
330 
331 //===----------------------------------------------------------------------===//
332 //                             ValueTable
333 //===----------------------------------------------------------------------===//
334 // This is a value number table where the value number is a function of the
335 // *uses* of a value, rather than its operands. Thus, if VN(A) == VN(B) we know
336 // that the program would be equivalent if we replaced A with PHI(A, B).
337 //===----------------------------------------------------------------------===//
338 
339 /// A GVN expression describing how an instruction is used. The operands
340 /// field of BasicExpression is used to store uses, not operands.
341 ///
342 /// This class also contains fields for discriminators used when determining
343 /// equivalence of instructions with sideeffects.
344 class InstructionUseExpr : public GVNExpression::BasicExpression {
345   unsigned MemoryUseOrder = -1;
346   bool Volatile = false;
347   ArrayRef<int> ShuffleMask;
348 
349 public:
350   InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R,
351                      BumpPtrAllocator &A)
352       : GVNExpression::BasicExpression(I->getNumUses()) {
353     allocateOperands(R, A);
354     setOpcode(I->getOpcode());
355     setType(I->getType());
356 
357     if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I))
358       ShuffleMask = SVI->getShuffleMask().copy(A);
359 
360     for (auto &U : I->uses())
361       op_push_back(U.getUser());
362     llvm::sort(op_begin(), op_end());
363   }
364 
365   void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; }
366   void setVolatile(bool V) { Volatile = V; }
367 
368   hash_code getHashValue() const override {
369     return hash_combine(GVNExpression::BasicExpression::getHashValue(),
370                         MemoryUseOrder, Volatile, ShuffleMask);
371   }
372 
373   template <typename Function> hash_code getHashValue(Function MapFn) {
374     hash_code H = hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile,
375                                ShuffleMask);
376     for (auto *V : operands())
377       H = hash_combine(H, MapFn(V));
378     return H;
379   }
380 };
381 
382 using BasicBlocksSet = SmallPtrSet<const BasicBlock *, 32>;
383 
384 class ValueTable {
385   DenseMap<Value *, uint32_t> ValueNumbering;
386   DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering;
387   DenseMap<size_t, uint32_t> HashNumbering;
388   BumpPtrAllocator Allocator;
389   ArrayRecycler<Value *> Recycler;
390   uint32_t nextValueNumber = 1;
391   BasicBlocksSet ReachableBBs;
392 
393   /// Create an expression for I based on its opcode and its uses. If I
394   /// touches or reads memory, the expression is also based upon its memory
395   /// order - see \c getMemoryUseOrder().
396   InstructionUseExpr *createExpr(Instruction *I) {
397     InstructionUseExpr *E =
398         new (Allocator) InstructionUseExpr(I, Recycler, Allocator);
399     if (isMemoryInst(I))
400       E->setMemoryUseOrder(getMemoryUseOrder(I));
401 
402     if (CmpInst *C = dyn_cast<CmpInst>(I)) {
403       CmpInst::Predicate Predicate = C->getPredicate();
404       E->setOpcode((C->getOpcode() << 8) | Predicate);
405     }
406     return E;
407   }
408 
409   /// Helper to compute the value number for a memory instruction
410   /// (LoadInst/StoreInst), including checking the memory ordering and
411   /// volatility.
412   template <class Inst> InstructionUseExpr *createMemoryExpr(Inst *I) {
413     if (isStrongerThanUnordered(I->getOrdering()) || I->isAtomic())
414       return nullptr;
415     InstructionUseExpr *E = createExpr(I);
416     E->setVolatile(I->isVolatile());
417     return E;
418   }
419 
420 public:
421   ValueTable() = default;
422 
423   /// Set basic blocks reachable from entry block.
424   void setReachableBBs(const BasicBlocksSet &ReachableBBs) {
425     this->ReachableBBs = ReachableBBs;
426   }
427 
428   /// Returns the value number for the specified value, assigning
429   /// it a new number if it did not have one before.
430   uint32_t lookupOrAdd(Value *V) {
431     auto VI = ValueNumbering.find(V);
432     if (VI != ValueNumbering.end())
433       return VI->second;
434 
435     if (!isa<Instruction>(V)) {
436       ValueNumbering[V] = nextValueNumber;
437       return nextValueNumber++;
438     }
439 
440     Instruction *I = cast<Instruction>(V);
441     if (!ReachableBBs.contains(I->getParent()))
442       return ~0U;
443 
444     InstructionUseExpr *exp = nullptr;
445     switch (I->getOpcode()) {
446     case Instruction::Load:
447       exp = createMemoryExpr(cast<LoadInst>(I));
448       break;
449     case Instruction::Store:
450       exp = createMemoryExpr(cast<StoreInst>(I));
451       break;
452     case Instruction::Call:
453     case Instruction::Invoke:
454     case Instruction::FNeg:
455     case Instruction::Add:
456     case Instruction::FAdd:
457     case Instruction::Sub:
458     case Instruction::FSub:
459     case Instruction::Mul:
460     case Instruction::FMul:
461     case Instruction::UDiv:
462     case Instruction::SDiv:
463     case Instruction::FDiv:
464     case Instruction::URem:
465     case Instruction::SRem:
466     case Instruction::FRem:
467     case Instruction::Shl:
468     case Instruction::LShr:
469     case Instruction::AShr:
470     case Instruction::And:
471     case Instruction::Or:
472     case Instruction::Xor:
473     case Instruction::ICmp:
474     case Instruction::FCmp:
475     case Instruction::Trunc:
476     case Instruction::ZExt:
477     case Instruction::SExt:
478     case Instruction::FPToUI:
479     case Instruction::FPToSI:
480     case Instruction::UIToFP:
481     case Instruction::SIToFP:
482     case Instruction::FPTrunc:
483     case Instruction::FPExt:
484     case Instruction::PtrToInt:
485     case Instruction::IntToPtr:
486     case Instruction::BitCast:
487     case Instruction::AddrSpaceCast:
488     case Instruction::Select:
489     case Instruction::ExtractElement:
490     case Instruction::InsertElement:
491     case Instruction::ShuffleVector:
492     case Instruction::InsertValue:
493     case Instruction::GetElementPtr:
494       exp = createExpr(I);
495       break;
496     default:
497       break;
498     }
499 
500     if (!exp) {
501       ValueNumbering[V] = nextValueNumber;
502       return nextValueNumber++;
503     }
504 
505     uint32_t e = ExpressionNumbering[exp];
506     if (!e) {
507       hash_code H = exp->getHashValue([=](Value *V) { return lookupOrAdd(V); });
508       auto I = HashNumbering.find(H);
509       if (I != HashNumbering.end()) {
510         e = I->second;
511       } else {
512         e = nextValueNumber++;
513         HashNumbering[H] = e;
514         ExpressionNumbering[exp] = e;
515       }
516     }
517     ValueNumbering[V] = e;
518     return e;
519   }
520 
521   /// Returns the value number of the specified value. Fails if the value has
522   /// not yet been numbered.
523   uint32_t lookup(Value *V) const {
524     auto VI = ValueNumbering.find(V);
525     assert(VI != ValueNumbering.end() && "Value not numbered?");
526     return VI->second;
527   }
528 
529   /// Removes all value numberings and resets the value table.
530   void clear() {
531     ValueNumbering.clear();
532     ExpressionNumbering.clear();
533     HashNumbering.clear();
534     Recycler.clear(Allocator);
535     nextValueNumber = 1;
536   }
537 
538   /// \c Inst uses or touches memory. Return an ID describing the memory state
539   /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2),
540   /// the exact same memory operations happen after I1 and I2.
541   ///
542   /// This is a very hard problem in general, so we use domain-specific
543   /// knowledge that we only ever check for equivalence between blocks sharing a
544   /// single immediate successor that is common, and when determining if I1 ==
545   /// I2 we will have already determined that next(I1) == next(I2). This
546   /// inductive property allows us to simply return the value number of the next
547   /// instruction that defines memory.
548   uint32_t getMemoryUseOrder(Instruction *Inst) {
549     auto *BB = Inst->getParent();
550     for (auto I = std::next(Inst->getIterator()), E = BB->end();
551          I != E && !I->isTerminator(); ++I) {
552       if (!isMemoryInst(&*I))
553         continue;
554       if (isa<LoadInst>(&*I))
555         continue;
556       CallInst *CI = dyn_cast<CallInst>(&*I);
557       if (CI && CI->onlyReadsMemory())
558         continue;
559       InvokeInst *II = dyn_cast<InvokeInst>(&*I);
560       if (II && II->onlyReadsMemory())
561         continue;
562       return lookupOrAdd(&*I);
563     }
564     return 0;
565   }
566 };
567 
568 //===----------------------------------------------------------------------===//
569 
570 class GVNSink {
571 public:
572   GVNSink() = default;
573 
574   bool run(Function &F) {
575     LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName()
576                       << "\n");
577 
578     unsigned NumSunk = 0;
579     ReversePostOrderTraversal<Function*> RPOT(&F);
580     VN.setReachableBBs(BasicBlocksSet(RPOT.begin(), RPOT.end()));
581     for (auto *N : RPOT)
582       NumSunk += sinkBB(N);
583 
584     return NumSunk > 0;
585   }
586 
587 private:
588   ValueTable VN;
589 
590   bool shouldAvoidSinkingInstruction(Instruction *I) {
591     // These instructions may change or break semantics if moved.
592     if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) ||
593         I->getType()->isTokenTy())
594       return true;
595     return false;
596   }
597 
598   /// The main heuristic function. Analyze the set of instructions pointed to by
599   /// LRI and return a candidate solution if these instructions can be sunk, or
600   /// std::nullopt otherwise.
601   std::optional<SinkingInstructionCandidate> analyzeInstructionForSinking(
602       LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum,
603       ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents);
604 
605   /// Create a ModelledPHI for each PHI in BB, adding to PHIs.
606   void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs,
607                           SmallPtrSetImpl<Value *> &PHIContents) {
608     for (PHINode &PN : BB->phis()) {
609       auto MPHI = ModelledPHI(&PN);
610       PHIs.insert(MPHI);
611       for (auto *V : MPHI.getValues())
612         PHIContents.insert(V);
613     }
614   }
615 
616   /// The main instruction sinking driver. Set up state and try and sink
617   /// instructions into BBEnd from its predecessors.
618   unsigned sinkBB(BasicBlock *BBEnd);
619 
620   /// Perform the actual mechanics of sinking an instruction from Blocks into
621   /// BBEnd, which is their only successor.
622   void sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd);
623 
624   /// Remove PHIs that all have the same incoming value.
625   void foldPointlessPHINodes(BasicBlock *BB) {
626     auto I = BB->begin();
627     while (PHINode *PN = dyn_cast<PHINode>(I++)) {
628       if (!llvm::all_of(PN->incoming_values(), [&](const Value *V) {
629             return V == PN->getIncomingValue(0);
630           }))
631         continue;
632       if (PN->getIncomingValue(0) != PN)
633         PN->replaceAllUsesWith(PN->getIncomingValue(0));
634       else
635         PN->replaceAllUsesWith(PoisonValue::get(PN->getType()));
636       PN->eraseFromParent();
637     }
638   }
639 };
640 
641 std::optional<SinkingInstructionCandidate>
642 GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI,
643                                       unsigned &InstNum,
644                                       unsigned &MemoryInstNum,
645                                       ModelledPHISet &NeededPHIs,
646                                       SmallPtrSetImpl<Value *> &PHIContents) {
647   auto Insts = *LRI;
648   LLVM_DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I
649                                                                   : Insts) {
650     I->dump();
651   } dbgs() << " ]\n";);
652 
653   DenseMap<uint32_t, unsigned> VNums;
654   for (auto *I : Insts) {
655     uint32_t N = VN.lookupOrAdd(I);
656     LLVM_DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n");
657     if (N == ~0U)
658       return std::nullopt;
659     VNums[N]++;
660   }
661   unsigned VNumToSink =
662       std::max_element(VNums.begin(), VNums.end(), llvm::less_second())->first;
663 
664   if (VNums[VNumToSink] == 1)
665     // Can't sink anything!
666     return std::nullopt;
667 
668   // Now restrict the number of incoming blocks down to only those with
669   // VNumToSink.
670   auto &ActivePreds = LRI.getActiveBlocks();
671   unsigned InitialActivePredSize = ActivePreds.size();
672   SmallVector<Instruction *, 4> NewInsts;
673   for (auto *I : Insts) {
674     if (VN.lookup(I) != VNumToSink)
675       ActivePreds.remove(I->getParent());
676     else
677       NewInsts.push_back(I);
678   }
679   for (auto *I : NewInsts)
680     if (shouldAvoidSinkingInstruction(I))
681       return std::nullopt;
682 
683   // If we've restricted the incoming blocks, restrict all needed PHIs also
684   // to that set.
685   bool RecomputePHIContents = false;
686   if (ActivePreds.size() != InitialActivePredSize) {
687     ModelledPHISet NewNeededPHIs;
688     for (auto P : NeededPHIs) {
689       P.restrictToBlocks(ActivePreds);
690       NewNeededPHIs.insert(P);
691     }
692     NeededPHIs = NewNeededPHIs;
693     LRI.restrictToBlocks(ActivePreds);
694     RecomputePHIContents = true;
695   }
696 
697   // The sunk instruction's results.
698   ModelledPHI NewPHI(NewInsts, ActivePreds);
699 
700   // Does sinking this instruction render previous PHIs redundant?
701   if (NeededPHIs.erase(NewPHI))
702     RecomputePHIContents = true;
703 
704   if (RecomputePHIContents) {
705     // The needed PHIs have changed, so recompute the set of all needed
706     // values.
707     PHIContents.clear();
708     for (auto &PHI : NeededPHIs)
709       PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
710   }
711 
712   // Is this instruction required by a later PHI that doesn't match this PHI?
713   // if so, we can't sink this instruction.
714   for (auto *V : NewPHI.getValues())
715     if (PHIContents.count(V))
716       // V exists in this PHI, but the whole PHI is different to NewPHI
717       // (else it would have been removed earlier). We cannot continue
718       // because this isn't representable.
719       return std::nullopt;
720 
721   // Which operands need PHIs?
722   // FIXME: If any of these fail, we should partition up the candidates to
723   // try and continue making progress.
724   Instruction *I0 = NewInsts[0];
725 
726   // If all instructions that are going to participate don't have the same
727   // number of operands, we can't do any useful PHI analysis for all operands.
728   auto hasDifferentNumOperands = [&I0](Instruction *I) {
729     return I->getNumOperands() != I0->getNumOperands();
730   };
731   if (any_of(NewInsts, hasDifferentNumOperands))
732     return std::nullopt;
733 
734   for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) {
735     ModelledPHI PHI(NewInsts, OpNum, ActivePreds);
736     if (PHI.areAllIncomingValuesSame())
737       continue;
738     if (!canReplaceOperandWithVariable(I0, OpNum))
739       // We can 't create a PHI from this instruction!
740       return std::nullopt;
741     if (NeededPHIs.count(PHI))
742       continue;
743     if (!PHI.areAllIncomingValuesSameType())
744       return std::nullopt;
745     // Don't create indirect calls! The called value is the final operand.
746     if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 &&
747         PHI.areAnyIncomingValuesConstant())
748       return std::nullopt;
749 
750     NeededPHIs.reserve(NeededPHIs.size());
751     NeededPHIs.insert(PHI);
752     PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
753   }
754 
755   if (isMemoryInst(NewInsts[0]))
756     ++MemoryInstNum;
757 
758   SinkingInstructionCandidate Cand;
759   Cand.NumInstructions = ++InstNum;
760   Cand.NumMemoryInsts = MemoryInstNum;
761   Cand.NumBlocks = ActivePreds.size();
762   Cand.NumPHIs = NeededPHIs.size();
763   append_range(Cand.Blocks, ActivePreds);
764 
765   return Cand;
766 }
767 
768 unsigned GVNSink::sinkBB(BasicBlock *BBEnd) {
769   LLVM_DEBUG(dbgs() << "GVNSink: running on basic block ";
770              BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
771   SmallVector<BasicBlock *, 4> Preds;
772   for (auto *B : predecessors(BBEnd)) {
773     auto *T = B->getTerminator();
774     if (isa<BranchInst>(T) || isa<SwitchInst>(T))
775       Preds.push_back(B);
776     else
777       return 0;
778   }
779   if (Preds.size() < 2)
780     return 0;
781   llvm::sort(Preds);
782 
783   unsigned NumOrigPreds = Preds.size();
784   // We can only sink instructions through unconditional branches.
785   llvm::erase_if(Preds, [](BasicBlock *BB) {
786     return BB->getTerminator()->getNumSuccessors() != 1;
787   });
788 
789   LockstepReverseIterator LRI(Preds);
790   SmallVector<SinkingInstructionCandidate, 4> Candidates;
791   unsigned InstNum = 0, MemoryInstNum = 0;
792   ModelledPHISet NeededPHIs;
793   SmallPtrSet<Value *, 4> PHIContents;
794   analyzeInitialPHIs(BBEnd, NeededPHIs, PHIContents);
795   unsigned NumOrigPHIs = NeededPHIs.size();
796 
797   while (LRI.isValid()) {
798     auto Cand = analyzeInstructionForSinking(LRI, InstNum, MemoryInstNum,
799                                              NeededPHIs, PHIContents);
800     if (!Cand)
801       break;
802     Cand->calculateCost(NumOrigPHIs, Preds.size());
803     Candidates.emplace_back(*Cand);
804     --LRI;
805   }
806 
807   llvm::stable_sort(Candidates, std::greater<SinkingInstructionCandidate>());
808   LLVM_DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C
809                                                          : Candidates) dbgs()
810                                                     << "  " << C << "\n";);
811 
812   // Pick the top candidate, as long it is positive!
813   if (Candidates.empty() || Candidates.front().Cost <= 0)
814     return 0;
815   auto C = Candidates.front();
816 
817   LLVM_DEBUG(dbgs() << " -- Sinking: " << C << "\n");
818   BasicBlock *InsertBB = BBEnd;
819   if (C.Blocks.size() < NumOrigPreds) {
820     LLVM_DEBUG(dbgs() << " -- Splitting edge to ";
821                BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
822     InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split");
823     if (!InsertBB) {
824       LLVM_DEBUG(dbgs() << " -- FAILED to split edge!\n");
825       // Edge couldn't be split.
826       return 0;
827     }
828   }
829 
830   for (unsigned I = 0; I < C.NumInstructions; ++I)
831     sinkLastInstruction(C.Blocks, InsertBB);
832 
833   return C.NumInstructions;
834 }
835 
836 void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks,
837                                   BasicBlock *BBEnd) {
838   SmallVector<Instruction *, 4> Insts;
839   for (BasicBlock *BB : Blocks)
840     Insts.push_back(BB->getTerminator()->getPrevNode());
841   Instruction *I0 = Insts.front();
842 
843   SmallVector<Value *, 4> NewOperands;
844   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) {
845     bool NeedPHI = llvm::any_of(Insts, [&I0, O](const Instruction *I) {
846       return I->getOperand(O) != I0->getOperand(O);
847     });
848     if (!NeedPHI) {
849       NewOperands.push_back(I0->getOperand(O));
850       continue;
851     }
852 
853     // Create a new PHI in the successor block and populate it.
854     auto *Op = I0->getOperand(O);
855     assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!");
856     auto *PN = PHINode::Create(Op->getType(), Insts.size(),
857                                Op->getName() + ".sink", &BBEnd->front());
858     for (auto *I : Insts)
859       PN->addIncoming(I->getOperand(O), I->getParent());
860     NewOperands.push_back(PN);
861   }
862 
863   // Arbitrarily use I0 as the new "common" instruction; remap its operands
864   // and move it to the start of the successor block.
865   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O)
866     I0->getOperandUse(O).set(NewOperands[O]);
867   I0->moveBefore(&*BBEnd->getFirstInsertionPt());
868 
869   // Update metadata and IR flags.
870   for (auto *I : Insts)
871     if (I != I0) {
872       combineMetadataForCSE(I0, I, true);
873       I0->andIRFlags(I);
874     }
875 
876   for (auto *I : Insts)
877     if (I != I0)
878       I->replaceAllUsesWith(I0);
879   foldPointlessPHINodes(BBEnd);
880 
881   // Finally nuke all instructions apart from the common instruction.
882   for (auto *I : Insts)
883     if (I != I0)
884       I->eraseFromParent();
885 
886   NumRemoved += Insts.size() - 1;
887 }
888 
889 ////////////////////////////////////////////////////////////////////////////////
890 // Pass machinery / boilerplate
891 
892 class GVNSinkLegacyPass : public FunctionPass {
893 public:
894   static char ID;
895 
896   GVNSinkLegacyPass() : FunctionPass(ID) {
897     initializeGVNSinkLegacyPassPass(*PassRegistry::getPassRegistry());
898   }
899 
900   bool runOnFunction(Function &F) override {
901     if (skipFunction(F))
902       return false;
903     GVNSink G;
904     return G.run(F);
905   }
906 
907   void getAnalysisUsage(AnalysisUsage &AU) const override {
908     AU.addPreserved<GlobalsAAWrapperPass>();
909   }
910 };
911 
912 } // end anonymous namespace
913 
914 PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) {
915   GVNSink G;
916   if (!G.run(F))
917     return PreservedAnalyses::all();
918   return PreservedAnalyses::none();
919 }
920 
921 char GVNSinkLegacyPass::ID = 0;
922 
923 INITIALIZE_PASS_BEGIN(GVNSinkLegacyPass, "gvn-sink",
924                       "Early GVN sinking of Expressions", false, false)
925 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
926 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
927 INITIALIZE_PASS_END(GVNSinkLegacyPass, "gvn-sink",
928                     "Early GVN sinking of Expressions", false, false)
929 
930 FunctionPass *llvm::createGVNSinkPass() { return new GVNSinkLegacyPass(); }
931