1 //===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Correlated Value Propagation pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h" 14 #include "llvm/ADT/DepthFirstIterator.h" 15 #include "llvm/ADT/Optional.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/DomTreeUpdater.h" 19 #include "llvm/Analysis/GlobalsModRef.h" 20 #include "llvm/Analysis/InstructionSimplify.h" 21 #include "llvm/Analysis/LazyValueInfo.h" 22 #include "llvm/IR/Attributes.h" 23 #include "llvm/IR/BasicBlock.h" 24 #include "llvm/IR/CFG.h" 25 #include "llvm/IR/Constant.h" 26 #include "llvm/IR/ConstantRange.h" 27 #include "llvm/IR/Constants.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/InstrTypes.h" 32 #include "llvm/IR/Instruction.h" 33 #include "llvm/IR/Instructions.h" 34 #include "llvm/IR/IntrinsicInst.h" 35 #include "llvm/IR/Operator.h" 36 #include "llvm/IR/PassManager.h" 37 #include "llvm/IR/Type.h" 38 #include "llvm/IR/Value.h" 39 #include "llvm/InitializePasses.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Casting.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Support/raw_ostream.h" 45 #include "llvm/Transforms/Scalar.h" 46 #include "llvm/Transforms/Utils/Local.h" 47 #include <cassert> 48 #include <utility> 49 50 using namespace llvm; 51 52 #define DEBUG_TYPE "correlated-value-propagation" 53 54 STATISTIC(NumPhis, "Number of phis propagated"); 55 STATISTIC(NumPhiCommon, "Number of phis deleted via common incoming value"); 56 STATISTIC(NumSelects, "Number of selects propagated"); 57 STATISTIC(NumMemAccess, "Number of memory access targets propagated"); 58 STATISTIC(NumCmps, "Number of comparisons propagated"); 59 STATISTIC(NumReturns, "Number of return values propagated"); 60 STATISTIC(NumDeadCases, "Number of switch cases removed"); 61 STATISTIC(NumSDivSRemsNarrowed, 62 "Number of sdivs/srems whose width was decreased"); 63 STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); 64 STATISTIC(NumUDivURemsNarrowed, 65 "Number of udivs/urems whose width was decreased"); 66 STATISTIC(NumAShrs, "Number of ashr converted to lshr"); 67 STATISTIC(NumSRems, "Number of srem converted to urem"); 68 STATISTIC(NumSExt, "Number of sext converted to zext"); 69 STATISTIC(NumAnd, "Number of ands removed"); 70 STATISTIC(NumNW, "Number of no-wrap deductions"); 71 STATISTIC(NumNSW, "Number of no-signed-wrap deductions"); 72 STATISTIC(NumNUW, "Number of no-unsigned-wrap deductions"); 73 STATISTIC(NumAddNW, "Number of no-wrap deductions for add"); 74 STATISTIC(NumAddNSW, "Number of no-signed-wrap deductions for add"); 75 STATISTIC(NumAddNUW, "Number of no-unsigned-wrap deductions for add"); 76 STATISTIC(NumSubNW, "Number of no-wrap deductions for sub"); 77 STATISTIC(NumSubNSW, "Number of no-signed-wrap deductions for sub"); 78 STATISTIC(NumSubNUW, "Number of no-unsigned-wrap deductions for sub"); 79 STATISTIC(NumMulNW, "Number of no-wrap deductions for mul"); 80 STATISTIC(NumMulNSW, "Number of no-signed-wrap deductions for mul"); 81 STATISTIC(NumMulNUW, "Number of no-unsigned-wrap deductions for mul"); 82 STATISTIC(NumShlNW, "Number of no-wrap deductions for shl"); 83 STATISTIC(NumShlNSW, "Number of no-signed-wrap deductions for shl"); 84 STATISTIC(NumShlNUW, "Number of no-unsigned-wrap deductions for shl"); 85 STATISTIC(NumOverflows, "Number of overflow checks removed"); 86 STATISTIC(NumSaturating, 87 "Number of saturating arithmetics converted to normal arithmetics"); 88 89 static cl::opt<bool> DontAddNoWrapFlags("cvp-dont-add-nowrap-flags", cl::init(false)); 90 91 namespace { 92 93 class CorrelatedValuePropagation : public FunctionPass { 94 public: 95 static char ID; 96 97 CorrelatedValuePropagation(): FunctionPass(ID) { 98 initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry()); 99 } 100 101 bool runOnFunction(Function &F) override; 102 103 void getAnalysisUsage(AnalysisUsage &AU) const override { 104 AU.addRequired<DominatorTreeWrapperPass>(); 105 AU.addRequired<LazyValueInfoWrapperPass>(); 106 AU.addPreserved<GlobalsAAWrapperPass>(); 107 AU.addPreserved<DominatorTreeWrapperPass>(); 108 AU.addPreserved<LazyValueInfoWrapperPass>(); 109 } 110 }; 111 112 } // end anonymous namespace 113 114 char CorrelatedValuePropagation::ID = 0; 115 116 INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation", 117 "Value Propagation", false, false) 118 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 119 INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) 120 INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation", 121 "Value Propagation", false, false) 122 123 // Public interface to the Value Propagation pass 124 Pass *llvm::createCorrelatedValuePropagationPass() { 125 return new CorrelatedValuePropagation(); 126 } 127 128 static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { 129 if (S->getType()->isVectorTy()) return false; 130 if (isa<Constant>(S->getCondition())) return false; 131 132 Constant *C = LVI->getConstant(S->getCondition(), S); 133 if (!C) return false; 134 135 ConstantInt *CI = dyn_cast<ConstantInt>(C); 136 if (!CI) return false; 137 138 Value *ReplaceWith = CI->isOne() ? S->getTrueValue() : S->getFalseValue(); 139 S->replaceAllUsesWith(ReplaceWith); 140 S->eraseFromParent(); 141 142 ++NumSelects; 143 144 return true; 145 } 146 147 /// Try to simplify a phi with constant incoming values that match the edge 148 /// values of a non-constant value on all other edges: 149 /// bb0: 150 /// %isnull = icmp eq i8* %x, null 151 /// br i1 %isnull, label %bb2, label %bb1 152 /// bb1: 153 /// br label %bb2 154 /// bb2: 155 /// %r = phi i8* [ %x, %bb1 ], [ null, %bb0 ] 156 /// --> 157 /// %r = %x 158 static bool simplifyCommonValuePhi(PHINode *P, LazyValueInfo *LVI, 159 DominatorTree *DT) { 160 // Collect incoming constants and initialize possible common value. 161 SmallVector<std::pair<Constant *, unsigned>, 4> IncomingConstants; 162 Value *CommonValue = nullptr; 163 for (unsigned i = 0, e = P->getNumIncomingValues(); i != e; ++i) { 164 Value *Incoming = P->getIncomingValue(i); 165 if (auto *IncomingConstant = dyn_cast<Constant>(Incoming)) { 166 IncomingConstants.push_back(std::make_pair(IncomingConstant, i)); 167 } else if (!CommonValue) { 168 // The potential common value is initialized to the first non-constant. 169 CommonValue = Incoming; 170 } else if (Incoming != CommonValue) { 171 // There can be only one non-constant common value. 172 return false; 173 } 174 } 175 176 if (!CommonValue || IncomingConstants.empty()) 177 return false; 178 179 // The common value must be valid in all incoming blocks. 180 BasicBlock *ToBB = P->getParent(); 181 if (auto *CommonInst = dyn_cast<Instruction>(CommonValue)) 182 if (!DT->dominates(CommonInst, ToBB)) 183 return false; 184 185 // We have a phi with exactly 1 variable incoming value and 1 or more constant 186 // incoming values. See if all constant incoming values can be mapped back to 187 // the same incoming variable value. 188 for (auto &IncomingConstant : IncomingConstants) { 189 Constant *C = IncomingConstant.first; 190 BasicBlock *IncomingBB = P->getIncomingBlock(IncomingConstant.second); 191 if (C != LVI->getConstantOnEdge(CommonValue, IncomingBB, ToBB, P)) 192 return false; 193 } 194 195 // All constant incoming values map to the same variable along the incoming 196 // edges of the phi. The phi is unnecessary. However, we must drop all 197 // poison-generating flags to ensure that no poison is propagated to the phi 198 // location by performing this substitution. 199 // Warning: If the underlying analysis changes, this may not be enough to 200 // guarantee that poison is not propagated. 201 // TODO: We may be able to re-infer flags by re-analyzing the instruction. 202 if (auto *CommonInst = dyn_cast<Instruction>(CommonValue)) 203 CommonInst->dropPoisonGeneratingFlags(); 204 P->replaceAllUsesWith(CommonValue); 205 P->eraseFromParent(); 206 ++NumPhiCommon; 207 return true; 208 } 209 210 static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT, 211 const SimplifyQuery &SQ) { 212 bool Changed = false; 213 214 BasicBlock *BB = P->getParent(); 215 for (unsigned i = 0, e = P->getNumIncomingValues(); i < e; ++i) { 216 Value *Incoming = P->getIncomingValue(i); 217 if (isa<Constant>(Incoming)) continue; 218 219 Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB, P); 220 221 // Look if the incoming value is a select with a scalar condition for which 222 // LVI can tells us the value. In that case replace the incoming value with 223 // the appropriate value of the select. This often allows us to remove the 224 // select later. 225 if (!V) { 226 SelectInst *SI = dyn_cast<SelectInst>(Incoming); 227 if (!SI) continue; 228 229 Value *Condition = SI->getCondition(); 230 if (!Condition->getType()->isVectorTy()) { 231 if (Constant *C = LVI->getConstantOnEdge( 232 Condition, P->getIncomingBlock(i), BB, P)) { 233 if (C->isOneValue()) { 234 V = SI->getTrueValue(); 235 } else if (C->isZeroValue()) { 236 V = SI->getFalseValue(); 237 } 238 // Once LVI learns to handle vector types, we could also add support 239 // for vector type constants that are not all zeroes or all ones. 240 } 241 } 242 243 // Look if the select has a constant but LVI tells us that the incoming 244 // value can never be that constant. In that case replace the incoming 245 // value with the other value of the select. This often allows us to 246 // remove the select later. 247 if (!V) { 248 Constant *C = dyn_cast<Constant>(SI->getFalseValue()); 249 if (!C) continue; 250 251 if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, 252 P->getIncomingBlock(i), BB, P) != 253 LazyValueInfo::False) 254 continue; 255 V = SI->getTrueValue(); 256 } 257 258 LLVM_DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n'); 259 } 260 261 P->setIncomingValue(i, V); 262 Changed = true; 263 } 264 265 if (Value *V = SimplifyInstruction(P, SQ)) { 266 P->replaceAllUsesWith(V); 267 P->eraseFromParent(); 268 Changed = true; 269 } 270 271 if (!Changed) 272 Changed = simplifyCommonValuePhi(P, LVI, DT); 273 274 if (Changed) 275 ++NumPhis; 276 277 return Changed; 278 } 279 280 static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) { 281 Value *Pointer = nullptr; 282 if (LoadInst *L = dyn_cast<LoadInst>(I)) 283 Pointer = L->getPointerOperand(); 284 else 285 Pointer = cast<StoreInst>(I)->getPointerOperand(); 286 287 if (isa<Constant>(Pointer)) return false; 288 289 Constant *C = LVI->getConstant(Pointer, I); 290 if (!C) return false; 291 292 ++NumMemAccess; 293 I->replaceUsesOfWith(Pointer, C); 294 return true; 295 } 296 297 /// See if LazyValueInfo's ability to exploit edge conditions or range 298 /// information is sufficient to prove this comparison. Even for local 299 /// conditions, this can sometimes prove conditions instcombine can't by 300 /// exploiting range information. 301 static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { 302 Value *Op0 = Cmp->getOperand(0); 303 auto *C = dyn_cast<Constant>(Cmp->getOperand(1)); 304 if (!C) 305 return false; 306 307 LazyValueInfo::Tristate Result = 308 LVI->getPredicateAt(Cmp->getPredicate(), Op0, C, Cmp, 309 /*UseBlockValue=*/true); 310 if (Result == LazyValueInfo::Unknown) 311 return false; 312 313 ++NumCmps; 314 Constant *TorF = ConstantInt::get(Type::getInt1Ty(Cmp->getContext()), Result); 315 Cmp->replaceAllUsesWith(TorF); 316 Cmp->eraseFromParent(); 317 return true; 318 } 319 320 /// Simplify a switch instruction by removing cases which can never fire. If the 321 /// uselessness of a case could be determined locally then constant propagation 322 /// would already have figured it out. Instead, walk the predecessors and 323 /// statically evaluate cases based on information available on that edge. Cases 324 /// that cannot fire no matter what the incoming edge can safely be removed. If 325 /// a case fires on every incoming edge then the entire switch can be removed 326 /// and replaced with a branch to the case destination. 327 static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, 328 DominatorTree *DT) { 329 DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy); 330 Value *Cond = I->getCondition(); 331 BasicBlock *BB = I->getParent(); 332 333 // Analyse each switch case in turn. 334 bool Changed = false; 335 DenseMap<BasicBlock*, int> SuccessorsCount; 336 for (auto *Succ : successors(BB)) 337 SuccessorsCount[Succ]++; 338 339 { // Scope for SwitchInstProfUpdateWrapper. It must not live during 340 // ConstantFoldTerminator() as the underlying SwitchInst can be changed. 341 SwitchInstProfUpdateWrapper SI(*I); 342 343 for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { 344 ConstantInt *Case = CI->getCaseValue(); 345 LazyValueInfo::Tristate State = 346 LVI->getPredicateAt(CmpInst::ICMP_EQ, Cond, Case, I, 347 /* UseBlockValue */ true); 348 349 if (State == LazyValueInfo::False) { 350 // This case never fires - remove it. 351 BasicBlock *Succ = CI->getCaseSuccessor(); 352 Succ->removePredecessor(BB); 353 CI = SI.removeCase(CI); 354 CE = SI->case_end(); 355 356 // The condition can be modified by removePredecessor's PHI simplification 357 // logic. 358 Cond = SI->getCondition(); 359 360 ++NumDeadCases; 361 Changed = true; 362 if (--SuccessorsCount[Succ] == 0) 363 DTU.applyUpdatesPermissive({{DominatorTree::Delete, BB, Succ}}); 364 continue; 365 } 366 if (State == LazyValueInfo::True) { 367 // This case always fires. Arrange for the switch to be turned into an 368 // unconditional branch by replacing the switch condition with the case 369 // value. 370 SI->setCondition(Case); 371 NumDeadCases += SI->getNumCases(); 372 Changed = true; 373 break; 374 } 375 376 // Increment the case iterator since we didn't delete it. 377 ++CI; 378 } 379 } 380 381 if (Changed) 382 // If the switch has been simplified to the point where it can be replaced 383 // by a branch then do so now. 384 ConstantFoldTerminator(BB, /*DeleteDeadConditions = */ false, 385 /*TLI = */ nullptr, &DTU); 386 return Changed; 387 } 388 389 // See if we can prove that the given binary op intrinsic will not overflow. 390 static bool willNotOverflow(BinaryOpIntrinsic *BO, LazyValueInfo *LVI) { 391 ConstantRange LRange = LVI->getConstantRange(BO->getLHS(), BO); 392 ConstantRange RRange = LVI->getConstantRange(BO->getRHS(), BO); 393 ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 394 BO->getBinaryOp(), RRange, BO->getNoWrapKind()); 395 return NWRegion.contains(LRange); 396 } 397 398 static void setDeducedOverflowingFlags(Value *V, Instruction::BinaryOps Opcode, 399 bool NewNSW, bool NewNUW) { 400 Statistic *OpcNW, *OpcNSW, *OpcNUW; 401 switch (Opcode) { 402 case Instruction::Add: 403 OpcNW = &NumAddNW; 404 OpcNSW = &NumAddNSW; 405 OpcNUW = &NumAddNUW; 406 break; 407 case Instruction::Sub: 408 OpcNW = &NumSubNW; 409 OpcNSW = &NumSubNSW; 410 OpcNUW = &NumSubNUW; 411 break; 412 case Instruction::Mul: 413 OpcNW = &NumMulNW; 414 OpcNSW = &NumMulNSW; 415 OpcNUW = &NumMulNUW; 416 break; 417 case Instruction::Shl: 418 OpcNW = &NumShlNW; 419 OpcNSW = &NumShlNSW; 420 OpcNUW = &NumShlNUW; 421 break; 422 default: 423 llvm_unreachable("Will not be called with other binops"); 424 } 425 426 auto *Inst = dyn_cast<Instruction>(V); 427 if (NewNSW) { 428 ++NumNW; 429 ++*OpcNW; 430 ++NumNSW; 431 ++*OpcNSW; 432 if (Inst) 433 Inst->setHasNoSignedWrap(); 434 } 435 if (NewNUW) { 436 ++NumNW; 437 ++*OpcNW; 438 ++NumNUW; 439 ++*OpcNUW; 440 if (Inst) 441 Inst->setHasNoUnsignedWrap(); 442 } 443 } 444 445 static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI); 446 447 // Rewrite this with.overflow intrinsic as non-overflowing. 448 static void processOverflowIntrinsic(WithOverflowInst *WO, LazyValueInfo *LVI) { 449 IRBuilder<> B(WO); 450 Instruction::BinaryOps Opcode = WO->getBinaryOp(); 451 bool NSW = WO->isSigned(); 452 bool NUW = !WO->isSigned(); 453 454 Value *NewOp = 455 B.CreateBinOp(Opcode, WO->getLHS(), WO->getRHS(), WO->getName()); 456 setDeducedOverflowingFlags(NewOp, Opcode, NSW, NUW); 457 458 StructType *ST = cast<StructType>(WO->getType()); 459 Constant *Struct = ConstantStruct::get(ST, 460 { UndefValue::get(ST->getElementType(0)), 461 ConstantInt::getFalse(ST->getElementType(1)) }); 462 Value *NewI = B.CreateInsertValue(Struct, NewOp, 0); 463 WO->replaceAllUsesWith(NewI); 464 WO->eraseFromParent(); 465 ++NumOverflows; 466 467 // See if we can infer the other no-wrap too. 468 if (auto *BO = dyn_cast<BinaryOperator>(NewOp)) 469 processBinOp(BO, LVI); 470 } 471 472 static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) { 473 Instruction::BinaryOps Opcode = SI->getBinaryOp(); 474 bool NSW = SI->isSigned(); 475 bool NUW = !SI->isSigned(); 476 BinaryOperator *BinOp = BinaryOperator::Create( 477 Opcode, SI->getLHS(), SI->getRHS(), SI->getName(), SI); 478 BinOp->setDebugLoc(SI->getDebugLoc()); 479 setDeducedOverflowingFlags(BinOp, Opcode, NSW, NUW); 480 481 SI->replaceAllUsesWith(BinOp); 482 SI->eraseFromParent(); 483 ++NumSaturating; 484 485 // See if we can infer the other no-wrap too. 486 if (auto *BO = dyn_cast<BinaryOperator>(BinOp)) 487 processBinOp(BO, LVI); 488 } 489 490 /// Infer nonnull attributes for the arguments at the specified callsite. 491 static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) { 492 493 if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) { 494 if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) { 495 processOverflowIntrinsic(WO, LVI); 496 return true; 497 } 498 } 499 500 if (auto *SI = dyn_cast<SaturatingInst>(&CB)) { 501 if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) { 502 processSaturatingInst(SI, LVI); 503 return true; 504 } 505 } 506 507 bool Changed = false; 508 509 // Deopt bundle operands are intended to capture state with minimal 510 // perturbance of the code otherwise. If we can find a constant value for 511 // any such operand and remove a use of the original value, that's 512 // desireable since it may allow further optimization of that value (e.g. via 513 // single use rules in instcombine). Since deopt uses tend to, 514 // idiomatically, appear along rare conditional paths, it's reasonable likely 515 // we may have a conditional fact with which LVI can fold. 516 if (auto DeoptBundle = CB.getOperandBundle(LLVMContext::OB_deopt)) { 517 for (const Use &ConstU : DeoptBundle->Inputs) { 518 Use &U = const_cast<Use&>(ConstU); 519 Value *V = U.get(); 520 if (V->getType()->isVectorTy()) continue; 521 if (isa<Constant>(V)) continue; 522 523 Constant *C = LVI->getConstant(V, &CB); 524 if (!C) continue; 525 U.set(C); 526 Changed = true; 527 } 528 } 529 530 SmallVector<unsigned, 4> ArgNos; 531 unsigned ArgNo = 0; 532 533 for (Value *V : CB.args()) { 534 PointerType *Type = dyn_cast<PointerType>(V->getType()); 535 // Try to mark pointer typed parameters as non-null. We skip the 536 // relatively expensive analysis for constants which are obviously either 537 // null or non-null to start with. 538 if (Type && !CB.paramHasAttr(ArgNo, Attribute::NonNull) && 539 !isa<Constant>(V) && 540 LVI->getPredicateAt(ICmpInst::ICMP_EQ, V, 541 ConstantPointerNull::get(Type), 542 &CB) == LazyValueInfo::False) 543 ArgNos.push_back(ArgNo); 544 ArgNo++; 545 } 546 547 assert(ArgNo == CB.arg_size() && "sanity check"); 548 549 if (ArgNos.empty()) 550 return Changed; 551 552 AttributeList AS = CB.getAttributes(); 553 LLVMContext &Ctx = CB.getContext(); 554 AS = AS.addParamAttribute(Ctx, ArgNos, 555 Attribute::get(Ctx, Attribute::NonNull)); 556 CB.setAttributes(AS); 557 558 return true; 559 } 560 561 static bool isNonNegative(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { 562 Constant *Zero = ConstantInt::get(V->getType(), 0); 563 auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SGE, V, Zero, CxtI); 564 return Result == LazyValueInfo::True; 565 } 566 567 static bool isNonPositive(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { 568 Constant *Zero = ConstantInt::get(V->getType(), 0); 569 auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SLE, V, Zero, CxtI); 570 return Result == LazyValueInfo::True; 571 } 572 573 enum class Domain { NonNegative, NonPositive, Unknown }; 574 575 Domain getDomain(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { 576 if (isNonNegative(V, LVI, CxtI)) 577 return Domain::NonNegative; 578 if (isNonPositive(V, LVI, CxtI)) 579 return Domain::NonPositive; 580 return Domain::Unknown; 581 } 582 583 /// Try to shrink a sdiv/srem's width down to the smallest power of two that's 584 /// sufficient to contain its operands. 585 static bool narrowSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { 586 assert(Instr->getOpcode() == Instruction::SDiv || 587 Instr->getOpcode() == Instruction::SRem); 588 if (Instr->getType()->isVectorTy()) 589 return false; 590 591 // Find the smallest power of two bitwidth that's sufficient to hold Instr's 592 // operands. 593 unsigned OrigWidth = Instr->getType()->getIntegerBitWidth(); 594 595 // What is the smallest bit width that can accomodate the entire value ranges 596 // of both of the operands? 597 std::array<Optional<ConstantRange>, 2> CRs; 598 unsigned MinSignedBits = 0; 599 for (auto I : zip(Instr->operands(), CRs)) { 600 std::get<1>(I) = LVI->getConstantRange(std::get<0>(I), Instr); 601 MinSignedBits = std::max(std::get<1>(I)->getMinSignedBits(), MinSignedBits); 602 } 603 604 // sdiv/srem is UB if divisor is -1 and divident is INT_MIN, so unless we can 605 // prove that such a combination is impossible, we need to bump the bitwidth. 606 if (CRs[1]->contains(APInt::getAllOnesValue(OrigWidth)) && 607 CRs[0]->contains( 608 APInt::getSignedMinValue(MinSignedBits).sextOrSelf(OrigWidth))) 609 ++MinSignedBits; 610 611 // Don't shrink below 8 bits wide. 612 unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MinSignedBits), 8); 613 614 // NewWidth might be greater than OrigWidth if OrigWidth is not a power of 615 // two. 616 if (NewWidth >= OrigWidth) 617 return false; 618 619 ++NumSDivSRemsNarrowed; 620 IRBuilder<> B{Instr}; 621 auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth); 622 auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy, 623 Instr->getName() + ".lhs.trunc"); 624 auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy, 625 Instr->getName() + ".rhs.trunc"); 626 auto *BO = B.CreateBinOp(Instr->getOpcode(), LHS, RHS, Instr->getName()); 627 auto *Sext = B.CreateSExt(BO, Instr->getType(), Instr->getName() + ".sext"); 628 if (auto *BinOp = dyn_cast<BinaryOperator>(BO)) 629 if (BinOp->getOpcode() == Instruction::SDiv) 630 BinOp->setIsExact(Instr->isExact()); 631 632 Instr->replaceAllUsesWith(Sext); 633 Instr->eraseFromParent(); 634 return true; 635 } 636 637 /// Try to shrink a udiv/urem's width down to the smallest power of two that's 638 /// sufficient to contain its operands. 639 static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { 640 assert(Instr->getOpcode() == Instruction::UDiv || 641 Instr->getOpcode() == Instruction::URem); 642 if (Instr->getType()->isVectorTy()) 643 return false; 644 645 // Find the smallest power of two bitwidth that's sufficient to hold Instr's 646 // operands. 647 648 // What is the smallest bit width that can accomodate the entire value ranges 649 // of both of the operands? 650 unsigned MaxActiveBits = 0; 651 for (Value *Operand : Instr->operands()) { 652 ConstantRange CR = LVI->getConstantRange(Operand, Instr); 653 MaxActiveBits = std::max(CR.getActiveBits(), MaxActiveBits); 654 } 655 // Don't shrink below 8 bits wide. 656 unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MaxActiveBits), 8); 657 658 // NewWidth might be greater than OrigWidth if OrigWidth is not a power of 659 // two. 660 if (NewWidth >= Instr->getType()->getIntegerBitWidth()) 661 return false; 662 663 ++NumUDivURemsNarrowed; 664 IRBuilder<> B{Instr}; 665 auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth); 666 auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy, 667 Instr->getName() + ".lhs.trunc"); 668 auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy, 669 Instr->getName() + ".rhs.trunc"); 670 auto *BO = B.CreateBinOp(Instr->getOpcode(), LHS, RHS, Instr->getName()); 671 auto *Zext = B.CreateZExt(BO, Instr->getType(), Instr->getName() + ".zext"); 672 if (auto *BinOp = dyn_cast<BinaryOperator>(BO)) 673 if (BinOp->getOpcode() == Instruction::UDiv) 674 BinOp->setIsExact(Instr->isExact()); 675 676 Instr->replaceAllUsesWith(Zext); 677 Instr->eraseFromParent(); 678 return true; 679 } 680 681 static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { 682 assert(SDI->getOpcode() == Instruction::SRem); 683 if (SDI->getType()->isVectorTy()) 684 return false; 685 686 struct Operand { 687 Value *V; 688 Domain D; 689 }; 690 std::array<Operand, 2> Ops; 691 692 for (const auto I : zip(Ops, SDI->operands())) { 693 Operand &Op = std::get<0>(I); 694 Op.V = std::get<1>(I); 695 Op.D = getDomain(Op.V, LVI, SDI); 696 if (Op.D == Domain::Unknown) 697 return false; 698 } 699 700 // We know domains of both of the operands! 701 ++NumSRems; 702 703 // We need operands to be non-negative, so negate each one that isn't. 704 for (Operand &Op : Ops) { 705 if (Op.D == Domain::NonNegative) 706 continue; 707 auto *BO = 708 BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", SDI); 709 BO->setDebugLoc(SDI->getDebugLoc()); 710 Op.V = BO; 711 } 712 713 auto *URem = 714 BinaryOperator::CreateURem(Ops[0].V, Ops[1].V, SDI->getName(), SDI); 715 URem->setDebugLoc(SDI->getDebugLoc()); 716 717 Value *Res = URem; 718 719 // If the divident was non-positive, we need to negate the result. 720 if (Ops[0].D == Domain::NonPositive) 721 Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI); 722 723 SDI->replaceAllUsesWith(Res); 724 SDI->eraseFromParent(); 725 726 // Try to simplify our new urem. 727 processUDivOrURem(URem, LVI); 728 729 return true; 730 } 731 732 /// See if LazyValueInfo's ability to exploit edge conditions or range 733 /// information is sufficient to prove the signs of both operands of this SDiv. 734 /// If this is the case, replace the SDiv with a UDiv. Even for local 735 /// conditions, this can sometimes prove conditions instcombine can't by 736 /// exploiting range information. 737 static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { 738 assert(SDI->getOpcode() == Instruction::SDiv); 739 if (SDI->getType()->isVectorTy()) 740 return false; 741 742 struct Operand { 743 Value *V; 744 Domain D; 745 }; 746 std::array<Operand, 2> Ops; 747 748 for (const auto I : zip(Ops, SDI->operands())) { 749 Operand &Op = std::get<0>(I); 750 Op.V = std::get<1>(I); 751 Op.D = getDomain(Op.V, LVI, SDI); 752 if (Op.D == Domain::Unknown) 753 return false; 754 } 755 756 // We know domains of both of the operands! 757 ++NumSDivs; 758 759 // We need operands to be non-negative, so negate each one that isn't. 760 for (Operand &Op : Ops) { 761 if (Op.D == Domain::NonNegative) 762 continue; 763 auto *BO = 764 BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", SDI); 765 BO->setDebugLoc(SDI->getDebugLoc()); 766 Op.V = BO; 767 } 768 769 auto *UDiv = 770 BinaryOperator::CreateUDiv(Ops[0].V, Ops[1].V, SDI->getName(), SDI); 771 UDiv->setDebugLoc(SDI->getDebugLoc()); 772 UDiv->setIsExact(SDI->isExact()); 773 774 Value *Res = UDiv; 775 776 // If the operands had two different domains, we need to negate the result. 777 if (Ops[0].D != Ops[1].D) 778 Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI); 779 780 SDI->replaceAllUsesWith(Res); 781 SDI->eraseFromParent(); 782 783 // Try to simplify our new udiv. 784 processUDivOrURem(UDiv, LVI); 785 786 return true; 787 } 788 789 static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { 790 assert(Instr->getOpcode() == Instruction::SDiv || 791 Instr->getOpcode() == Instruction::SRem); 792 if (Instr->getType()->isVectorTy()) 793 return false; 794 795 if (Instr->getOpcode() == Instruction::SDiv) 796 if (processSDiv(Instr, LVI)) 797 return true; 798 799 if (Instr->getOpcode() == Instruction::SRem) 800 if (processSRem(Instr, LVI)) 801 return true; 802 803 return narrowSDivOrSRem(Instr, LVI); 804 } 805 806 static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { 807 if (SDI->getType()->isVectorTy()) 808 return false; 809 810 if (!isNonNegative(SDI->getOperand(0), LVI, SDI)) 811 return false; 812 813 ++NumAShrs; 814 auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1), 815 SDI->getName(), SDI); 816 BO->setDebugLoc(SDI->getDebugLoc()); 817 BO->setIsExact(SDI->isExact()); 818 SDI->replaceAllUsesWith(BO); 819 SDI->eraseFromParent(); 820 821 return true; 822 } 823 824 static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { 825 if (SDI->getType()->isVectorTy()) 826 return false; 827 828 Value *Base = SDI->getOperand(0); 829 830 if (!isNonNegative(Base, LVI, SDI)) 831 return false; 832 833 ++NumSExt; 834 auto *ZExt = 835 CastInst::CreateZExtOrBitCast(Base, SDI->getType(), SDI->getName(), SDI); 836 ZExt->setDebugLoc(SDI->getDebugLoc()); 837 SDI->replaceAllUsesWith(ZExt); 838 SDI->eraseFromParent(); 839 840 return true; 841 } 842 843 static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { 844 using OBO = OverflowingBinaryOperator; 845 846 if (DontAddNoWrapFlags) 847 return false; 848 849 if (BinOp->getType()->isVectorTy()) 850 return false; 851 852 bool NSW = BinOp->hasNoSignedWrap(); 853 bool NUW = BinOp->hasNoUnsignedWrap(); 854 if (NSW && NUW) 855 return false; 856 857 Instruction::BinaryOps Opcode = BinOp->getOpcode(); 858 Value *LHS = BinOp->getOperand(0); 859 Value *RHS = BinOp->getOperand(1); 860 861 ConstantRange LRange = LVI->getConstantRange(LHS, BinOp); 862 ConstantRange RRange = LVI->getConstantRange(RHS, BinOp); 863 864 bool Changed = false; 865 bool NewNUW = false, NewNSW = false; 866 if (!NUW) { 867 ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( 868 Opcode, RRange, OBO::NoUnsignedWrap); 869 NewNUW = NUWRange.contains(LRange); 870 Changed |= NewNUW; 871 } 872 if (!NSW) { 873 ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( 874 Opcode, RRange, OBO::NoSignedWrap); 875 NewNSW = NSWRange.contains(LRange); 876 Changed |= NewNSW; 877 } 878 879 setDeducedOverflowingFlags(BinOp, Opcode, NewNSW, NewNUW); 880 881 return Changed; 882 } 883 884 static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) { 885 if (BinOp->getType()->isVectorTy()) 886 return false; 887 888 // Pattern match (and lhs, C) where C includes a superset of bits which might 889 // be set in lhs. This is a common truncation idiom created by instcombine. 890 Value *LHS = BinOp->getOperand(0); 891 ConstantInt *RHS = dyn_cast<ConstantInt>(BinOp->getOperand(1)); 892 if (!RHS || !RHS->getValue().isMask()) 893 return false; 894 895 // We can only replace the AND with LHS based on range info if the range does 896 // not include undef. 897 ConstantRange LRange = 898 LVI->getConstantRange(LHS, BinOp, /*UndefAllowed=*/false); 899 if (!LRange.getUnsignedMax().ule(RHS->getValue())) 900 return false; 901 902 BinOp->replaceAllUsesWith(LHS); 903 BinOp->eraseFromParent(); 904 NumAnd++; 905 return true; 906 } 907 908 909 static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { 910 if (Constant *C = LVI->getConstant(V, At)) 911 return C; 912 913 // TODO: The following really should be sunk inside LVI's core algorithm, or 914 // at least the outer shims around such. 915 auto *C = dyn_cast<CmpInst>(V); 916 if (!C) return nullptr; 917 918 Value *Op0 = C->getOperand(0); 919 Constant *Op1 = dyn_cast<Constant>(C->getOperand(1)); 920 if (!Op1) return nullptr; 921 922 LazyValueInfo::Tristate Result = 923 LVI->getPredicateAt(C->getPredicate(), Op0, Op1, At); 924 if (Result == LazyValueInfo::Unknown) 925 return nullptr; 926 927 return (Result == LazyValueInfo::True) ? 928 ConstantInt::getTrue(C->getContext()) : 929 ConstantInt::getFalse(C->getContext()); 930 } 931 932 static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, 933 const SimplifyQuery &SQ) { 934 bool FnChanged = false; 935 // Visiting in a pre-order depth-first traversal causes us to simplify early 936 // blocks before querying later blocks (which require us to analyze early 937 // blocks). Eagerly simplifying shallow blocks means there is strictly less 938 // work to do for deep blocks. This also means we don't visit unreachable 939 // blocks. 940 for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { 941 bool BBChanged = false; 942 for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { 943 Instruction *II = &*BI++; 944 switch (II->getOpcode()) { 945 case Instruction::Select: 946 BBChanged |= processSelect(cast<SelectInst>(II), LVI); 947 break; 948 case Instruction::PHI: 949 BBChanged |= processPHI(cast<PHINode>(II), LVI, DT, SQ); 950 break; 951 case Instruction::ICmp: 952 case Instruction::FCmp: 953 BBChanged |= processCmp(cast<CmpInst>(II), LVI); 954 break; 955 case Instruction::Load: 956 case Instruction::Store: 957 BBChanged |= processMemAccess(II, LVI); 958 break; 959 case Instruction::Call: 960 case Instruction::Invoke: 961 BBChanged |= processCallSite(cast<CallBase>(*II), LVI); 962 break; 963 case Instruction::SRem: 964 case Instruction::SDiv: 965 BBChanged |= processSDivOrSRem(cast<BinaryOperator>(II), LVI); 966 break; 967 case Instruction::UDiv: 968 case Instruction::URem: 969 BBChanged |= processUDivOrURem(cast<BinaryOperator>(II), LVI); 970 break; 971 case Instruction::AShr: 972 BBChanged |= processAShr(cast<BinaryOperator>(II), LVI); 973 break; 974 case Instruction::SExt: 975 BBChanged |= processSExt(cast<SExtInst>(II), LVI); 976 break; 977 case Instruction::Add: 978 case Instruction::Sub: 979 case Instruction::Mul: 980 case Instruction::Shl: 981 BBChanged |= processBinOp(cast<BinaryOperator>(II), LVI); 982 break; 983 case Instruction::And: 984 BBChanged |= processAnd(cast<BinaryOperator>(II), LVI); 985 break; 986 } 987 } 988 989 Instruction *Term = BB->getTerminator(); 990 switch (Term->getOpcode()) { 991 case Instruction::Switch: 992 BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI, DT); 993 break; 994 case Instruction::Ret: { 995 auto *RI = cast<ReturnInst>(Term); 996 // Try to determine the return value if we can. This is mainly here to 997 // simplify the writing of unit tests, but also helps to enable IPO by 998 // constant folding the return values of callees. 999 auto *RetVal = RI->getReturnValue(); 1000 if (!RetVal) break; // handle "ret void" 1001 if (isa<Constant>(RetVal)) break; // nothing to do 1002 if (auto *C = getConstantAt(RetVal, RI, LVI)) { 1003 ++NumReturns; 1004 RI->replaceUsesOfWith(RetVal, C); 1005 BBChanged = true; 1006 } 1007 } 1008 } 1009 1010 FnChanged |= BBChanged; 1011 } 1012 1013 return FnChanged; 1014 } 1015 1016 bool CorrelatedValuePropagation::runOnFunction(Function &F) { 1017 if (skipFunction(F)) 1018 return false; 1019 1020 LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); 1021 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1022 1023 return runImpl(F, LVI, DT, getBestSimplifyQuery(*this, F)); 1024 } 1025 1026 PreservedAnalyses 1027 CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { 1028 LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F); 1029 DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); 1030 1031 bool Changed = runImpl(F, LVI, DT, getBestSimplifyQuery(AM, F)); 1032 1033 PreservedAnalyses PA; 1034 if (!Changed) { 1035 PA = PreservedAnalyses::all(); 1036 } else { 1037 PA.preserve<GlobalsAA>(); 1038 PA.preserve<DominatorTreeAnalysis>(); 1039 PA.preserve<LazyValueAnalysis>(); 1040 } 1041 1042 // Keeping LVI alive is expensive, both because it uses a lot of memory, and 1043 // because invalidating values in LVI is expensive. While CVP does preserve 1044 // LVI, we know that passes after JumpThreading+CVP will not need the result 1045 // of this analysis, so we forcefully discard it early. 1046 PA.abandon<LazyValueAnalysis>(); 1047 return PA; 1048 } 1049