1 //===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Correlated Value Propagation pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h"
14 #include "llvm/ADT/DepthFirstIterator.h"
15 #include "llvm/ADT/Optional.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/InstructionSimplify.h"
21 #include "llvm/Analysis/LazyValueInfo.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/ConstantRange.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/InstrTypes.h"
32 #include "llvm/IR/Instruction.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/IntrinsicInst.h"
35 #include "llvm/IR/Operator.h"
36 #include "llvm/IR/PassManager.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include "llvm/Transforms/Scalar.h"
46 #include "llvm/Transforms/Utils/Local.h"
47 #include <cassert>
48 #include <utility>
49 
50 using namespace llvm;
51 
52 #define DEBUG_TYPE "correlated-value-propagation"
53 
54 STATISTIC(NumPhis,      "Number of phis propagated");
55 STATISTIC(NumPhiCommon, "Number of phis deleted via common incoming value");
56 STATISTIC(NumSelects,   "Number of selects propagated");
57 STATISTIC(NumMemAccess, "Number of memory access targets propagated");
58 STATISTIC(NumCmps,      "Number of comparisons propagated");
59 STATISTIC(NumReturns,   "Number of return values propagated");
60 STATISTIC(NumDeadCases, "Number of switch cases removed");
61 STATISTIC(NumSDivSRemsNarrowed,
62           "Number of sdivs/srems whose width was decreased");
63 STATISTIC(NumSDivs,     "Number of sdiv converted to udiv");
64 STATISTIC(NumUDivURemsNarrowed,
65           "Number of udivs/urems whose width was decreased");
66 STATISTIC(NumAShrs,     "Number of ashr converted to lshr");
67 STATISTIC(NumSRems,     "Number of srem converted to urem");
68 STATISTIC(NumSExt,      "Number of sext converted to zext");
69 STATISTIC(NumAnd,       "Number of ands removed");
70 STATISTIC(NumNW,        "Number of no-wrap deductions");
71 STATISTIC(NumNSW,       "Number of no-signed-wrap deductions");
72 STATISTIC(NumNUW,       "Number of no-unsigned-wrap deductions");
73 STATISTIC(NumAddNW,     "Number of no-wrap deductions for add");
74 STATISTIC(NumAddNSW,    "Number of no-signed-wrap deductions for add");
75 STATISTIC(NumAddNUW,    "Number of no-unsigned-wrap deductions for add");
76 STATISTIC(NumSubNW,     "Number of no-wrap deductions for sub");
77 STATISTIC(NumSubNSW,    "Number of no-signed-wrap deductions for sub");
78 STATISTIC(NumSubNUW,    "Number of no-unsigned-wrap deductions for sub");
79 STATISTIC(NumMulNW,     "Number of no-wrap deductions for mul");
80 STATISTIC(NumMulNSW,    "Number of no-signed-wrap deductions for mul");
81 STATISTIC(NumMulNUW,    "Number of no-unsigned-wrap deductions for mul");
82 STATISTIC(NumShlNW,     "Number of no-wrap deductions for shl");
83 STATISTIC(NumShlNSW,    "Number of no-signed-wrap deductions for shl");
84 STATISTIC(NumShlNUW,    "Number of no-unsigned-wrap deductions for shl");
85 STATISTIC(NumOverflows, "Number of overflow checks removed");
86 STATISTIC(NumSaturating,
87     "Number of saturating arithmetics converted to normal arithmetics");
88 
89 static cl::opt<bool> DontAddNoWrapFlags("cvp-dont-add-nowrap-flags", cl::init(false));
90 
91 namespace {
92 
93   class CorrelatedValuePropagation : public FunctionPass {
94   public:
95     static char ID;
96 
97     CorrelatedValuePropagation(): FunctionPass(ID) {
98      initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry());
99     }
100 
101     bool runOnFunction(Function &F) override;
102 
103     void getAnalysisUsage(AnalysisUsage &AU) const override {
104       AU.addRequired<DominatorTreeWrapperPass>();
105       AU.addRequired<LazyValueInfoWrapperPass>();
106       AU.addPreserved<GlobalsAAWrapperPass>();
107       AU.addPreserved<DominatorTreeWrapperPass>();
108       AU.addPreserved<LazyValueInfoWrapperPass>();
109     }
110   };
111 
112 } // end anonymous namespace
113 
114 char CorrelatedValuePropagation::ID = 0;
115 
116 INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation",
117                 "Value Propagation", false, false)
118 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
119 INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
120 INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation",
121                 "Value Propagation", false, false)
122 
123 // Public interface to the Value Propagation pass
124 Pass *llvm::createCorrelatedValuePropagationPass() {
125   return new CorrelatedValuePropagation();
126 }
127 
128 static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {
129   if (S->getType()->isVectorTy()) return false;
130   if (isa<Constant>(S->getCondition())) return false;
131 
132   Constant *C = LVI->getConstant(S->getCondition(), S);
133   if (!C) return false;
134 
135   ConstantInt *CI = dyn_cast<ConstantInt>(C);
136   if (!CI) return false;
137 
138   Value *ReplaceWith = CI->isOne() ? S->getTrueValue() : S->getFalseValue();
139   S->replaceAllUsesWith(ReplaceWith);
140   S->eraseFromParent();
141 
142   ++NumSelects;
143 
144   return true;
145 }
146 
147 /// Try to simplify a phi with constant incoming values that match the edge
148 /// values of a non-constant value on all other edges:
149 /// bb0:
150 ///   %isnull = icmp eq i8* %x, null
151 ///   br i1 %isnull, label %bb2, label %bb1
152 /// bb1:
153 ///   br label %bb2
154 /// bb2:
155 ///   %r = phi i8* [ %x, %bb1 ], [ null, %bb0 ]
156 /// -->
157 ///   %r = %x
158 static bool simplifyCommonValuePhi(PHINode *P, LazyValueInfo *LVI,
159                                    DominatorTree *DT) {
160   // Collect incoming constants and initialize possible common value.
161   SmallVector<std::pair<Constant *, unsigned>, 4> IncomingConstants;
162   Value *CommonValue = nullptr;
163   for (unsigned i = 0, e = P->getNumIncomingValues(); i != e; ++i) {
164     Value *Incoming = P->getIncomingValue(i);
165     if (auto *IncomingConstant = dyn_cast<Constant>(Incoming)) {
166       IncomingConstants.push_back(std::make_pair(IncomingConstant, i));
167     } else if (!CommonValue) {
168       // The potential common value is initialized to the first non-constant.
169       CommonValue = Incoming;
170     } else if (Incoming != CommonValue) {
171       // There can be only one non-constant common value.
172       return false;
173     }
174   }
175 
176   if (!CommonValue || IncomingConstants.empty())
177     return false;
178 
179   // The common value must be valid in all incoming blocks.
180   BasicBlock *ToBB = P->getParent();
181   if (auto *CommonInst = dyn_cast<Instruction>(CommonValue))
182     if (!DT->dominates(CommonInst, ToBB))
183       return false;
184 
185   // We have a phi with exactly 1 variable incoming value and 1 or more constant
186   // incoming values. See if all constant incoming values can be mapped back to
187   // the same incoming variable value.
188   for (auto &IncomingConstant : IncomingConstants) {
189     Constant *C = IncomingConstant.first;
190     BasicBlock *IncomingBB = P->getIncomingBlock(IncomingConstant.second);
191     if (C != LVI->getConstantOnEdge(CommonValue, IncomingBB, ToBB, P))
192       return false;
193   }
194 
195   // All constant incoming values map to the same variable along the incoming
196   // edges of the phi. The phi is unnecessary. However, we must drop all
197   // poison-generating flags to ensure that no poison is propagated to the phi
198   // location by performing this substitution.
199   // Warning: If the underlying analysis changes, this may not be enough to
200   //          guarantee that poison is not propagated.
201   // TODO: We may be able to re-infer flags by re-analyzing the instruction.
202   if (auto *CommonInst = dyn_cast<Instruction>(CommonValue))
203     CommonInst->dropPoisonGeneratingFlags();
204   P->replaceAllUsesWith(CommonValue);
205   P->eraseFromParent();
206   ++NumPhiCommon;
207   return true;
208 }
209 
210 static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT,
211                        const SimplifyQuery &SQ) {
212   bool Changed = false;
213 
214   BasicBlock *BB = P->getParent();
215   for (unsigned i = 0, e = P->getNumIncomingValues(); i < e; ++i) {
216     Value *Incoming = P->getIncomingValue(i);
217     if (isa<Constant>(Incoming)) continue;
218 
219     Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB, P);
220 
221     // Look if the incoming value is a select with a scalar condition for which
222     // LVI can tells us the value. In that case replace the incoming value with
223     // the appropriate value of the select. This often allows us to remove the
224     // select later.
225     if (!V) {
226       SelectInst *SI = dyn_cast<SelectInst>(Incoming);
227       if (!SI) continue;
228 
229       Value *Condition = SI->getCondition();
230       if (!Condition->getType()->isVectorTy()) {
231         if (Constant *C = LVI->getConstantOnEdge(
232                 Condition, P->getIncomingBlock(i), BB, P)) {
233           if (C->isOneValue()) {
234             V = SI->getTrueValue();
235           } else if (C->isZeroValue()) {
236             V = SI->getFalseValue();
237           }
238           // Once LVI learns to handle vector types, we could also add support
239           // for vector type constants that are not all zeroes or all ones.
240         }
241       }
242 
243       // Look if the select has a constant but LVI tells us that the incoming
244       // value can never be that constant. In that case replace the incoming
245       // value with the other value of the select. This often allows us to
246       // remove the select later.
247       if (!V) {
248         Constant *C = dyn_cast<Constant>(SI->getFalseValue());
249         if (!C) continue;
250 
251         if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C,
252               P->getIncomingBlock(i), BB, P) !=
253             LazyValueInfo::False)
254           continue;
255         V = SI->getTrueValue();
256       }
257 
258       LLVM_DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n');
259     }
260 
261     P->setIncomingValue(i, V);
262     Changed = true;
263   }
264 
265   if (Value *V = SimplifyInstruction(P, SQ)) {
266     P->replaceAllUsesWith(V);
267     P->eraseFromParent();
268     Changed = true;
269   }
270 
271   if (!Changed)
272     Changed = simplifyCommonValuePhi(P, LVI, DT);
273 
274   if (Changed)
275     ++NumPhis;
276 
277   return Changed;
278 }
279 
280 static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) {
281   Value *Pointer = nullptr;
282   if (LoadInst *L = dyn_cast<LoadInst>(I))
283     Pointer = L->getPointerOperand();
284   else
285     Pointer = cast<StoreInst>(I)->getPointerOperand();
286 
287   if (isa<Constant>(Pointer)) return false;
288 
289   Constant *C = LVI->getConstant(Pointer, I);
290   if (!C) return false;
291 
292   ++NumMemAccess;
293   I->replaceUsesOfWith(Pointer, C);
294   return true;
295 }
296 
297 /// See if LazyValueInfo's ability to exploit edge conditions or range
298 /// information is sufficient to prove this comparison. Even for local
299 /// conditions, this can sometimes prove conditions instcombine can't by
300 /// exploiting range information.
301 static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) {
302   Value *Op0 = Cmp->getOperand(0);
303   auto *C = dyn_cast<Constant>(Cmp->getOperand(1));
304   if (!C)
305     return false;
306 
307   LazyValueInfo::Tristate Result =
308       LVI->getPredicateAt(Cmp->getPredicate(), Op0, C, Cmp,
309                           /*UseBlockValue=*/true);
310   if (Result == LazyValueInfo::Unknown)
311     return false;
312 
313   ++NumCmps;
314   Constant *TorF = ConstantInt::get(Type::getInt1Ty(Cmp->getContext()), Result);
315   Cmp->replaceAllUsesWith(TorF);
316   Cmp->eraseFromParent();
317   return true;
318 }
319 
320 /// Simplify a switch instruction by removing cases which can never fire. If the
321 /// uselessness of a case could be determined locally then constant propagation
322 /// would already have figured it out. Instead, walk the predecessors and
323 /// statically evaluate cases based on information available on that edge. Cases
324 /// that cannot fire no matter what the incoming edge can safely be removed. If
325 /// a case fires on every incoming edge then the entire switch can be removed
326 /// and replaced with a branch to the case destination.
327 static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI,
328                           DominatorTree *DT) {
329   DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy);
330   Value *Cond = I->getCondition();
331   BasicBlock *BB = I->getParent();
332 
333   // Analyse each switch case in turn.
334   bool Changed = false;
335   DenseMap<BasicBlock*, int> SuccessorsCount;
336   for (auto *Succ : successors(BB))
337     SuccessorsCount[Succ]++;
338 
339   { // Scope for SwitchInstProfUpdateWrapper. It must not live during
340     // ConstantFoldTerminator() as the underlying SwitchInst can be changed.
341     SwitchInstProfUpdateWrapper SI(*I);
342 
343     for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) {
344       ConstantInt *Case = CI->getCaseValue();
345       LazyValueInfo::Tristate State =
346           LVI->getPredicateAt(CmpInst::ICMP_EQ, Cond, Case, I,
347                               /* UseBlockValue */ true);
348 
349       if (State == LazyValueInfo::False) {
350         // This case never fires - remove it.
351         BasicBlock *Succ = CI->getCaseSuccessor();
352         Succ->removePredecessor(BB);
353         CI = SI.removeCase(CI);
354         CE = SI->case_end();
355 
356         // The condition can be modified by removePredecessor's PHI simplification
357         // logic.
358         Cond = SI->getCondition();
359 
360         ++NumDeadCases;
361         Changed = true;
362         if (--SuccessorsCount[Succ] == 0)
363           DTU.applyUpdatesPermissive({{DominatorTree::Delete, BB, Succ}});
364         continue;
365       }
366       if (State == LazyValueInfo::True) {
367         // This case always fires.  Arrange for the switch to be turned into an
368         // unconditional branch by replacing the switch condition with the case
369         // value.
370         SI->setCondition(Case);
371         NumDeadCases += SI->getNumCases();
372         Changed = true;
373         break;
374       }
375 
376       // Increment the case iterator since we didn't delete it.
377       ++CI;
378     }
379   }
380 
381   if (Changed)
382     // If the switch has been simplified to the point where it can be replaced
383     // by a branch then do so now.
384     ConstantFoldTerminator(BB, /*DeleteDeadConditions = */ false,
385                            /*TLI = */ nullptr, &DTU);
386   return Changed;
387 }
388 
389 // See if we can prove that the given binary op intrinsic will not overflow.
390 static bool willNotOverflow(BinaryOpIntrinsic *BO, LazyValueInfo *LVI) {
391   ConstantRange LRange = LVI->getConstantRange(BO->getLHS(), BO);
392   ConstantRange RRange = LVI->getConstantRange(BO->getRHS(), BO);
393   ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
394       BO->getBinaryOp(), RRange, BO->getNoWrapKind());
395   return NWRegion.contains(LRange);
396 }
397 
398 static void setDeducedOverflowingFlags(Value *V, Instruction::BinaryOps Opcode,
399                                        bool NewNSW, bool NewNUW) {
400   Statistic *OpcNW, *OpcNSW, *OpcNUW;
401   switch (Opcode) {
402   case Instruction::Add:
403     OpcNW = &NumAddNW;
404     OpcNSW = &NumAddNSW;
405     OpcNUW = &NumAddNUW;
406     break;
407   case Instruction::Sub:
408     OpcNW = &NumSubNW;
409     OpcNSW = &NumSubNSW;
410     OpcNUW = &NumSubNUW;
411     break;
412   case Instruction::Mul:
413     OpcNW = &NumMulNW;
414     OpcNSW = &NumMulNSW;
415     OpcNUW = &NumMulNUW;
416     break;
417   case Instruction::Shl:
418     OpcNW = &NumShlNW;
419     OpcNSW = &NumShlNSW;
420     OpcNUW = &NumShlNUW;
421     break;
422   default:
423     llvm_unreachable("Will not be called with other binops");
424   }
425 
426   auto *Inst = dyn_cast<Instruction>(V);
427   if (NewNSW) {
428     ++NumNW;
429     ++*OpcNW;
430     ++NumNSW;
431     ++*OpcNSW;
432     if (Inst)
433       Inst->setHasNoSignedWrap();
434   }
435   if (NewNUW) {
436     ++NumNW;
437     ++*OpcNW;
438     ++NumNUW;
439     ++*OpcNUW;
440     if (Inst)
441       Inst->setHasNoUnsignedWrap();
442   }
443 }
444 
445 static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI);
446 
447 // Rewrite this with.overflow intrinsic as non-overflowing.
448 static void processOverflowIntrinsic(WithOverflowInst *WO, LazyValueInfo *LVI) {
449   IRBuilder<> B(WO);
450   Instruction::BinaryOps Opcode = WO->getBinaryOp();
451   bool NSW = WO->isSigned();
452   bool NUW = !WO->isSigned();
453 
454   Value *NewOp =
455       B.CreateBinOp(Opcode, WO->getLHS(), WO->getRHS(), WO->getName());
456   setDeducedOverflowingFlags(NewOp, Opcode, NSW, NUW);
457 
458   StructType *ST = cast<StructType>(WO->getType());
459   Constant *Struct = ConstantStruct::get(ST,
460       { UndefValue::get(ST->getElementType(0)),
461         ConstantInt::getFalse(ST->getElementType(1)) });
462   Value *NewI = B.CreateInsertValue(Struct, NewOp, 0);
463   WO->replaceAllUsesWith(NewI);
464   WO->eraseFromParent();
465   ++NumOverflows;
466 
467   // See if we can infer the other no-wrap too.
468   if (auto *BO = dyn_cast<BinaryOperator>(NewOp))
469     processBinOp(BO, LVI);
470 }
471 
472 static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) {
473   Instruction::BinaryOps Opcode = SI->getBinaryOp();
474   bool NSW = SI->isSigned();
475   bool NUW = !SI->isSigned();
476   BinaryOperator *BinOp = BinaryOperator::Create(
477       Opcode, SI->getLHS(), SI->getRHS(), SI->getName(), SI);
478   BinOp->setDebugLoc(SI->getDebugLoc());
479   setDeducedOverflowingFlags(BinOp, Opcode, NSW, NUW);
480 
481   SI->replaceAllUsesWith(BinOp);
482   SI->eraseFromParent();
483   ++NumSaturating;
484 
485   // See if we can infer the other no-wrap too.
486   if (auto *BO = dyn_cast<BinaryOperator>(BinOp))
487     processBinOp(BO, LVI);
488 }
489 
490 /// Infer nonnull attributes for the arguments at the specified callsite.
491 static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {
492 
493   if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) {
494     if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) {
495       processOverflowIntrinsic(WO, LVI);
496       return true;
497     }
498   }
499 
500   if (auto *SI = dyn_cast<SaturatingInst>(&CB)) {
501     if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) {
502       processSaturatingInst(SI, LVI);
503       return true;
504     }
505   }
506 
507   bool Changed = false;
508 
509   // Deopt bundle operands are intended to capture state with minimal
510   // perturbance of the code otherwise.  If we can find a constant value for
511   // any such operand and remove a use of the original value, that's
512   // desireable since it may allow further optimization of that value (e.g. via
513   // single use rules in instcombine).  Since deopt uses tend to,
514   // idiomatically, appear along rare conditional paths, it's reasonable likely
515   // we may have a conditional fact with which LVI can fold.
516   if (auto DeoptBundle = CB.getOperandBundle(LLVMContext::OB_deopt)) {
517     for (const Use &ConstU : DeoptBundle->Inputs) {
518       Use &U = const_cast<Use&>(ConstU);
519       Value *V = U.get();
520       if (V->getType()->isVectorTy()) continue;
521       if (isa<Constant>(V)) continue;
522 
523       Constant *C = LVI->getConstant(V, &CB);
524       if (!C) continue;
525       U.set(C);
526       Changed = true;
527     }
528   }
529 
530   SmallVector<unsigned, 4> ArgNos;
531   unsigned ArgNo = 0;
532 
533   for (Value *V : CB.args()) {
534     PointerType *Type = dyn_cast<PointerType>(V->getType());
535     // Try to mark pointer typed parameters as non-null.  We skip the
536     // relatively expensive analysis for constants which are obviously either
537     // null or non-null to start with.
538     if (Type && !CB.paramHasAttr(ArgNo, Attribute::NonNull) &&
539         !isa<Constant>(V) &&
540         LVI->getPredicateAt(ICmpInst::ICMP_EQ, V,
541                             ConstantPointerNull::get(Type),
542                             &CB) == LazyValueInfo::False)
543       ArgNos.push_back(ArgNo);
544     ArgNo++;
545   }
546 
547   assert(ArgNo == CB.arg_size() && "sanity check");
548 
549   if (ArgNos.empty())
550     return Changed;
551 
552   AttributeList AS = CB.getAttributes();
553   LLVMContext &Ctx = CB.getContext();
554   AS = AS.addParamAttribute(Ctx, ArgNos,
555                             Attribute::get(Ctx, Attribute::NonNull));
556   CB.setAttributes(AS);
557 
558   return true;
559 }
560 
561 static bool isNonNegative(Value *V, LazyValueInfo *LVI, Instruction *CxtI) {
562   Constant *Zero = ConstantInt::get(V->getType(), 0);
563   auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SGE, V, Zero, CxtI);
564   return Result == LazyValueInfo::True;
565 }
566 
567 static bool isNonPositive(Value *V, LazyValueInfo *LVI, Instruction *CxtI) {
568   Constant *Zero = ConstantInt::get(V->getType(), 0);
569   auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SLE, V, Zero, CxtI);
570   return Result == LazyValueInfo::True;
571 }
572 
573 enum class Domain { NonNegative, NonPositive, Unknown };
574 
575 Domain getDomain(Value *V, LazyValueInfo *LVI, Instruction *CxtI) {
576   if (isNonNegative(V, LVI, CxtI))
577     return Domain::NonNegative;
578   if (isNonPositive(V, LVI, CxtI))
579     return Domain::NonPositive;
580   return Domain::Unknown;
581 }
582 
583 /// Try to shrink a sdiv/srem's width down to the smallest power of two that's
584 /// sufficient to contain its operands.
585 static bool narrowSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) {
586   assert(Instr->getOpcode() == Instruction::SDiv ||
587          Instr->getOpcode() == Instruction::SRem);
588   if (Instr->getType()->isVectorTy())
589     return false;
590 
591   // Find the smallest power of two bitwidth that's sufficient to hold Instr's
592   // operands.
593   unsigned OrigWidth = Instr->getType()->getIntegerBitWidth();
594 
595   // What is the smallest bit width that can accomodate the entire value ranges
596   // of both of the operands?
597   std::array<Optional<ConstantRange>, 2> CRs;
598   unsigned MinSignedBits = 0;
599   for (auto I : zip(Instr->operands(), CRs)) {
600     std::get<1>(I) = LVI->getConstantRange(std::get<0>(I), Instr);
601     MinSignedBits = std::max(std::get<1>(I)->getMinSignedBits(), MinSignedBits);
602   }
603 
604   // sdiv/srem is UB if divisor is -1 and divident is INT_MIN, so unless we can
605   // prove that such a combination is impossible, we need to bump the bitwidth.
606   if (CRs[1]->contains(APInt::getAllOnesValue(OrigWidth)) &&
607       CRs[0]->contains(
608           APInt::getSignedMinValue(MinSignedBits).sextOrSelf(OrigWidth)))
609     ++MinSignedBits;
610 
611   // Don't shrink below 8 bits wide.
612   unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MinSignedBits), 8);
613 
614   // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
615   // two.
616   if (NewWidth >= OrigWidth)
617     return false;
618 
619   ++NumSDivSRemsNarrowed;
620   IRBuilder<> B{Instr};
621   auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth);
622   auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy,
623                                      Instr->getName() + ".lhs.trunc");
624   auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy,
625                                      Instr->getName() + ".rhs.trunc");
626   auto *BO = B.CreateBinOp(Instr->getOpcode(), LHS, RHS, Instr->getName());
627   auto *Sext = B.CreateSExt(BO, Instr->getType(), Instr->getName() + ".sext");
628   if (auto *BinOp = dyn_cast<BinaryOperator>(BO))
629     if (BinOp->getOpcode() == Instruction::SDiv)
630       BinOp->setIsExact(Instr->isExact());
631 
632   Instr->replaceAllUsesWith(Sext);
633   Instr->eraseFromParent();
634   return true;
635 }
636 
637 /// Try to shrink a udiv/urem's width down to the smallest power of two that's
638 /// sufficient to contain its operands.
639 static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) {
640   assert(Instr->getOpcode() == Instruction::UDiv ||
641          Instr->getOpcode() == Instruction::URem);
642   if (Instr->getType()->isVectorTy())
643     return false;
644 
645   // Find the smallest power of two bitwidth that's sufficient to hold Instr's
646   // operands.
647 
648   // What is the smallest bit width that can accomodate the entire value ranges
649   // of both of the operands?
650   unsigned MaxActiveBits = 0;
651   for (Value *Operand : Instr->operands()) {
652     ConstantRange CR = LVI->getConstantRange(Operand, Instr);
653     MaxActiveBits = std::max(CR.getActiveBits(), MaxActiveBits);
654   }
655   // Don't shrink below 8 bits wide.
656   unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MaxActiveBits), 8);
657 
658   // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
659   // two.
660   if (NewWidth >= Instr->getType()->getIntegerBitWidth())
661     return false;
662 
663   ++NumUDivURemsNarrowed;
664   IRBuilder<> B{Instr};
665   auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth);
666   auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy,
667                                      Instr->getName() + ".lhs.trunc");
668   auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy,
669                                      Instr->getName() + ".rhs.trunc");
670   auto *BO = B.CreateBinOp(Instr->getOpcode(), LHS, RHS, Instr->getName());
671   auto *Zext = B.CreateZExt(BO, Instr->getType(), Instr->getName() + ".zext");
672   if (auto *BinOp = dyn_cast<BinaryOperator>(BO))
673     if (BinOp->getOpcode() == Instruction::UDiv)
674       BinOp->setIsExact(Instr->isExact());
675 
676   Instr->replaceAllUsesWith(Zext);
677   Instr->eraseFromParent();
678   return true;
679 }
680 
681 static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) {
682   assert(SDI->getOpcode() == Instruction::SRem);
683   if (SDI->getType()->isVectorTy())
684     return false;
685 
686   struct Operand {
687     Value *V;
688     Domain D;
689   };
690   std::array<Operand, 2> Ops;
691 
692   for (const auto I : zip(Ops, SDI->operands())) {
693     Operand &Op = std::get<0>(I);
694     Op.V = std::get<1>(I);
695     Op.D = getDomain(Op.V, LVI, SDI);
696     if (Op.D == Domain::Unknown)
697       return false;
698   }
699 
700   // We know domains of both of the operands!
701   ++NumSRems;
702 
703   // We need operands to be non-negative, so negate each one that isn't.
704   for (Operand &Op : Ops) {
705     if (Op.D == Domain::NonNegative)
706       continue;
707     auto *BO =
708         BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", SDI);
709     BO->setDebugLoc(SDI->getDebugLoc());
710     Op.V = BO;
711   }
712 
713   auto *URem =
714       BinaryOperator::CreateURem(Ops[0].V, Ops[1].V, SDI->getName(), SDI);
715   URem->setDebugLoc(SDI->getDebugLoc());
716 
717   Value *Res = URem;
718 
719   // If the divident was non-positive, we need to negate the result.
720   if (Ops[0].D == Domain::NonPositive)
721     Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI);
722 
723   SDI->replaceAllUsesWith(Res);
724   SDI->eraseFromParent();
725 
726   // Try to simplify our new urem.
727   processUDivOrURem(URem, LVI);
728 
729   return true;
730 }
731 
732 /// See if LazyValueInfo's ability to exploit edge conditions or range
733 /// information is sufficient to prove the signs of both operands of this SDiv.
734 /// If this is the case, replace the SDiv with a UDiv. Even for local
735 /// conditions, this can sometimes prove conditions instcombine can't by
736 /// exploiting range information.
737 static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) {
738   assert(SDI->getOpcode() == Instruction::SDiv);
739   if (SDI->getType()->isVectorTy())
740     return false;
741 
742   struct Operand {
743     Value *V;
744     Domain D;
745   };
746   std::array<Operand, 2> Ops;
747 
748   for (const auto I : zip(Ops, SDI->operands())) {
749     Operand &Op = std::get<0>(I);
750     Op.V = std::get<1>(I);
751     Op.D = getDomain(Op.V, LVI, SDI);
752     if (Op.D == Domain::Unknown)
753       return false;
754   }
755 
756   // We know domains of both of the operands!
757   ++NumSDivs;
758 
759   // We need operands to be non-negative, so negate each one that isn't.
760   for (Operand &Op : Ops) {
761     if (Op.D == Domain::NonNegative)
762       continue;
763     auto *BO =
764         BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", SDI);
765     BO->setDebugLoc(SDI->getDebugLoc());
766     Op.V = BO;
767   }
768 
769   auto *UDiv =
770       BinaryOperator::CreateUDiv(Ops[0].V, Ops[1].V, SDI->getName(), SDI);
771   UDiv->setDebugLoc(SDI->getDebugLoc());
772   UDiv->setIsExact(SDI->isExact());
773 
774   Value *Res = UDiv;
775 
776   // If the operands had two different domains, we need to negate the result.
777   if (Ops[0].D != Ops[1].D)
778     Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI);
779 
780   SDI->replaceAllUsesWith(Res);
781   SDI->eraseFromParent();
782 
783   // Try to simplify our new udiv.
784   processUDivOrURem(UDiv, LVI);
785 
786   return true;
787 }
788 
789 static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) {
790   assert(Instr->getOpcode() == Instruction::SDiv ||
791          Instr->getOpcode() == Instruction::SRem);
792   if (Instr->getType()->isVectorTy())
793     return false;
794 
795   if (Instr->getOpcode() == Instruction::SDiv)
796     if (processSDiv(Instr, LVI))
797       return true;
798 
799   if (Instr->getOpcode() == Instruction::SRem)
800     if (processSRem(Instr, LVI))
801       return true;
802 
803   return narrowSDivOrSRem(Instr, LVI);
804 }
805 
806 static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
807   if (SDI->getType()->isVectorTy())
808     return false;
809 
810   if (!isNonNegative(SDI->getOperand(0), LVI, SDI))
811     return false;
812 
813   ++NumAShrs;
814   auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1),
815                                         SDI->getName(), SDI);
816   BO->setDebugLoc(SDI->getDebugLoc());
817   BO->setIsExact(SDI->isExact());
818   SDI->replaceAllUsesWith(BO);
819   SDI->eraseFromParent();
820 
821   return true;
822 }
823 
824 static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
825   if (SDI->getType()->isVectorTy())
826     return false;
827 
828   Value *Base = SDI->getOperand(0);
829 
830   if (!isNonNegative(Base, LVI, SDI))
831     return false;
832 
833   ++NumSExt;
834   auto *ZExt =
835       CastInst::CreateZExtOrBitCast(Base, SDI->getType(), SDI->getName(), SDI);
836   ZExt->setDebugLoc(SDI->getDebugLoc());
837   SDI->replaceAllUsesWith(ZExt);
838   SDI->eraseFromParent();
839 
840   return true;
841 }
842 
843 static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
844   using OBO = OverflowingBinaryOperator;
845 
846   if (DontAddNoWrapFlags)
847     return false;
848 
849   if (BinOp->getType()->isVectorTy())
850     return false;
851 
852   bool NSW = BinOp->hasNoSignedWrap();
853   bool NUW = BinOp->hasNoUnsignedWrap();
854   if (NSW && NUW)
855     return false;
856 
857   Instruction::BinaryOps Opcode = BinOp->getOpcode();
858   Value *LHS = BinOp->getOperand(0);
859   Value *RHS = BinOp->getOperand(1);
860 
861   ConstantRange LRange = LVI->getConstantRange(LHS, BinOp);
862   ConstantRange RRange = LVI->getConstantRange(RHS, BinOp);
863 
864   bool Changed = false;
865   bool NewNUW = false, NewNSW = false;
866   if (!NUW) {
867     ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion(
868         Opcode, RRange, OBO::NoUnsignedWrap);
869     NewNUW = NUWRange.contains(LRange);
870     Changed |= NewNUW;
871   }
872   if (!NSW) {
873     ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion(
874         Opcode, RRange, OBO::NoSignedWrap);
875     NewNSW = NSWRange.contains(LRange);
876     Changed |= NewNSW;
877   }
878 
879   setDeducedOverflowingFlags(BinOp, Opcode, NewNSW, NewNUW);
880 
881   return Changed;
882 }
883 
884 static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) {
885   if (BinOp->getType()->isVectorTy())
886     return false;
887 
888   // Pattern match (and lhs, C) where C includes a superset of bits which might
889   // be set in lhs.  This is a common truncation idiom created by instcombine.
890   Value *LHS = BinOp->getOperand(0);
891   ConstantInt *RHS = dyn_cast<ConstantInt>(BinOp->getOperand(1));
892   if (!RHS || !RHS->getValue().isMask())
893     return false;
894 
895   // We can only replace the AND with LHS based on range info if the range does
896   // not include undef.
897   ConstantRange LRange =
898       LVI->getConstantRange(LHS, BinOp, /*UndefAllowed=*/false);
899   if (!LRange.getUnsignedMax().ule(RHS->getValue()))
900     return false;
901 
902   BinOp->replaceAllUsesWith(LHS);
903   BinOp->eraseFromParent();
904   NumAnd++;
905   return true;
906 }
907 
908 
909 static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) {
910   if (Constant *C = LVI->getConstant(V, At))
911     return C;
912 
913   // TODO: The following really should be sunk inside LVI's core algorithm, or
914   // at least the outer shims around such.
915   auto *C = dyn_cast<CmpInst>(V);
916   if (!C) return nullptr;
917 
918   Value *Op0 = C->getOperand(0);
919   Constant *Op1 = dyn_cast<Constant>(C->getOperand(1));
920   if (!Op1) return nullptr;
921 
922   LazyValueInfo::Tristate Result =
923     LVI->getPredicateAt(C->getPredicate(), Op0, Op1, At);
924   if (Result == LazyValueInfo::Unknown)
925     return nullptr;
926 
927   return (Result == LazyValueInfo::True) ?
928     ConstantInt::getTrue(C->getContext()) :
929     ConstantInt::getFalse(C->getContext());
930 }
931 
932 static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
933                     const SimplifyQuery &SQ) {
934   bool FnChanged = false;
935   // Visiting in a pre-order depth-first traversal causes us to simplify early
936   // blocks before querying later blocks (which require us to analyze early
937   // blocks).  Eagerly simplifying shallow blocks means there is strictly less
938   // work to do for deep blocks.  This also means we don't visit unreachable
939   // blocks.
940   for (BasicBlock *BB : depth_first(&F.getEntryBlock())) {
941     bool BBChanged = false;
942     for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
943       Instruction *II = &*BI++;
944       switch (II->getOpcode()) {
945       case Instruction::Select:
946         BBChanged |= processSelect(cast<SelectInst>(II), LVI);
947         break;
948       case Instruction::PHI:
949         BBChanged |= processPHI(cast<PHINode>(II), LVI, DT, SQ);
950         break;
951       case Instruction::ICmp:
952       case Instruction::FCmp:
953         BBChanged |= processCmp(cast<CmpInst>(II), LVI);
954         break;
955       case Instruction::Load:
956       case Instruction::Store:
957         BBChanged |= processMemAccess(II, LVI);
958         break;
959       case Instruction::Call:
960       case Instruction::Invoke:
961         BBChanged |= processCallSite(cast<CallBase>(*II), LVI);
962         break;
963       case Instruction::SRem:
964       case Instruction::SDiv:
965         BBChanged |= processSDivOrSRem(cast<BinaryOperator>(II), LVI);
966         break;
967       case Instruction::UDiv:
968       case Instruction::URem:
969         BBChanged |= processUDivOrURem(cast<BinaryOperator>(II), LVI);
970         break;
971       case Instruction::AShr:
972         BBChanged |= processAShr(cast<BinaryOperator>(II), LVI);
973         break;
974       case Instruction::SExt:
975         BBChanged |= processSExt(cast<SExtInst>(II), LVI);
976         break;
977       case Instruction::Add:
978       case Instruction::Sub:
979       case Instruction::Mul:
980       case Instruction::Shl:
981         BBChanged |= processBinOp(cast<BinaryOperator>(II), LVI);
982         break;
983       case Instruction::And:
984         BBChanged |= processAnd(cast<BinaryOperator>(II), LVI);
985         break;
986       }
987     }
988 
989     Instruction *Term = BB->getTerminator();
990     switch (Term->getOpcode()) {
991     case Instruction::Switch:
992       BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI, DT);
993       break;
994     case Instruction::Ret: {
995       auto *RI = cast<ReturnInst>(Term);
996       // Try to determine the return value if we can.  This is mainly here to
997       // simplify the writing of unit tests, but also helps to enable IPO by
998       // constant folding the return values of callees.
999       auto *RetVal = RI->getReturnValue();
1000       if (!RetVal) break; // handle "ret void"
1001       if (isa<Constant>(RetVal)) break; // nothing to do
1002       if (auto *C = getConstantAt(RetVal, RI, LVI)) {
1003         ++NumReturns;
1004         RI->replaceUsesOfWith(RetVal, C);
1005         BBChanged = true;
1006       }
1007     }
1008     }
1009 
1010     FnChanged |= BBChanged;
1011   }
1012 
1013   return FnChanged;
1014 }
1015 
1016 bool CorrelatedValuePropagation::runOnFunction(Function &F) {
1017   if (skipFunction(F))
1018     return false;
1019 
1020   LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
1021   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1022 
1023   return runImpl(F, LVI, DT, getBestSimplifyQuery(*this, F));
1024 }
1025 
1026 PreservedAnalyses
1027 CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) {
1028   LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F);
1029   DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1030 
1031   bool Changed = runImpl(F, LVI, DT, getBestSimplifyQuery(AM, F));
1032 
1033   PreservedAnalyses PA;
1034   if (!Changed) {
1035     PA = PreservedAnalyses::all();
1036   } else {
1037     PA.preserve<GlobalsAA>();
1038     PA.preserve<DominatorTreeAnalysis>();
1039     PA.preserve<LazyValueAnalysis>();
1040   }
1041 
1042   // Keeping LVI alive is expensive, both because it uses a lot of memory, and
1043   // because invalidating values in LVI is expensive. While CVP does preserve
1044   // LVI, we know that passes after JumpThreading+CVP will not need the result
1045   // of this analysis, so we forcefully discard it early.
1046   PA.abandon<LazyValueAnalysis>();
1047   return PA;
1048 }
1049