1 //===-- SCCP.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Interprocedural Sparse Conditional Constant Propagation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/IPO/SCCP.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/Analysis/AssumptionCache.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/Analysis/PostDominators.h"
18 #include "llvm/Analysis/TargetLibraryInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueLattice.h"
21 #include "llvm/Analysis/ValueLatticeUtils.h"
22 #include "llvm/Analysis/ValueTracking.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/ModRef.h"
28 #include "llvm/Transforms/IPO.h"
29 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
30 #include "llvm/Transforms/Scalar/SCCP.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include "llvm/Transforms/Utils/SCCPSolver.h"
33 
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "sccp"
37 
38 STATISTIC(NumInstRemoved, "Number of instructions removed");
39 STATISTIC(NumArgsElimed ,"Number of arguments constant propagated");
40 STATISTIC(NumGlobalConst, "Number of globals found to be constant");
41 STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable");
42 STATISTIC(NumInstReplaced,
43           "Number of instructions replaced with (simpler) instruction");
44 
45 static cl::opt<unsigned> FuncSpecializationMaxIters(
46     "func-specialization-max-iters", cl::init(1), cl::Hidden, cl::desc(
47     "The maximum number of iterations function specialization is run"));
48 
49 static void findReturnsToZap(Function &F,
50                              SmallVector<ReturnInst *, 8> &ReturnsToZap,
51                              SCCPSolver &Solver) {
52   // We can only do this if we know that nothing else can call the function.
53   if (!Solver.isArgumentTrackedFunction(&F))
54     return;
55 
56   if (Solver.mustPreserveReturn(&F)) {
57     LLVM_DEBUG(
58         dbgs()
59         << "Can't zap returns of the function : " << F.getName()
60         << " due to present musttail or \"clang.arc.attachedcall\" call of "
61            "it\n");
62     return;
63   }
64 
65   assert(
66       all_of(F.users(),
67              [&Solver](User *U) {
68                if (isa<Instruction>(U) &&
69                    !Solver.isBlockExecutable(cast<Instruction>(U)->getParent()))
70                  return true;
71                // Non-callsite uses are not impacted by zapping. Also, constant
72                // uses (like blockaddresses) could stuck around, without being
73                // used in the underlying IR, meaning we do not have lattice
74                // values for them.
75                if (!isa<CallBase>(U))
76                  return true;
77                if (U->getType()->isStructTy()) {
78                  return all_of(Solver.getStructLatticeValueFor(U),
79                                [](const ValueLatticeElement &LV) {
80                                  return !SCCPSolver::isOverdefined(LV);
81                                });
82                }
83 
84                // We don't consider assume-like intrinsics to be actual address
85                // captures.
86                if (auto *II = dyn_cast<IntrinsicInst>(U)) {
87                  if (II->isAssumeLikeIntrinsic())
88                    return true;
89                }
90 
91                return !SCCPSolver::isOverdefined(Solver.getLatticeValueFor(U));
92              }) &&
93       "We can only zap functions where all live users have a concrete value");
94 
95   for (BasicBlock &BB : F) {
96     if (CallInst *CI = BB.getTerminatingMustTailCall()) {
97       LLVM_DEBUG(dbgs() << "Can't zap return of the block due to present "
98                         << "musttail call : " << *CI << "\n");
99       (void)CI;
100       return;
101     }
102 
103     if (auto *RI = dyn_cast<ReturnInst>(BB.getTerminator()))
104       if (!isa<UndefValue>(RI->getOperand(0)))
105         ReturnsToZap.push_back(RI);
106   }
107 }
108 
109 static bool runIPSCCP(
110     Module &M, const DataLayout &DL, FunctionAnalysisManager *FAM,
111     std::function<const TargetLibraryInfo &(Function &)> GetTLI,
112     std::function<TargetTransformInfo &(Function &)> GetTTI,
113     std::function<AssumptionCache &(Function &)> GetAC,
114     function_ref<AnalysisResultsForFn(Function &)> getAnalysis,
115     bool IsFuncSpecEnabled) {
116   SCCPSolver Solver(DL, GetTLI, M.getContext());
117   FunctionSpecializer Specializer(Solver, M, FAM, GetTLI, GetTTI, GetAC);
118 
119   // Loop over all functions, marking arguments to those with their addresses
120   // taken or that are external as overdefined.
121   for (Function &F : M) {
122     if (F.isDeclaration())
123       continue;
124 
125     Solver.addAnalysis(F, getAnalysis(F));
126 
127     // Determine if we can track the function's return values. If so, add the
128     // function to the solver's set of return-tracked functions.
129     if (canTrackReturnsInterprocedurally(&F))
130       Solver.addTrackedFunction(&F);
131 
132     // Determine if we can track the function's arguments. If so, add the
133     // function to the solver's set of argument-tracked functions.
134     if (canTrackArgumentsInterprocedurally(&F)) {
135       Solver.addArgumentTrackedFunction(&F);
136       continue;
137     }
138 
139     // Assume the function is called.
140     Solver.markBlockExecutable(&F.front());
141 
142     // Assume nothing about the incoming arguments.
143     for (Argument &AI : F.args())
144       Solver.markOverdefined(&AI);
145   }
146 
147   // Determine if we can track any of the module's global variables. If so, add
148   // the global variables we can track to the solver's set of tracked global
149   // variables.
150   for (GlobalVariable &G : M.globals()) {
151     G.removeDeadConstantUsers();
152     if (canTrackGlobalVariableInterprocedurally(&G))
153       Solver.trackValueOfGlobalVariable(&G);
154   }
155 
156   // Solve for constants.
157   Solver.solveWhileResolvedUndefsIn(M);
158 
159   if (IsFuncSpecEnabled) {
160     unsigned Iters = 0;
161     while (Iters++ < FuncSpecializationMaxIters && Specializer.run());
162   }
163 
164   // Iterate over all of the instructions in the module, replacing them with
165   // constants if we have found them to be of constant values.
166   bool MadeChanges = false;
167   for (Function &F : M) {
168     if (F.isDeclaration())
169       continue;
170 
171     SmallVector<BasicBlock *, 512> BlocksToErase;
172 
173     if (Solver.isBlockExecutable(&F.front())) {
174       bool ReplacedPointerArg = false;
175       for (Argument &Arg : F.args()) {
176         if (!Arg.use_empty() && Solver.tryToReplaceWithConstant(&Arg)) {
177           ReplacedPointerArg |= Arg.getType()->isPointerTy();
178           ++NumArgsElimed;
179         }
180       }
181 
182       // If we replaced an argument, we may now also access a global (currently
183       // classified as "other" memory). Update memory attribute to reflect this.
184       if (ReplacedPointerArg) {
185         auto UpdateAttrs = [&](AttributeList AL) {
186           MemoryEffects ME = AL.getMemoryEffects();
187           if (ME == MemoryEffects::unknown())
188             return AL;
189 
190           ME |= MemoryEffects(MemoryEffects::Other,
191                               ME.getModRef(MemoryEffects::ArgMem));
192           return AL.addFnAttribute(
193               F.getContext(),
194               Attribute::getWithMemoryEffects(F.getContext(), ME));
195         };
196 
197         F.setAttributes(UpdateAttrs(F.getAttributes()));
198         for (User *U : F.users()) {
199           auto *CB = dyn_cast<CallBase>(U);
200           if (!CB || CB->getCalledFunction() != &F)
201             continue;
202 
203           CB->setAttributes(UpdateAttrs(CB->getAttributes()));
204         }
205       }
206       MadeChanges |= ReplacedPointerArg;
207     }
208 
209     SmallPtrSet<Value *, 32> InsertedValues;
210     for (BasicBlock &BB : F) {
211       if (!Solver.isBlockExecutable(&BB)) {
212         LLVM_DEBUG(dbgs() << "  BasicBlock Dead:" << BB);
213         ++NumDeadBlocks;
214 
215         MadeChanges = true;
216 
217         if (&BB != &F.front())
218           BlocksToErase.push_back(&BB);
219         continue;
220       }
221 
222       MadeChanges |= Solver.simplifyInstsInBlock(
223           BB, InsertedValues, NumInstRemoved, NumInstReplaced);
224     }
225 
226     DomTreeUpdater DTU = IsFuncSpecEnabled && Specializer.isClonedFunction(&F)
227         ? DomTreeUpdater(DomTreeUpdater::UpdateStrategy::Lazy)
228         : Solver.getDTU(F);
229 
230     // Change dead blocks to unreachable. We do it after replacing constants
231     // in all executable blocks, because changeToUnreachable may remove PHI
232     // nodes in executable blocks we found values for. The function's entry
233     // block is not part of BlocksToErase, so we have to handle it separately.
234     for (BasicBlock *BB : BlocksToErase) {
235       NumInstRemoved += changeToUnreachable(BB->getFirstNonPHI(),
236                                             /*PreserveLCSSA=*/false, &DTU);
237     }
238     if (!Solver.isBlockExecutable(&F.front()))
239       NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(),
240                                             /*PreserveLCSSA=*/false, &DTU);
241 
242     BasicBlock *NewUnreachableBB = nullptr;
243     for (BasicBlock &BB : F)
244       MadeChanges |= Solver.removeNonFeasibleEdges(&BB, DTU, NewUnreachableBB);
245 
246     for (BasicBlock *DeadBB : BlocksToErase)
247       if (!DeadBB->hasAddressTaken())
248         DTU.deleteBB(DeadBB);
249 
250     for (BasicBlock &BB : F) {
251       for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
252         if (Solver.getPredicateInfoFor(&Inst)) {
253           if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
254             if (II->getIntrinsicID() == Intrinsic::ssa_copy) {
255               Value *Op = II->getOperand(0);
256               Inst.replaceAllUsesWith(Op);
257               Inst.eraseFromParent();
258             }
259           }
260         }
261       }
262     }
263   }
264 
265   // If we inferred constant or undef return values for a function, we replaced
266   // all call uses with the inferred value.  This means we don't need to bother
267   // actually returning anything from the function.  Replace all return
268   // instructions with return undef.
269   //
270   // Do this in two stages: first identify the functions we should process, then
271   // actually zap their returns.  This is important because we can only do this
272   // if the address of the function isn't taken.  In cases where a return is the
273   // last use of a function, the order of processing functions would affect
274   // whether other functions are optimizable.
275   SmallVector<ReturnInst*, 8> ReturnsToZap;
276 
277   for (const auto &I : Solver.getTrackedRetVals()) {
278     Function *F = I.first;
279     const ValueLatticeElement &ReturnValue = I.second;
280 
281     // If there is a known constant range for the return value, add !range
282     // metadata to the function's call sites.
283     if (ReturnValue.isConstantRange() &&
284         !ReturnValue.getConstantRange().isSingleElement()) {
285       // Do not add range metadata if the return value may include undef.
286       if (ReturnValue.isConstantRangeIncludingUndef())
287         continue;
288 
289       auto &CR = ReturnValue.getConstantRange();
290       for (User *User : F->users()) {
291         auto *CB = dyn_cast<CallBase>(User);
292         if (!CB || CB->getCalledFunction() != F)
293           continue;
294 
295         // Limit to cases where the return value is guaranteed to be neither
296         // poison nor undef. Poison will be outside any range and currently
297         // values outside of the specified range cause immediate undefined
298         // behavior.
299         if (!isGuaranteedNotToBeUndefOrPoison(CB, nullptr, CB))
300           continue;
301 
302         // Do not touch existing metadata for now.
303         // TODO: We should be able to take the intersection of the existing
304         // metadata and the inferred range.
305         if (CB->getMetadata(LLVMContext::MD_range))
306           continue;
307 
308         LLVMContext &Context = CB->getParent()->getContext();
309         Metadata *RangeMD[] = {
310             ConstantAsMetadata::get(ConstantInt::get(Context, CR.getLower())),
311             ConstantAsMetadata::get(ConstantInt::get(Context, CR.getUpper()))};
312         CB->setMetadata(LLVMContext::MD_range, MDNode::get(Context, RangeMD));
313       }
314       continue;
315     }
316     if (F->getReturnType()->isVoidTy())
317       continue;
318     if (SCCPSolver::isConstant(ReturnValue) || ReturnValue.isUnknownOrUndef())
319       findReturnsToZap(*F, ReturnsToZap, Solver);
320   }
321 
322   for (auto *F : Solver.getMRVFunctionsTracked()) {
323     assert(F->getReturnType()->isStructTy() &&
324            "The return type should be a struct");
325     StructType *STy = cast<StructType>(F->getReturnType());
326     if (Solver.isStructLatticeConstant(F, STy))
327       findReturnsToZap(*F, ReturnsToZap, Solver);
328   }
329 
330   // Zap all returns which we've identified as zap to change.
331   SmallSetVector<Function *, 8> FuncZappedReturn;
332   for (ReturnInst *RI : ReturnsToZap) {
333     Function *F = RI->getParent()->getParent();
334     RI->setOperand(0, UndefValue::get(F->getReturnType()));
335     // Record all functions that are zapped.
336     FuncZappedReturn.insert(F);
337   }
338 
339   // Remove the returned attribute for zapped functions and the
340   // corresponding call sites.
341   for (Function *F : FuncZappedReturn) {
342     for (Argument &A : F->args())
343       F->removeParamAttr(A.getArgNo(), Attribute::Returned);
344     for (Use &U : F->uses()) {
345       CallBase *CB = dyn_cast<CallBase>(U.getUser());
346       if (!CB) {
347         assert(isa<BlockAddress>(U.getUser()) ||
348                (isa<Constant>(U.getUser()) &&
349                 all_of(U.getUser()->users(), [](const User *UserUser) {
350                   return cast<IntrinsicInst>(UserUser)->isAssumeLikeIntrinsic();
351                 })));
352         continue;
353       }
354 
355       for (Use &Arg : CB->args())
356         CB->removeParamAttr(CB->getArgOperandNo(&Arg), Attribute::Returned);
357     }
358   }
359 
360   // If we inferred constant or undef values for globals variables, we can
361   // delete the global and any stores that remain to it.
362   for (const auto &I : make_early_inc_range(Solver.getTrackedGlobals())) {
363     GlobalVariable *GV = I.first;
364     if (SCCPSolver::isOverdefined(I.second))
365       continue;
366     LLVM_DEBUG(dbgs() << "Found that GV '" << GV->getName()
367                       << "' is constant!\n");
368     while (!GV->use_empty()) {
369       StoreInst *SI = cast<StoreInst>(GV->user_back());
370       SI->eraseFromParent();
371       MadeChanges = true;
372     }
373     M.getGlobalList().erase(GV);
374     ++NumGlobalConst;
375   }
376 
377   return MadeChanges;
378 }
379 
380 PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) {
381   const DataLayout &DL = M.getDataLayout();
382   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
383   auto GetTLI = [&FAM](Function &F) -> const TargetLibraryInfo & {
384     return FAM.getResult<TargetLibraryAnalysis>(F);
385   };
386   auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & {
387     return FAM.getResult<TargetIRAnalysis>(F);
388   };
389   auto GetAC = [&FAM](Function &F) -> AssumptionCache & {
390     return FAM.getResult<AssumptionAnalysis>(F);
391   };
392   auto getAnalysis = [&FAM, this](Function &F) -> AnalysisResultsForFn {
393     DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
394     return {
395         std::make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)),
396         &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F),
397         isFuncSpecEnabled() ? &FAM.getResult<LoopAnalysis>(F) : nullptr };
398   };
399 
400   if (!runIPSCCP(M, DL, &FAM, GetTLI, GetTTI, GetAC, getAnalysis,
401                  isFuncSpecEnabled()))
402     return PreservedAnalyses::all();
403 
404   PreservedAnalyses PA;
405   PA.preserve<DominatorTreeAnalysis>();
406   PA.preserve<PostDominatorTreeAnalysis>();
407   PA.preserve<FunctionAnalysisManagerModuleProxy>();
408   return PA;
409 }
410 
411 namespace {
412 
413 //===--------------------------------------------------------------------===//
414 //
415 /// IPSCCP Class - This class implements interprocedural Sparse Conditional
416 /// Constant Propagation.
417 ///
418 class IPSCCPLegacyPass : public ModulePass {
419 public:
420   static char ID;
421 
422   IPSCCPLegacyPass() : ModulePass(ID) {
423     initializeIPSCCPLegacyPassPass(*PassRegistry::getPassRegistry());
424   }
425 
426   bool runOnModule(Module &M) override {
427     if (skipModule(M))
428       return false;
429     const DataLayout &DL = M.getDataLayout();
430     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
431       return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
432     };
433     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
434       return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
435     };
436     auto GetAC = [this](Function &F) -> AssumptionCache & {
437       return this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
438     };
439     auto getAnalysis = [this](Function &F) -> AnalysisResultsForFn {
440       DominatorTree &DT =
441           this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
442       return {
443           std::make_unique<PredicateInfo>(
444               F, DT,
445               this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
446                   F)),
447           nullptr,  // We cannot preserve the LI, DT or PDT with the legacy pass
448           nullptr,  // manager, so set them to nullptr.
449           nullptr};
450     };
451 
452     return runIPSCCP(M, DL, nullptr, GetTLI, GetTTI, GetAC, getAnalysis, false);
453   }
454 
455   void getAnalysisUsage(AnalysisUsage &AU) const override {
456     AU.addRequired<AssumptionCacheTracker>();
457     AU.addRequired<DominatorTreeWrapperPass>();
458     AU.addRequired<TargetLibraryInfoWrapperPass>();
459     AU.addRequired<TargetTransformInfoWrapperPass>();
460   }
461 };
462 
463 } // end anonymous namespace
464 
465 char IPSCCPLegacyPass::ID = 0;
466 
467 INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp",
468                       "Interprocedural Sparse Conditional Constant Propagation",
469                       false, false)
470 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
471 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
472 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
473 INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp",
474                     "Interprocedural Sparse Conditional Constant Propagation",
475                     false, false)
476 
477 // createIPSCCPPass - This is the public interface to this file.
478 ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); }
479 
480