1 //===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "llvm/Analysis/StackSafetyAnalysis.h"
12 #include "llvm/ADT/APInt.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/ModuleSummaryAnalysis.h"
17 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18 #include "llvm/Analysis/StackLifetime.h"
19 #include "llvm/IR/ConstantRange.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/GlobalValue.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <memory>
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "stack-safety"
36 
37 STATISTIC(NumAllocaStackSafe, "Number of safe allocas");
38 STATISTIC(NumAllocaTotal, "Number of total allocas");
39 
40 static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations",
41                                              cl::init(20), cl::Hidden);
42 
43 static cl::opt<bool> StackSafetyPrint("stack-safety-print", cl::init(false),
44                                       cl::Hidden);
45 
46 static cl::opt<bool> StackSafetyRun("stack-safety-run", cl::init(false),
47                                     cl::Hidden);
48 
49 namespace {
50 
51 /// Describes use of address in as a function call argument.
52 template <typename CalleeTy> struct CallInfo {
53   /// Function being called.
54   const CalleeTy *Callee = nullptr;
55   /// Index of argument which pass address.
56   size_t ParamNo = 0;
57   // Offset range of address from base address (alloca or calling function
58   // argument).
59   // Range should never set to empty-set, that is an invalid access range
60   // that can cause empty-set to be propagated with ConstantRange::add
61   ConstantRange Offset;
CallInfo__anonb65e3eb70111::CallInfo62   CallInfo(const CalleeTy *Callee, size_t ParamNo, const ConstantRange &Offset)
63       : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {}
64 };
65 
66 template <typename CalleeTy>
operator <<(raw_ostream & OS,const CallInfo<CalleeTy> & P)67 raw_ostream &operator<<(raw_ostream &OS, const CallInfo<CalleeTy> &P) {
68   return OS << "@" << P.Callee->getName() << "(arg" << P.ParamNo << ", "
69             << P.Offset << ")";
70 }
71 
72 /// Describe uses of address (alloca or parameter) inside of the function.
73 template <typename CalleeTy> struct UseInfo {
74   // Access range if the address (alloca or parameters).
75   // It is allowed to be empty-set when there are no known accesses.
76   ConstantRange Range;
77 
78   // List of calls which pass address as an argument.
79   SmallVector<CallInfo<CalleeTy>, 4> Calls;
80 
UseInfo__anonb65e3eb70111::UseInfo81   UseInfo(unsigned PointerSize) : Range{PointerSize, false} {}
82 
updateRange__anonb65e3eb70111::UseInfo83   void updateRange(const ConstantRange &R) {
84     assert(!R.isUpperSignWrapped());
85     Range = Range.unionWith(R);
86     assert(!Range.isUpperSignWrapped());
87   }
88 };
89 
90 template <typename CalleeTy>
operator <<(raw_ostream & OS,const UseInfo<CalleeTy> & U)91 raw_ostream &operator<<(raw_ostream &OS, const UseInfo<CalleeTy> &U) {
92   OS << U.Range;
93   for (auto &Call : U.Calls)
94     OS << ", " << Call;
95   return OS;
96 }
97 
98 // Check if we should bailout for such ranges.
isUnsafe(const ConstantRange & R)99 bool isUnsafe(const ConstantRange &R) {
100   return R.isEmptySet() || R.isFullSet() || R.isUpperSignWrapped();
101 }
102 
addOverflowNever(const ConstantRange & L,const ConstantRange & R)103 ConstantRange addOverflowNever(const ConstantRange &L, const ConstantRange &R) {
104   if (L.signedAddMayOverflow(R) !=
105       ConstantRange::OverflowResult::NeverOverflows)
106     return ConstantRange(L.getBitWidth(), true);
107   return L.add(R);
108 }
109 
110 /// Calculate the allocation size of a given alloca. Returns empty range
111 // in case of confution.
getStaticAllocaSizeRange(const AllocaInst & AI)112 ConstantRange getStaticAllocaSizeRange(const AllocaInst &AI) {
113   const DataLayout &DL = AI.getModule()->getDataLayout();
114   TypeSize TS = DL.getTypeAllocSize(AI.getAllocatedType());
115   unsigned PointerSize = DL.getMaxPointerSizeInBits();
116   // Fallback to empty range for alloca size.
117   ConstantRange R = ConstantRange::getEmpty(PointerSize);
118   if (TS.isScalable())
119     return R;
120   APInt APSize(PointerSize, TS.getFixedSize(), true);
121   if (APSize.isNonPositive())
122     return R;
123   if (AI.isArrayAllocation()) {
124     const auto *C = dyn_cast<ConstantInt>(AI.getArraySize());
125     if (!C)
126       return R;
127     bool Overflow = false;
128     APInt Mul = C->getValue();
129     if (Mul.isNonPositive())
130       return R;
131     Mul = Mul.sextOrTrunc(PointerSize);
132     APSize = APSize.smul_ov(Mul, Overflow);
133     if (Overflow)
134       return R;
135   }
136   R = ConstantRange(APInt::getNullValue(PointerSize), APSize);
137   assert(!isUnsafe(R));
138   return R;
139 }
140 
141 template <typename CalleeTy> struct FunctionInfo {
142   std::map<const AllocaInst *, UseInfo<CalleeTy>> Allocas;
143   std::map<uint32_t, UseInfo<CalleeTy>> Params;
144   // TODO: describe return value as depending on one or more of its arguments.
145 
146   // StackSafetyDataFlowAnalysis counter stored here for faster access.
147   int UpdateCount = 0;
148 
print__anonb65e3eb70111::FunctionInfo149   void print(raw_ostream &O, StringRef Name, const Function *F) const {
150     // TODO: Consider different printout format after
151     // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then.
152     O << "  @" << Name << ((F && F->isDSOLocal()) ? "" : " dso_preemptable")
153       << ((F && F->isInterposable()) ? " interposable" : "") << "\n";
154 
155     O << "    args uses:\n";
156     for (auto &KV : Params) {
157       O << "      ";
158       if (F)
159         O << F->getArg(KV.first)->getName();
160       else
161         O << formatv("arg{0}", KV.first);
162       O << "[]: " << KV.second << "\n";
163     }
164 
165     O << "    allocas uses:\n";
166     if (F) {
167       for (auto &I : instructions(F)) {
168         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
169           auto &AS = Allocas.find(AI)->second;
170           O << "      " << AI->getName() << "["
171             << getStaticAllocaSizeRange(*AI).getUpper() << "]: " << AS << "\n";
172         }
173       }
174     } else {
175       assert(Allocas.empty());
176     }
177   }
178 };
179 
180 using GVToSSI = std::map<const GlobalValue *, FunctionInfo<GlobalValue>>;
181 
182 } // namespace
183 
184 struct StackSafetyInfo::InfoTy {
185   FunctionInfo<GlobalValue> Info;
186 };
187 
188 struct StackSafetyGlobalInfo::InfoTy {
189   GVToSSI Info;
190   SmallPtrSet<const AllocaInst *, 8> SafeAllocas;
191 };
192 
193 namespace {
194 
195 class StackSafetyLocalAnalysis {
196   Function &F;
197   const DataLayout &DL;
198   ScalarEvolution &SE;
199   unsigned PointerSize = 0;
200 
201   const ConstantRange UnknownRange;
202 
203   ConstantRange offsetFrom(Value *Addr, Value *Base);
204   ConstantRange getAccessRange(Value *Addr, Value *Base,
205                                const ConstantRange &SizeRange);
206   ConstantRange getAccessRange(Value *Addr, Value *Base, TypeSize Size);
207   ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U,
208                                            Value *Base);
209 
210   bool analyzeAllUses(Value *Ptr, UseInfo<GlobalValue> &AS,
211                       const StackLifetime &SL);
212 
213 public:
StackSafetyLocalAnalysis(Function & F,ScalarEvolution & SE)214   StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE)
215       : F(F), DL(F.getParent()->getDataLayout()), SE(SE),
216         PointerSize(DL.getPointerSizeInBits()),
217         UnknownRange(PointerSize, true) {}
218 
219   // Run the transformation on the associated function.
220   FunctionInfo<GlobalValue> run();
221 };
222 
offsetFrom(Value * Addr,Value * Base)223 ConstantRange StackSafetyLocalAnalysis::offsetFrom(Value *Addr, Value *Base) {
224   if (!SE.isSCEVable(Addr->getType()) || !SE.isSCEVable(Base->getType()))
225     return UnknownRange;
226 
227   auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext());
228   const SCEV *AddrExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Addr), PtrTy);
229   const SCEV *BaseExp = SE.getTruncateOrZeroExtend(SE.getSCEV(Base), PtrTy);
230   const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp);
231 
232   ConstantRange Offset = SE.getSignedRange(Diff);
233   if (isUnsafe(Offset))
234     return UnknownRange;
235   return Offset.sextOrTrunc(PointerSize);
236 }
237 
238 ConstantRange
getAccessRange(Value * Addr,Value * Base,const ConstantRange & SizeRange)239 StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
240                                          const ConstantRange &SizeRange) {
241   // Zero-size loads and stores do not access memory.
242   if (SizeRange.isEmptySet())
243     return ConstantRange::getEmpty(PointerSize);
244   assert(!isUnsafe(SizeRange));
245 
246   ConstantRange Offsets = offsetFrom(Addr, Base);
247   if (isUnsafe(Offsets))
248     return UnknownRange;
249 
250   Offsets = addOverflowNever(Offsets, SizeRange);
251   if (isUnsafe(Offsets))
252     return UnknownRange;
253   return Offsets;
254 }
255 
getAccessRange(Value * Addr,Value * Base,TypeSize Size)256 ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, Value *Base,
257                                                        TypeSize Size) {
258   if (Size.isScalable())
259     return UnknownRange;
260   APInt APSize(PointerSize, Size.getFixedSize(), true);
261   if (APSize.isNegative())
262     return UnknownRange;
263   return getAccessRange(
264       Addr, Base, ConstantRange(APInt::getNullValue(PointerSize), APSize));
265 }
266 
getMemIntrinsicAccessRange(const MemIntrinsic * MI,const Use & U,Value * Base)267 ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange(
268     const MemIntrinsic *MI, const Use &U, Value *Base) {
269   if (const auto *MTI = dyn_cast<MemTransferInst>(MI)) {
270     if (MTI->getRawSource() != U && MTI->getRawDest() != U)
271       return ConstantRange::getEmpty(PointerSize);
272   } else {
273     if (MI->getRawDest() != U)
274       return ConstantRange::getEmpty(PointerSize);
275   }
276 
277   auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize);
278   if (!SE.isSCEVable(MI->getLength()->getType()))
279     return UnknownRange;
280 
281   const SCEV *Expr =
282       SE.getTruncateOrZeroExtend(SE.getSCEV(MI->getLength()), CalculationTy);
283   ConstantRange Sizes = SE.getSignedRange(Expr);
284   if (Sizes.getUpper().isNegative() || isUnsafe(Sizes))
285     return UnknownRange;
286   Sizes = Sizes.sextOrTrunc(PointerSize);
287   ConstantRange SizeRange(APInt::getNullValue(PointerSize),
288                           Sizes.getUpper() - 1);
289   return getAccessRange(U, Base, SizeRange);
290 }
291 
292 /// The function analyzes all local uses of Ptr (alloca or argument) and
293 /// calculates local access range and all function calls where it was used.
analyzeAllUses(Value * Ptr,UseInfo<GlobalValue> & US,const StackLifetime & SL)294 bool StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr,
295                                               UseInfo<GlobalValue> &US,
296                                               const StackLifetime &SL) {
297   SmallPtrSet<const Value *, 16> Visited;
298   SmallVector<const Value *, 8> WorkList;
299   WorkList.push_back(Ptr);
300   const AllocaInst *AI = dyn_cast<AllocaInst>(Ptr);
301 
302   // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc.
303   while (!WorkList.empty()) {
304     const Value *V = WorkList.pop_back_val();
305     for (const Use &UI : V->uses()) {
306       const auto *I = cast<Instruction>(UI.getUser());
307       if (!SL.isReachable(I))
308         continue;
309 
310       assert(V == UI.get());
311 
312       switch (I->getOpcode()) {
313       case Instruction::Load: {
314         if (AI && !SL.isAliveAfter(AI, I)) {
315           US.updateRange(UnknownRange);
316           return false;
317         }
318         US.updateRange(
319             getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType())));
320         break;
321       }
322 
323       case Instruction::VAArg:
324         // "va-arg" from a pointer is safe.
325         break;
326       case Instruction::Store: {
327         if (V == I->getOperand(0)) {
328           // Stored the pointer - conservatively assume it may be unsafe.
329           US.updateRange(UnknownRange);
330           return false;
331         }
332         if (AI && !SL.isAliveAfter(AI, I)) {
333           US.updateRange(UnknownRange);
334           return false;
335         }
336         US.updateRange(getAccessRange(
337             UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType())));
338         break;
339       }
340 
341       case Instruction::Ret:
342         // Information leak.
343         // FIXME: Process parameters correctly. This is a leak only if we return
344         // alloca.
345         US.updateRange(UnknownRange);
346         return false;
347 
348       case Instruction::Call:
349       case Instruction::Invoke: {
350         if (I->isLifetimeStartOrEnd())
351           break;
352 
353         if (AI && !SL.isAliveAfter(AI, I)) {
354           US.updateRange(UnknownRange);
355           return false;
356         }
357 
358         if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
359           US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr));
360           break;
361         }
362 
363         const auto &CB = cast<CallBase>(*I);
364         if (!CB.isArgOperand(&UI)) {
365           US.updateRange(UnknownRange);
366           return false;
367         }
368 
369         unsigned ArgNo = CB.getArgOperandNo(&UI);
370         if (CB.isByValArgument(ArgNo)) {
371           US.updateRange(getAccessRange(
372               UI, Ptr, DL.getTypeStoreSize(CB.getParamByValType(ArgNo))));
373           break;
374         }
375 
376         // FIXME: consult devirt?
377         // Do not follow aliases, otherwise we could inadvertently follow
378         // dso_preemptable aliases or aliases with interposable linkage.
379         const GlobalValue *Callee =
380             dyn_cast<GlobalValue>(CB.getCalledOperand()->stripPointerCasts());
381         if (!Callee) {
382           US.updateRange(UnknownRange);
383           return false;
384         }
385 
386         assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee));
387         US.Calls.emplace_back(Callee, ArgNo, offsetFrom(UI, Ptr));
388         break;
389       }
390 
391       default:
392         if (Visited.insert(I).second)
393           WorkList.push_back(cast<const Instruction>(I));
394       }
395     }
396   }
397 
398   return true;
399 }
400 
run()401 FunctionInfo<GlobalValue> StackSafetyLocalAnalysis::run() {
402   FunctionInfo<GlobalValue> Info;
403   assert(!F.isDeclaration() &&
404          "Can't run StackSafety on a function declaration");
405 
406   LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n");
407 
408   SmallVector<AllocaInst *, 64> Allocas;
409   for (auto &I : instructions(F))
410     if (auto *AI = dyn_cast<AllocaInst>(&I))
411       Allocas.push_back(AI);
412   StackLifetime SL(F, Allocas, StackLifetime::LivenessType::Must);
413   SL.run();
414 
415   for (auto *AI : Allocas) {
416     auto &UI = Info.Allocas.emplace(AI, PointerSize).first->second;
417     analyzeAllUses(AI, UI, SL);
418   }
419 
420   for (Argument &A : make_range(F.arg_begin(), F.arg_end())) {
421     // Non pointers and bypass arguments are not going to be used in any global
422     // processing.
423     if (A.getType()->isPointerTy() && !A.hasByValAttr()) {
424       auto &UI = Info.Params.emplace(A.getArgNo(), PointerSize).first->second;
425       analyzeAllUses(&A, UI, SL);
426     }
427   }
428 
429   LLVM_DEBUG(Info.print(dbgs(), F.getName(), &F));
430   LLVM_DEBUG(dbgs() << "[StackSafety] done\n");
431   return Info;
432 }
433 
434 template <typename CalleeTy> class StackSafetyDataFlowAnalysis {
435   using FunctionMap = std::map<const CalleeTy *, FunctionInfo<CalleeTy>>;
436 
437   FunctionMap Functions;
438   const ConstantRange UnknownRange;
439 
440   // Callee-to-Caller multimap.
441   DenseMap<const CalleeTy *, SmallVector<const CalleeTy *, 4>> Callers;
442   SetVector<const CalleeTy *> WorkList;
443 
444   bool updateOneUse(UseInfo<CalleeTy> &US, bool UpdateToFullSet);
445   void updateOneNode(const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS);
updateOneNode(const CalleeTy * Callee)446   void updateOneNode(const CalleeTy *Callee) {
447     updateOneNode(Callee, Functions.find(Callee)->second);
448   }
updateAllNodes()449   void updateAllNodes() {
450     for (auto &F : Functions)
451       updateOneNode(F.first, F.second);
452   }
453   void runDataFlow();
454 #ifndef NDEBUG
455   void verifyFixedPoint();
456 #endif
457 
458 public:
StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth,FunctionMap Functions)459   StackSafetyDataFlowAnalysis(uint32_t PointerBitWidth, FunctionMap Functions)
460       : Functions(std::move(Functions)),
461         UnknownRange(ConstantRange::getFull(PointerBitWidth)) {}
462 
463   const FunctionMap &run();
464 
465   ConstantRange getArgumentAccessRange(const CalleeTy *Callee, unsigned ParamNo,
466                                        const ConstantRange &Offsets) const;
467 };
468 
469 template <typename CalleeTy>
getArgumentAccessRange(const CalleeTy * Callee,unsigned ParamNo,const ConstantRange & Offsets) const470 ConstantRange StackSafetyDataFlowAnalysis<CalleeTy>::getArgumentAccessRange(
471     const CalleeTy *Callee, unsigned ParamNo,
472     const ConstantRange &Offsets) const {
473   auto FnIt = Functions.find(Callee);
474   // Unknown callee (outside of LTO domain or an indirect call).
475   if (FnIt == Functions.end())
476     return UnknownRange;
477   auto &FS = FnIt->second;
478   auto ParamIt = FS.Params.find(ParamNo);
479   if (ParamIt == FS.Params.end())
480     return UnknownRange;
481   auto &Access = ParamIt->second.Range;
482   if (Access.isEmptySet())
483     return Access;
484   if (Access.isFullSet())
485     return UnknownRange;
486   return addOverflowNever(Access, Offsets);
487 }
488 
489 template <typename CalleeTy>
updateOneUse(UseInfo<CalleeTy> & US,bool UpdateToFullSet)490 bool StackSafetyDataFlowAnalysis<CalleeTy>::updateOneUse(UseInfo<CalleeTy> &US,
491                                                          bool UpdateToFullSet) {
492   bool Changed = false;
493   for (auto &CS : US.Calls) {
494     assert(!CS.Offset.isEmptySet() &&
495            "Param range can't be empty-set, invalid offset range");
496 
497     ConstantRange CalleeRange =
498         getArgumentAccessRange(CS.Callee, CS.ParamNo, CS.Offset);
499     if (!US.Range.contains(CalleeRange)) {
500       Changed = true;
501       if (UpdateToFullSet)
502         US.Range = UnknownRange;
503       else
504         US.Range = US.Range.unionWith(CalleeRange);
505     }
506   }
507   return Changed;
508 }
509 
510 template <typename CalleeTy>
updateOneNode(const CalleeTy * Callee,FunctionInfo<CalleeTy> & FS)511 void StackSafetyDataFlowAnalysis<CalleeTy>::updateOneNode(
512     const CalleeTy *Callee, FunctionInfo<CalleeTy> &FS) {
513   bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations;
514   bool Changed = false;
515   for (auto &KV : FS.Params)
516     Changed |= updateOneUse(KV.second, UpdateToFullSet);
517 
518   if (Changed) {
519     LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount
520                       << (UpdateToFullSet ? ", full-set" : "") << "] " << &FS
521                       << "\n");
522     // Callers of this function may need updating.
523     for (auto &CallerID : Callers[Callee])
524       WorkList.insert(CallerID);
525 
526     ++FS.UpdateCount;
527   }
528 }
529 
530 template <typename CalleeTy>
runDataFlow()531 void StackSafetyDataFlowAnalysis<CalleeTy>::runDataFlow() {
532   SmallVector<const CalleeTy *, 16> Callees;
533   for (auto &F : Functions) {
534     Callees.clear();
535     auto &FS = F.second;
536     for (auto &KV : FS.Params)
537       for (auto &CS : KV.second.Calls)
538         Callees.push_back(CS.Callee);
539 
540     llvm::sort(Callees);
541     Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end());
542 
543     for (auto &Callee : Callees)
544       Callers[Callee].push_back(F.first);
545   }
546 
547   updateAllNodes();
548 
549   while (!WorkList.empty()) {
550     const CalleeTy *Callee = WorkList.back();
551     WorkList.pop_back();
552     updateOneNode(Callee);
553   }
554 }
555 
556 #ifndef NDEBUG
557 template <typename CalleeTy>
verifyFixedPoint()558 void StackSafetyDataFlowAnalysis<CalleeTy>::verifyFixedPoint() {
559   WorkList.clear();
560   updateAllNodes();
561   assert(WorkList.empty());
562 }
563 #endif
564 
565 template <typename CalleeTy>
566 const typename StackSafetyDataFlowAnalysis<CalleeTy>::FunctionMap &
run()567 StackSafetyDataFlowAnalysis<CalleeTy>::run() {
568   runDataFlow();
569   LLVM_DEBUG(verifyFixedPoint());
570   return Functions;
571 }
572 
resolveCallee(GlobalValueSummary * S)573 FunctionSummary *resolveCallee(GlobalValueSummary *S) {
574   while (S) {
575     if (!S->isLive() || !S->isDSOLocal())
576       return nullptr;
577     if (FunctionSummary *FS = dyn_cast<FunctionSummary>(S))
578       return FS;
579     AliasSummary *AS = dyn_cast<AliasSummary>(S);
580     if (!AS)
581       return nullptr;
582     S = AS->getBaseObject();
583     if (S == AS)
584       return nullptr;
585   }
586   return nullptr;
587 }
588 
findCalleeInModule(const GlobalValue * GV)589 const Function *findCalleeInModule(const GlobalValue *GV) {
590   while (GV) {
591     if (GV->isDeclaration() || GV->isInterposable() || !GV->isDSOLocal())
592       return nullptr;
593     if (const Function *F = dyn_cast<Function>(GV))
594       return F;
595     const GlobalAlias *A = dyn_cast<GlobalAlias>(GV);
596     if (!A)
597       return nullptr;
598     GV = A->getBaseObject();
599     if (GV == A)
600       return nullptr;
601   }
602   return nullptr;
603 }
604 
getGlobalValueSummary(const ModuleSummaryIndex * Index,uint64_t ValueGUID)605 GlobalValueSummary *getGlobalValueSummary(const ModuleSummaryIndex *Index,
606                                           uint64_t ValueGUID) {
607   auto VI = Index->getValueInfo(ValueGUID);
608   if (!VI || VI.getSummaryList().empty())
609     return nullptr;
610   assert(VI.getSummaryList().size() == 1);
611   auto &Summary = VI.getSummaryList()[0];
612   return Summary.get();
613 }
614 
findParamAccess(const FunctionSummary & FS,uint32_t ParamNo)615 const ConstantRange *findParamAccess(const FunctionSummary &FS,
616                                      uint32_t ParamNo) {
617   assert(FS.isLive());
618   assert(FS.isDSOLocal());
619   for (auto &PS : FS.paramAccesses())
620     if (ParamNo == PS.ParamNo)
621       return &PS.Use;
622   return nullptr;
623 }
624 
resolveAllCalls(UseInfo<GlobalValue> & Use,const ModuleSummaryIndex * Index)625 void resolveAllCalls(UseInfo<GlobalValue> &Use,
626                      const ModuleSummaryIndex *Index) {
627   ConstantRange FullSet(Use.Range.getBitWidth(), true);
628   for (auto &C : Use.Calls) {
629     const Function *F = findCalleeInModule(C.Callee);
630     if (F) {
631       C.Callee = F;
632       continue;
633     }
634 
635     if (!Index)
636       return Use.updateRange(FullSet);
637     GlobalValueSummary *GVS = getGlobalValueSummary(Index, C.Callee->getGUID());
638 
639     FunctionSummary *FS = resolveCallee(GVS);
640     if (!FS)
641       return Use.updateRange(FullSet);
642     const ConstantRange *Found = findParamAccess(*FS, C.ParamNo);
643     if (!Found)
644       return Use.updateRange(FullSet);
645     ConstantRange Access = Found->sextOrTrunc(Use.Range.getBitWidth());
646     Use.updateRange(addOverflowNever(Access, C.Offset));
647     C.Callee = nullptr;
648   }
649 
650   Use.Calls.erase(std::remove_if(Use.Calls.begin(), Use.Calls.end(),
651                                  [](auto &T) { return !T.Callee; }),
652                   Use.Calls.end());
653 }
654 
createGlobalStackSafetyInfo(std::map<const GlobalValue *,FunctionInfo<GlobalValue>> Functions,const ModuleSummaryIndex * Index)655 GVToSSI createGlobalStackSafetyInfo(
656     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions,
657     const ModuleSummaryIndex *Index) {
658   GVToSSI SSI;
659   if (Functions.empty())
660     return SSI;
661 
662   // FIXME: Simplify printing and remove copying here.
663   auto Copy = Functions;
664 
665   for (auto &FnKV : Copy)
666     for (auto &KV : FnKV.second.Params)
667       resolveAllCalls(KV.second, Index);
668 
669   uint32_t PointerSize = Copy.begin()
670                              ->first->getParent()
671                              ->getDataLayout()
672                              .getMaxPointerSizeInBits();
673   StackSafetyDataFlowAnalysis<GlobalValue> SSDFA(PointerSize, std::move(Copy));
674 
675   for (auto &F : SSDFA.run()) {
676     auto FI = F.second;
677     auto &SrcF = Functions[F.first];
678     for (auto &KV : FI.Allocas) {
679       auto &A = KV.second;
680       resolveAllCalls(A, Index);
681       for (auto &C : A.Calls) {
682         A.updateRange(
683             SSDFA.getArgumentAccessRange(C.Callee, C.ParamNo, C.Offset));
684       }
685       // FIXME: This is needed only to preserve calls in print() results.
686       A.Calls = SrcF.Allocas.find(KV.first)->second.Calls;
687     }
688     for (auto &KV : FI.Params) {
689       auto &P = KV.second;
690       P.Calls = SrcF.Params.find(KV.first)->second.Calls;
691     }
692     SSI[F.first] = std::move(FI);
693   }
694 
695   return SSI;
696 }
697 
698 } // end anonymous namespace
699 
700 StackSafetyInfo::StackSafetyInfo() = default;
701 
StackSafetyInfo(Function * F,std::function<ScalarEvolution & ()> GetSE)702 StackSafetyInfo::StackSafetyInfo(Function *F,
703                                  std::function<ScalarEvolution &()> GetSE)
704     : F(F), GetSE(GetSE) {}
705 
706 StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default;
707 
708 StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default;
709 
710 StackSafetyInfo::~StackSafetyInfo() = default;
711 
getInfo() const712 const StackSafetyInfo::InfoTy &StackSafetyInfo::getInfo() const {
713   if (!Info) {
714     StackSafetyLocalAnalysis SSLA(*F, GetSE());
715     Info.reset(new InfoTy{SSLA.run()});
716   }
717   return *Info;
718 }
719 
print(raw_ostream & O) const720 void StackSafetyInfo::print(raw_ostream &O) const {
721   getInfo().Info.print(O, F->getName(), dyn_cast<Function>(F));
722 }
723 
getInfo() const724 const StackSafetyGlobalInfo::InfoTy &StackSafetyGlobalInfo::getInfo() const {
725   if (!Info) {
726     std::map<const GlobalValue *, FunctionInfo<GlobalValue>> Functions;
727     for (auto &F : M->functions()) {
728       if (!F.isDeclaration()) {
729         auto FI = GetSSI(F).getInfo().Info;
730         Functions.emplace(&F, std::move(FI));
731       }
732     }
733     Info.reset(new InfoTy{
734         createGlobalStackSafetyInfo(std::move(Functions), Index), {}});
735     for (auto &FnKV : Info->Info) {
736       for (auto &KV : FnKV.second.Allocas) {
737         ++NumAllocaTotal;
738         const AllocaInst *AI = KV.first;
739         if (getStaticAllocaSizeRange(*AI).contains(KV.second.Range)) {
740           Info->SafeAllocas.insert(AI);
741           ++NumAllocaStackSafe;
742         }
743       }
744     }
745     if (StackSafetyPrint)
746       print(errs());
747   }
748   return *Info;
749 }
750 
751 std::vector<FunctionSummary::ParamAccess>
getParamAccesses() const752 StackSafetyInfo::getParamAccesses() const {
753   // Implementation transforms internal representation of parameter information
754   // into FunctionSummary format.
755   std::vector<FunctionSummary::ParamAccess> ParamAccesses;
756   for (const auto &KV : getInfo().Info.Params) {
757     auto &PS = KV.second;
758     // Parameter accessed by any or unknown offset, represented as FullSet by
759     // StackSafety, is handled as the parameter for which we have no
760     // StackSafety info at all. So drop it to reduce summary size.
761     if (PS.Range.isFullSet())
762       continue;
763 
764     ParamAccesses.emplace_back(KV.first, PS.Range);
765     FunctionSummary::ParamAccess &Param = ParamAccesses.back();
766 
767     Param.Calls.reserve(PS.Calls.size());
768     for (auto &C : PS.Calls) {
769       // Parameter forwarded into another function by any or unknown offset
770       // will make ParamAccess::Range as FullSet anyway. So we can drop the
771       // entire parameter like we did above.
772       // TODO(vitalybuka): Return already filtered parameters from getInfo().
773       if (C.Offset.isFullSet()) {
774         ParamAccesses.pop_back();
775         break;
776       }
777       Param.Calls.emplace_back(C.ParamNo, C.Callee->getGUID(), C.Offset);
778     }
779   }
780   return ParamAccesses;
781 }
782 
783 StackSafetyGlobalInfo::StackSafetyGlobalInfo() = default;
784 
StackSafetyGlobalInfo(Module * M,std::function<const StackSafetyInfo & (Function & F)> GetSSI,const ModuleSummaryIndex * Index)785 StackSafetyGlobalInfo::StackSafetyGlobalInfo(
786     Module *M, std::function<const StackSafetyInfo &(Function &F)> GetSSI,
787     const ModuleSummaryIndex *Index)
788     : M(M), GetSSI(GetSSI), Index(Index) {
789   if (StackSafetyRun)
790     getInfo();
791 }
792 
793 StackSafetyGlobalInfo::StackSafetyGlobalInfo(StackSafetyGlobalInfo &&) =
794     default;
795 
796 StackSafetyGlobalInfo &
797 StackSafetyGlobalInfo::operator=(StackSafetyGlobalInfo &&) = default;
798 
799 StackSafetyGlobalInfo::~StackSafetyGlobalInfo() = default;
800 
isSafe(const AllocaInst & AI) const801 bool StackSafetyGlobalInfo::isSafe(const AllocaInst &AI) const {
802   const auto &Info = getInfo();
803   return Info.SafeAllocas.count(&AI);
804 }
805 
print(raw_ostream & O) const806 void StackSafetyGlobalInfo::print(raw_ostream &O) const {
807   auto &SSI = getInfo().Info;
808   if (SSI.empty())
809     return;
810   const Module &M = *SSI.begin()->first->getParent();
811   for (auto &F : M.functions()) {
812     if (!F.isDeclaration()) {
813       SSI.find(&F)->second.print(O, F.getName(), &F);
814       O << "\n";
815     }
816   }
817 }
818 
dump() const819 LLVM_DUMP_METHOD void StackSafetyGlobalInfo::dump() const { print(dbgs()); }
820 
821 AnalysisKey StackSafetyAnalysis::Key;
822 
run(Function & F,FunctionAnalysisManager & AM)823 StackSafetyInfo StackSafetyAnalysis::run(Function &F,
824                                          FunctionAnalysisManager &AM) {
825   return StackSafetyInfo(&F, [&AM, &F]() -> ScalarEvolution & {
826     return AM.getResult<ScalarEvolutionAnalysis>(F);
827   });
828 }
829 
run(Function & F,FunctionAnalysisManager & AM)830 PreservedAnalyses StackSafetyPrinterPass::run(Function &F,
831                                               FunctionAnalysisManager &AM) {
832   OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n";
833   AM.getResult<StackSafetyAnalysis>(F).print(OS);
834   return PreservedAnalyses::all();
835 }
836 
837 char StackSafetyInfoWrapperPass::ID = 0;
838 
StackSafetyInfoWrapperPass()839 StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) {
840   initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry());
841 }
842 
getAnalysisUsage(AnalysisUsage & AU) const843 void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
844   AU.addRequiredTransitive<ScalarEvolutionWrapperPass>();
845   AU.setPreservesAll();
846 }
847 
print(raw_ostream & O,const Module * M) const848 void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const {
849   SSI.print(O);
850 }
851 
runOnFunction(Function & F)852 bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) {
853   auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
854   SSI = {&F, [SE]() -> ScalarEvolution & { return *SE; }};
855   return false;
856 }
857 
858 AnalysisKey StackSafetyGlobalAnalysis::Key;
859 
860 StackSafetyGlobalInfo
run(Module & M,ModuleAnalysisManager & AM)861 StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
862   // FIXME: Lookup Module Summary.
863   FunctionAnalysisManager &FAM =
864       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
865   return {&M,
866           [&FAM](Function &F) -> const StackSafetyInfo & {
867             return FAM.getResult<StackSafetyAnalysis>(F);
868           },
869           nullptr};
870 }
871 
run(Module & M,ModuleAnalysisManager & AM)872 PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M,
873                                                     ModuleAnalysisManager &AM) {
874   OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n";
875   AM.getResult<StackSafetyGlobalAnalysis>(M).print(OS);
876   return PreservedAnalyses::all();
877 }
878 
879 char StackSafetyGlobalInfoWrapperPass::ID = 0;
880 
StackSafetyGlobalInfoWrapperPass()881 StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass()
882     : ModulePass(ID) {
883   initializeStackSafetyGlobalInfoWrapperPassPass(
884       *PassRegistry::getPassRegistry());
885 }
886 
887 StackSafetyGlobalInfoWrapperPass::~StackSafetyGlobalInfoWrapperPass() = default;
888 
print(raw_ostream & O,const Module * M) const889 void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O,
890                                              const Module *M) const {
891   SSGI.print(O);
892 }
893 
getAnalysisUsage(AnalysisUsage & AU) const894 void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage(
895     AnalysisUsage &AU) const {
896   AU.setPreservesAll();
897   AU.addRequired<StackSafetyInfoWrapperPass>();
898 }
899 
runOnModule(Module & M)900 bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) {
901   const ModuleSummaryIndex *ImportSummary = nullptr;
902   if (auto *IndexWrapperPass =
903           getAnalysisIfAvailable<ImmutableModuleSummaryIndexWrapperPass>())
904     ImportSummary = IndexWrapperPass->getIndex();
905 
906   SSGI = {&M,
907           [this](Function &F) -> const StackSafetyInfo & {
908             return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult();
909           },
910           ImportSummary};
911   return false;
912 }
913 
needsParamAccessSummary(const Module & M)914 bool llvm::needsParamAccessSummary(const Module &M) {
915   if (StackSafetyRun)
916     return true;
917   for (auto &F : M.functions())
918     if (F.hasFnAttribute(Attribute::SanitizeMemTag))
919       return true;
920   return false;
921 }
922 
generateParamAccessSummary(ModuleSummaryIndex & Index)923 void llvm::generateParamAccessSummary(ModuleSummaryIndex &Index) {
924   const ConstantRange FullSet(FunctionSummary::ParamAccess::RangeWidth, true);
925   std::map<const FunctionSummary *, FunctionInfo<FunctionSummary>> Functions;
926 
927   // Convert the ModuleSummaryIndex to a FunctionMap
928   for (auto &GVS : Index) {
929     for (auto &GV : GVS.second.SummaryList) {
930       FunctionSummary *FS = dyn_cast<FunctionSummary>(GV.get());
931       if (!FS)
932         continue;
933       if (FS->isLive() && FS->isDSOLocal()) {
934         FunctionInfo<FunctionSummary> FI;
935         for (auto &PS : FS->paramAccesses()) {
936           auto &US =
937               FI.Params
938                   .emplace(PS.ParamNo, FunctionSummary::ParamAccess::RangeWidth)
939                   .first->second;
940           US.Range = PS.Use;
941           for (auto &Call : PS.Calls) {
942             assert(!Call.Offsets.isFullSet());
943             FunctionSummary *S = resolveCallee(
944                 Index.findSummaryInModule(Call.Callee, FS->modulePath()));
945             if (!S) {
946               US.Range = FullSet;
947               US.Calls.clear();
948               break;
949             }
950             US.Calls.emplace_back(S, Call.ParamNo, Call.Offsets);
951           }
952         }
953         Functions.emplace(FS, std::move(FI));
954       }
955       // Reset data for all summaries. Alive and DSO local will be set back from
956       // of data flow results below. Anything else will not be accessed
957       // by ThinLTO backend, so we can save on bitcode size.
958       FS->setParamAccesses({});
959     }
960   }
961   StackSafetyDataFlowAnalysis<FunctionSummary> SSDFA(
962       FunctionSummary::ParamAccess::RangeWidth, std::move(Functions));
963   for (auto &KV : SSDFA.run()) {
964     std::vector<FunctionSummary::ParamAccess> NewParams;
965     NewParams.reserve(KV.second.Params.size());
966     for (auto &Param : KV.second.Params) {
967       NewParams.emplace_back();
968       FunctionSummary::ParamAccess &New = NewParams.back();
969       New.ParamNo = Param.first;
970       New.Use = Param.second.Range; // Only range is needed.
971     }
972     const_cast<FunctionSummary *>(KV.first)->setParamAccesses(
973         std::move(NewParams));
974   }
975 }
976 
977 static const char LocalPassArg[] = "stack-safety-local";
978 static const char LocalPassName[] = "Stack Safety Local Analysis";
979 INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
980                       false, true)
981 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
982 INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName,
983                     false, true)
984 
985 static const char GlobalPassName[] = "Stack Safety Analysis";
986 INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
987                       GlobalPassName, false, true)
988 INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass)
989 INITIALIZE_PASS_DEPENDENCY(ImmutableModuleSummaryIndexWrapperPass)
990 INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE,
991                     GlobalPassName, false, true)
992