1 //===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers the 'expect' intrinsic to LLVM metadata.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/ADT/iterator_range.h"
17 #include "llvm/IR/BasicBlock.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Function.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Transforms/Scalar.h"
28 #include "llvm/Transforms/Utils/MisExpect.h"
29
30 #include <cmath>
31
32 using namespace llvm;
33
34 #define DEBUG_TYPE "lower-expect-intrinsic"
35
36 STATISTIC(ExpectIntrinsicsHandled,
37 "Number of 'expect' intrinsic instructions handled");
38
39 // These default values are chosen to represent an extremely skewed outcome for
40 // a condition, but they leave some room for interpretation by later passes.
41 //
42 // If the documentation for __builtin_expect() was made explicit that it should
43 // only be used in extreme cases, we could make this ratio higher. As it stands,
44 // programmers may be using __builtin_expect() / llvm.expect to annotate that a
45 // branch is likely or unlikely to be taken.
46
47 // WARNING: these values are internal implementation detail of the pass.
48 // They should not be exposed to the outside of the pass, front-end codegen
49 // should emit @llvm.expect intrinsics instead of using these weights directly.
50 // Transforms should use TargetTransformInfo's getPredictableBranchThreshold().
51 static cl::opt<uint32_t> LikelyBranchWeight(
52 "likely-branch-weight", cl::Hidden, cl::init(2000),
53 cl::desc("Weight of the branch likely to be taken (default = 2000)"));
54 static cl::opt<uint32_t> UnlikelyBranchWeight(
55 "unlikely-branch-weight", cl::Hidden, cl::init(1),
56 cl::desc("Weight of the branch unlikely to be taken (default = 1)"));
57
58 static std::tuple<uint32_t, uint32_t>
getBranchWeight(Intrinsic::ID IntrinsicID,CallInst * CI,int BranchCount)59 getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) {
60 if (IntrinsicID == Intrinsic::expect) {
61 // __builtin_expect
62 return std::make_tuple(LikelyBranchWeight.getValue(),
63 UnlikelyBranchWeight.getValue());
64 } else {
65 // __builtin_expect_with_probability
66 assert(CI->getNumOperands() >= 3 &&
67 "expect with probability must have 3 arguments");
68 auto *Confidence = cast<ConstantFP>(CI->getArgOperand(2));
69 double TrueProb = Confidence->getValueAPF().convertToDouble();
70 assert((TrueProb >= 0.0 && TrueProb <= 1.0) &&
71 "probability value must be in the range [0.0, 1.0]");
72 double FalseProb = (1.0 - TrueProb) / (BranchCount - 1);
73 uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
74 uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
75 return std::make_tuple(LikelyBW, UnlikelyBW);
76 }
77 }
78
handleSwitchExpect(SwitchInst & SI)79 static bool handleSwitchExpect(SwitchInst &SI) {
80 CallInst *CI = dyn_cast<CallInst>(SI.getCondition());
81 if (!CI)
82 return false;
83
84 Function *Fn = CI->getCalledFunction();
85 if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
86 Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
87 return false;
88
89 Value *ArgValue = CI->getArgOperand(0);
90 ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
91 if (!ExpectedValue)
92 return false;
93
94 SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
95 unsigned n = SI.getNumCases(); // +1 for default case.
96 uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
97 std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
98 getBranchWeight(Fn->getIntrinsicID(), CI, n + 1);
99
100 SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
101
102 uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
103 Weights[Index] = LikelyBranchWeightVal;
104
105 misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
106
107 SI.setCondition(ArgValue);
108
109 SI.setMetadata(LLVMContext::MD_prof,
110 MDBuilder(CI->getContext()).createBranchWeights(Weights));
111
112 return true;
113 }
114
115 /// Handler for PHINodes that define the value argument to an
116 /// @llvm.expect call.
117 ///
118 /// If the operand of the phi has a constant value and it 'contradicts'
119 /// with the expected value of phi def, then the corresponding incoming
120 /// edge of the phi is unlikely to be taken. Using that information,
121 /// the branch probability info for the originating branch can be inferred.
handlePhiDef(CallInst * Expect)122 static void handlePhiDef(CallInst *Expect) {
123 Value &Arg = *Expect->getArgOperand(0);
124 ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(Expect->getArgOperand(1));
125 if (!ExpectedValue)
126 return;
127 const APInt &ExpectedPhiValue = ExpectedValue->getValue();
128 bool ExpectedValueIsLikely = true;
129 Function *Fn = Expect->getCalledFunction();
130 // If the function is expect_with_probability, then we need to take the
131 // probability into consideration. For example, in
132 // expect.with.probability.i64(i64 %a, i64 1, double 0.0), the
133 // "ExpectedValue" 1 is unlikely. This affects probability propagation later.
134 if (Fn->getIntrinsicID() == Intrinsic::expect_with_probability) {
135 auto *Confidence = cast<ConstantFP>(Expect->getArgOperand(2));
136 double TrueProb = Confidence->getValueAPF().convertToDouble();
137 ExpectedValueIsLikely = (TrueProb > 0.5);
138 }
139
140 // Walk up in backward a list of instructions that
141 // have 'copy' semantics by 'stripping' the copies
142 // until a PHI node or an instruction of unknown kind
143 // is reached. Negation via xor is also handled.
144 //
145 // C = PHI(...);
146 // B = C;
147 // A = B;
148 // D = __builtin_expect(A, 0);
149 //
150 Value *V = &Arg;
151 SmallVector<Instruction *, 4> Operations;
152 while (!isa<PHINode>(V)) {
153 if (ZExtInst *ZExt = dyn_cast<ZExtInst>(V)) {
154 V = ZExt->getOperand(0);
155 Operations.push_back(ZExt);
156 continue;
157 }
158
159 if (SExtInst *SExt = dyn_cast<SExtInst>(V)) {
160 V = SExt->getOperand(0);
161 Operations.push_back(SExt);
162 continue;
163 }
164
165 BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V);
166 if (!BinOp || BinOp->getOpcode() != Instruction::Xor)
167 return;
168
169 ConstantInt *CInt = dyn_cast<ConstantInt>(BinOp->getOperand(1));
170 if (!CInt)
171 return;
172
173 V = BinOp->getOperand(0);
174 Operations.push_back(BinOp);
175 }
176
177 // Executes the recorded operations on input 'Value'.
178 auto ApplyOperations = [&](const APInt &Value) {
179 APInt Result = Value;
180 for (auto *Op : llvm::reverse(Operations)) {
181 switch (Op->getOpcode()) {
182 case Instruction::Xor:
183 Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue();
184 break;
185 case Instruction::ZExt:
186 Result = Result.zext(Op->getType()->getIntegerBitWidth());
187 break;
188 case Instruction::SExt:
189 Result = Result.sext(Op->getType()->getIntegerBitWidth());
190 break;
191 default:
192 llvm_unreachable("Unexpected operation");
193 }
194 }
195 return Result;
196 };
197
198 auto *PhiDef = cast<PHINode>(V);
199
200 // Get the first dominating conditional branch of the operand
201 // i's incoming block.
202 auto GetDomConditional = [&](unsigned i) -> BranchInst * {
203 BasicBlock *BB = PhiDef->getIncomingBlock(i);
204 BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
205 if (BI && BI->isConditional())
206 return BI;
207 BB = BB->getSinglePredecessor();
208 if (!BB)
209 return nullptr;
210 BI = dyn_cast<BranchInst>(BB->getTerminator());
211 if (!BI || BI->isUnconditional())
212 return nullptr;
213 return BI;
214 };
215
216 // Now walk through all Phi operands to find phi oprerands with values
217 // conflicting with the expected phi output value. Any such operand
218 // indicates the incoming edge to that operand is unlikely.
219 for (unsigned i = 0, e = PhiDef->getNumIncomingValues(); i != e; ++i) {
220
221 Value *PhiOpnd = PhiDef->getIncomingValue(i);
222 ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd);
223 if (!CI)
224 continue;
225
226 // Not an interesting case when IsUnlikely is false -- we can not infer
227 // anything useful when:
228 // (1) We expect some phi output and the operand value matches it, or
229 // (2) We don't expect some phi output (i.e. the "ExpectedValue" has low
230 // probability) and the operand value doesn't match that.
231 const APInt &CurrentPhiValue = ApplyOperations(CI->getValue());
232 if (ExpectedValueIsLikely == (ExpectedPhiValue == CurrentPhiValue))
233 continue;
234
235 BranchInst *BI = GetDomConditional(i);
236 if (!BI)
237 continue;
238
239 MDBuilder MDB(PhiDef->getContext());
240
241 // There are two situations in which an operand of the PhiDef comes
242 // from a given successor of a branch instruction BI.
243 // 1) When the incoming block of the operand is the successor block;
244 // 2) When the incoming block is BI's enclosing block and the
245 // successor is the PhiDef's enclosing block.
246 //
247 // Returns true if the operand which comes from OpndIncomingBB
248 // comes from outgoing edge of BI that leads to Succ block.
249 auto *OpndIncomingBB = PhiDef->getIncomingBlock(i);
250 auto IsOpndComingFromSuccessor = [&](BasicBlock *Succ) {
251 if (OpndIncomingBB == Succ)
252 // If this successor is the incoming block for this
253 // Phi operand, then this successor does lead to the Phi.
254 return true;
255 if (OpndIncomingBB == BI->getParent() && Succ == PhiDef->getParent())
256 // Otherwise, if the edge is directly from the branch
257 // to the Phi, this successor is the one feeding this
258 // Phi operand.
259 return true;
260 return false;
261 };
262 uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
263 std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(
264 Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);
265 if (!ExpectedValueIsLikely)
266 std::swap(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
267
268 if (IsOpndComingFromSuccessor(BI->getSuccessor(1)))
269 BI->setMetadata(LLVMContext::MD_prof,
270 MDB.createBranchWeights(LikelyBranchWeightVal,
271 UnlikelyBranchWeightVal));
272 else if (IsOpndComingFromSuccessor(BI->getSuccessor(0)))
273 BI->setMetadata(LLVMContext::MD_prof,
274 MDB.createBranchWeights(UnlikelyBranchWeightVal,
275 LikelyBranchWeightVal));
276 }
277 }
278
279 // Handle both BranchInst and SelectInst.
handleBrSelExpect(BrSelInst & BSI)280 template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
281
282 // Handle non-optimized IR code like:
283 // %expval = call i64 @llvm.expect.i64(i64 %conv1, i64 1)
284 // %tobool = icmp ne i64 %expval, 0
285 // br i1 %tobool, label %if.then, label %if.end
286 //
287 // Or the following simpler case:
288 // %expval = call i1 @llvm.expect.i1(i1 %cmp, i1 1)
289 // br i1 %expval, label %if.then, label %if.end
290
291 CallInst *CI;
292
293 ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition());
294 CmpInst::Predicate Predicate;
295 ConstantInt *CmpConstOperand = nullptr;
296 if (!CmpI) {
297 CI = dyn_cast<CallInst>(BSI.getCondition());
298 Predicate = CmpInst::ICMP_NE;
299 } else {
300 Predicate = CmpI->getPredicate();
301 if (Predicate != CmpInst::ICMP_NE && Predicate != CmpInst::ICMP_EQ)
302 return false;
303
304 CmpConstOperand = dyn_cast<ConstantInt>(CmpI->getOperand(1));
305 if (!CmpConstOperand)
306 return false;
307 CI = dyn_cast<CallInst>(CmpI->getOperand(0));
308 }
309
310 if (!CI)
311 return false;
312
313 uint64_t ValueComparedTo = 0;
314 if (CmpConstOperand) {
315 if (CmpConstOperand->getBitWidth() > 64)
316 return false;
317 ValueComparedTo = CmpConstOperand->getZExtValue();
318 }
319
320 Function *Fn = CI->getCalledFunction();
321 if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
322 Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
323 return false;
324
325 Value *ArgValue = CI->getArgOperand(0);
326 ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
327 if (!ExpectedValue)
328 return false;
329
330 MDBuilder MDB(CI->getContext());
331 MDNode *Node;
332
333 uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
334 std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
335 getBranchWeight(Fn->getIntrinsicID(), CI, 2);
336
337 SmallVector<uint32_t, 4> ExpectedWeights;
338 if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
339 (Predicate == CmpInst::ICMP_EQ)) {
340 Node =
341 MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
342 ExpectedWeights = {LikelyBranchWeightVal, UnlikelyBranchWeightVal};
343 } else {
344 Node =
345 MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal);
346 ExpectedWeights = {UnlikelyBranchWeightVal, LikelyBranchWeightVal};
347 }
348
349 if (CmpI)
350 CmpI->setOperand(0, ArgValue);
351 else
352 BSI.setCondition(ArgValue);
353
354 misexpect::checkFrontendInstrumentation(BSI, ExpectedWeights);
355
356 BSI.setMetadata(LLVMContext::MD_prof, Node);
357
358 return true;
359 }
360
handleBranchExpect(BranchInst & BI)361 static bool handleBranchExpect(BranchInst &BI) {
362 if (BI.isUnconditional())
363 return false;
364
365 return handleBrSelExpect<BranchInst>(BI);
366 }
367
lowerExpectIntrinsic(Function & F)368 static bool lowerExpectIntrinsic(Function &F) {
369 bool Changed = false;
370
371 for (BasicBlock &BB : F) {
372 // Create "block_weights" metadata.
373 if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
374 if (handleBranchExpect(*BI))
375 ExpectIntrinsicsHandled++;
376 } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
377 if (handleSwitchExpect(*SI))
378 ExpectIntrinsicsHandled++;
379 }
380
381 // Remove llvm.expect intrinsics. Iterate backwards in order
382 // to process select instructions before the intrinsic gets
383 // removed.
384 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(BB))) {
385 CallInst *CI = dyn_cast<CallInst>(&Inst);
386 if (!CI) {
387 if (SelectInst *SI = dyn_cast<SelectInst>(&Inst)) {
388 if (handleBrSelExpect(*SI))
389 ExpectIntrinsicsHandled++;
390 }
391 continue;
392 }
393
394 Function *Fn = CI->getCalledFunction();
395 if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect ||
396 Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) {
397 // Before erasing the llvm.expect, walk backward to find
398 // phi that define llvm.expect's first arg, and
399 // infer branch probability:
400 handlePhiDef(CI);
401 Value *Exp = CI->getArgOperand(0);
402 CI->replaceAllUsesWith(Exp);
403 CI->eraseFromParent();
404 Changed = true;
405 }
406 }
407 }
408
409 return Changed;
410 }
411
run(Function & F,FunctionAnalysisManager &)412 PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F,
413 FunctionAnalysisManager &) {
414 if (lowerExpectIntrinsic(F))
415 return PreservedAnalyses::none();
416
417 return PreservedAnalyses::all();
418 }
419
420 namespace {
421 /// Legacy pass for lowering expect intrinsics out of the IR.
422 ///
423 /// When this pass is run over a function it uses expect intrinsics which feed
424 /// branches and switches to provide branch weight metadata for those
425 /// terminators. It then removes the expect intrinsics from the IR so the rest
426 /// of the optimizer can ignore them.
427 class LowerExpectIntrinsic : public FunctionPass {
428 public:
429 static char ID;
LowerExpectIntrinsic()430 LowerExpectIntrinsic() : FunctionPass(ID) {
431 initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
432 }
433
runOnFunction(Function & F)434 bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); }
435 };
436 } // namespace
437
438 char LowerExpectIntrinsic::ID = 0;
439 INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect",
440 "Lower 'expect' Intrinsics", false, false)
441
createLowerExpectIntrinsicPass()442 FunctionPass *llvm::createLowerExpectIntrinsicPass() {
443 return new LowerExpectIntrinsic();
444 }
445