1 //===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a pass that keeps track of @llvm.assume and
10 // @llvm.experimental.guard intrinsics in the functions of a module.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Analysis/AssumptionCache.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallPtrSet.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/AssumeBundleQueries.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstrTypes.h"
23 #include "llvm/IR/Instruction.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/IR/PatternMatch.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/ErrorHandling.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <cassert>
34 #include <utility>
35
36 using namespace llvm;
37 using namespace llvm::PatternMatch;
38
39 static cl::opt<bool>
40 VerifyAssumptionCache("verify-assumption-cache", cl::Hidden,
41 cl::desc("Enable verification of assumption cache"),
42 cl::init(false));
43
44 SmallVector<AssumptionCache::ResultElem, 1> &
getOrInsertAffectedValues(Value * V)45 AssumptionCache::getOrInsertAffectedValues(Value *V) {
46 // Try using find_as first to avoid creating extra value handles just for the
47 // purpose of doing the lookup.
48 auto AVI = AffectedValues.find_as(V);
49 if (AVI != AffectedValues.end())
50 return AVI->second;
51
52 auto AVIP = AffectedValues.insert(
53 {AffectedValueCallbackVH(V, this), SmallVector<ResultElem, 1>()});
54 return AVIP.first->second;
55 }
56
57 static void
findAffectedValues(CallBase * CI,TargetTransformInfo * TTI,SmallVectorImpl<AssumptionCache::ResultElem> & Affected)58 findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
59 SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
60 // Note: This code must be kept in-sync with the code in
61 // computeKnownBitsFromAssume in ValueTracking.
62
63 auto AddAffected = [&Affected](Value *V, unsigned Idx =
64 AssumptionCache::ExprResultIdx) {
65 if (isa<Argument>(V)) {
66 Affected.push_back({V, Idx});
67 } else if (auto *I = dyn_cast<Instruction>(V)) {
68 Affected.push_back({I, Idx});
69
70 // Peek through unary operators to find the source of the condition.
71 Value *Op;
72 if (match(I, m_BitCast(m_Value(Op))) ||
73 match(I, m_PtrToInt(m_Value(Op))) || match(I, m_Not(m_Value(Op)))) {
74 if (isa<Instruction>(Op) || isa<Argument>(Op))
75 Affected.push_back({Op, Idx});
76 }
77 }
78 };
79
80 for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) {
81 if (CI->getOperandBundleAt(Idx).Inputs.size() > ABA_WasOn &&
82 CI->getOperandBundleAt(Idx).getTagName() != IgnoreBundleTag)
83 AddAffected(CI->getOperandBundleAt(Idx).Inputs[ABA_WasOn], Idx);
84 }
85
86 Value *Cond = CI->getArgOperand(0), *A, *B;
87 AddAffected(Cond);
88
89 CmpInst::Predicate Pred;
90 if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
91 AddAffected(A);
92 AddAffected(B);
93
94 if (Pred == ICmpInst::ICMP_EQ) {
95 // For equality comparisons, we handle the case of bit inversion.
96 auto AddAffectedFromEq = [&AddAffected](Value *V) {
97 Value *A;
98 if (match(V, m_Not(m_Value(A)))) {
99 AddAffected(A);
100 V = A;
101 }
102
103 Value *B;
104 // (A & B) or (A | B) or (A ^ B).
105 if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
106 AddAffected(A);
107 AddAffected(B);
108 // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
109 } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
110 AddAffected(A);
111 }
112 };
113
114 AddAffectedFromEq(A);
115 AddAffectedFromEq(B);
116 } else if (Pred == ICmpInst::ICMP_NE) {
117 Value *X, *Y;
118 // Handle (a & b != 0). If a/b is a power of 2 we can use this
119 // information.
120 if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
121 AddAffected(X);
122 AddAffected(Y);
123 }
124 } else if (Pred == ICmpInst::ICMP_ULT) {
125 Value *X;
126 // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
127 // and recognized by LVI at least.
128 if (match(A, m_Add(m_Value(X), m_ConstantInt())) &&
129 match(B, m_ConstantInt()))
130 AddAffected(X);
131 }
132 }
133
134 if (TTI) {
135 const Value *Ptr;
136 unsigned AS;
137 std::tie(Ptr, AS) = TTI->getPredicatedAddrSpace(Cond);
138 if (Ptr)
139 AddAffected(const_cast<Value *>(Ptr->stripInBoundsOffsets()));
140 }
141 }
142
updateAffectedValues(CondGuardInst * CI)143 void AssumptionCache::updateAffectedValues(CondGuardInst *CI) {
144 SmallVector<AssumptionCache::ResultElem, 16> Affected;
145 findAffectedValues(CI, TTI, Affected);
146
147 for (auto &AV : Affected) {
148 auto &AVV = getOrInsertAffectedValues(AV.Assume);
149 if (llvm::none_of(AVV, [&](ResultElem &Elem) {
150 return Elem.Assume == CI && Elem.Index == AV.Index;
151 }))
152 AVV.push_back({CI, AV.Index});
153 }
154 }
155
unregisterAssumption(CondGuardInst * CI)156 void AssumptionCache::unregisterAssumption(CondGuardInst *CI) {
157 SmallVector<AssumptionCache::ResultElem, 16> Affected;
158 findAffectedValues(CI, TTI, Affected);
159
160 for (auto &AV : Affected) {
161 auto AVI = AffectedValues.find_as(AV.Assume);
162 if (AVI == AffectedValues.end())
163 continue;
164 bool Found = false;
165 bool HasNonnull = false;
166 for (ResultElem &Elem : AVI->second) {
167 if (Elem.Assume == CI) {
168 Found = true;
169 Elem.Assume = nullptr;
170 }
171 HasNonnull |= !!Elem.Assume;
172 if (HasNonnull && Found)
173 break;
174 }
175 assert(Found && "already unregistered or incorrect cache state");
176 if (!HasNonnull)
177 AffectedValues.erase(AVI);
178 }
179
180 erase_value(AssumeHandles, CI);
181 }
182
deleted()183 void AssumptionCache::AffectedValueCallbackVH::deleted() {
184 AC->AffectedValues.erase(getValPtr());
185 // 'this' now dangles!
186 }
187
transferAffectedValuesInCache(Value * OV,Value * NV)188 void AssumptionCache::transferAffectedValuesInCache(Value *OV, Value *NV) {
189 auto &NAVV = getOrInsertAffectedValues(NV);
190 auto AVI = AffectedValues.find(OV);
191 if (AVI == AffectedValues.end())
192 return;
193
194 for (auto &A : AVI->second)
195 if (!llvm::is_contained(NAVV, A))
196 NAVV.push_back(A);
197 AffectedValues.erase(OV);
198 }
199
allUsesReplacedWith(Value * NV)200 void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
201 if (!isa<Instruction>(NV) && !isa<Argument>(NV))
202 return;
203
204 // Any assumptions that affected this value now affect the new value.
205
206 AC->transferAffectedValuesInCache(getValPtr(), NV);
207 // 'this' now might dangle! If the AffectedValues map was resized to add an
208 // entry for NV then this object might have been destroyed in favor of some
209 // copy in the grown map.
210 }
211
scanFunction()212 void AssumptionCache::scanFunction() {
213 assert(!Scanned && "Tried to scan the function twice!");
214 assert(AssumeHandles.empty() && "Already have assumes when scanning!");
215
216 // Go through all instructions in all blocks, add all calls to @llvm.assume
217 // to this cache.
218 for (BasicBlock &B : F)
219 for (Instruction &I : B)
220 if (isa<CondGuardInst>(&I))
221 AssumeHandles.push_back({&I, ExprResultIdx});
222
223 // Mark the scan as complete.
224 Scanned = true;
225
226 // Update affected values.
227 for (auto &A : AssumeHandles)
228 updateAffectedValues(cast<CondGuardInst>(A));
229 }
230
registerAssumption(CondGuardInst * CI)231 void AssumptionCache::registerAssumption(CondGuardInst *CI) {
232 // If we haven't scanned the function yet, just drop this assumption. It will
233 // be found when we scan later.
234 if (!Scanned)
235 return;
236
237 AssumeHandles.push_back({CI, ExprResultIdx});
238
239 #ifndef NDEBUG
240 assert(CI->getParent() &&
241 "Cannot a register CondGuardInst not in a basic block");
242 assert(&F == CI->getParent()->getParent() &&
243 "Cannot a register CondGuardInst not in this function");
244
245 // We expect the number of assumptions to be small, so in an asserts build
246 // check that we don't accumulate duplicates and that all assumptions point
247 // to the same function.
248 SmallPtrSet<Value *, 16> AssumptionSet;
249 for (auto &VH : AssumeHandles) {
250 if (!VH)
251 continue;
252
253 assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
254 "Cached assumption not inside this function!");
255 assert(isa<CondGuardInst>(VH) &&
256 "Cached something other than CondGuardInst!");
257 assert(AssumptionSet.insert(VH).second &&
258 "Cache contains multiple copies of a call!");
259 }
260 #endif
261
262 updateAffectedValues(CI);
263 }
264
run(Function & F,FunctionAnalysisManager & FAM)265 AssumptionCache AssumptionAnalysis::run(Function &F,
266 FunctionAnalysisManager &FAM) {
267 auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
268 return AssumptionCache(F, &TTI);
269 }
270
271 AnalysisKey AssumptionAnalysis::Key;
272
run(Function & F,FunctionAnalysisManager & AM)273 PreservedAnalyses AssumptionPrinterPass::run(Function &F,
274 FunctionAnalysisManager &AM) {
275 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
276
277 OS << "Cached assumptions for function: " << F.getName() << "\n";
278 for (auto &VH : AC.assumptions())
279 if (VH)
280 OS << " " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
281
282 return PreservedAnalyses::all();
283 }
284
deleted()285 void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
286 auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr()));
287 if (I != ACT->AssumptionCaches.end())
288 ACT->AssumptionCaches.erase(I);
289 // 'this' now dangles!
290 }
291
getAssumptionCache(Function & F)292 AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
293 // We probe the function map twice to try and avoid creating a value handle
294 // around the function in common cases. This makes insertion a bit slower,
295 // but if we have to insert we're going to scan the whole function so that
296 // shouldn't matter.
297 auto I = AssumptionCaches.find_as(&F);
298 if (I != AssumptionCaches.end())
299 return *I->second;
300
301 auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
302 auto *TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
303
304 // Ok, build a new cache by scanning the function, insert it and the value
305 // handle into our map, and return the newly populated cache.
306 auto IP = AssumptionCaches.insert(std::make_pair(
307 FunctionCallbackVH(&F, this), std::make_unique<AssumptionCache>(F, TTI)));
308 assert(IP.second && "Scanning function already in the map?");
309 return *IP.first->second;
310 }
311
lookupAssumptionCache(Function & F)312 AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
313 auto I = AssumptionCaches.find_as(&F);
314 if (I != AssumptionCaches.end())
315 return I->second.get();
316 return nullptr;
317 }
318
verifyAnalysis() const319 void AssumptionCacheTracker::verifyAnalysis() const {
320 // FIXME: In the long term the verifier should not be controllable with a
321 // flag. We should either fix all passes to correctly update the assumption
322 // cache and enable the verifier unconditionally or somehow arrange for the
323 // assumption list to be updated automatically by passes.
324 if (!VerifyAssumptionCache)
325 return;
326
327 SmallPtrSet<const CallInst *, 4> AssumptionSet;
328 for (const auto &I : AssumptionCaches) {
329 for (auto &VH : I.second->assumptions())
330 if (VH)
331 AssumptionSet.insert(cast<CallInst>(VH));
332
333 for (const BasicBlock &B : cast<Function>(*I.first))
334 for (const Instruction &II : B)
335 if (match(&II, m_Intrinsic<Intrinsic::assume>()) &&
336 !AssumptionSet.count(cast<CallInst>(&II)))
337 report_fatal_error("Assumption in scanned function not in cache");
338 }
339 }
340
AssumptionCacheTracker()341 AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {
342 initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry());
343 }
344
345 AssumptionCacheTracker::~AssumptionCacheTracker() = default;
346
347 char AssumptionCacheTracker::ID = 0;
348
349 INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
350 "Assumption Cache Tracker", false, true)
351