1 //===-- AbstractCallSite.cpp - Implementation of abstract call sites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements abstract call sites which unify the interface for
10 // direct, indirect, and callback call sites.
11 //
12 // For more information see:
13 // https://llvm.org/devmtg/2018-10/talk-abstracts.html#talk20
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/IR/AbstractCallSite.h"
18 #include "llvm/ADT/Statistic.h"
19 
20 using namespace llvm;
21 
22 #define DEBUG_TYPE "abstract-call-sites"
23 
24 STATISTIC(NumCallbackCallSites, "Number of callback call sites created");
25 STATISTIC(NumDirectAbstractCallSites,
26           "Number of direct abstract call sites created");
27 STATISTIC(NumInvalidAbstractCallSitesUnknownUse,
28           "Number of invalid abstract call sites created (unknown use)");
29 STATISTIC(NumInvalidAbstractCallSitesUnknownCallee,
30           "Number of invalid abstract call sites created (unknown callee)");
31 STATISTIC(NumInvalidAbstractCallSitesNoCallback,
32           "Number of invalid abstract call sites created (no callback)");
33 
getCallbackUses(const CallBase & CB,SmallVectorImpl<const Use * > & CallbackUses)34 void AbstractCallSite::getCallbackUses(
35     const CallBase &CB, SmallVectorImpl<const Use *> &CallbackUses) {
36   const Function *Callee = CB.getCalledFunction();
37   if (!Callee)
38     return;
39 
40   MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
41   if (!CallbackMD)
42     return;
43 
44   for (const MDOperand &Op : CallbackMD->operands()) {
45     MDNode *OpMD = cast<MDNode>(Op.get());
46     auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
47     uint64_t CBCalleeIdx =
48         cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
49     if (CBCalleeIdx < CB.arg_size())
50       CallbackUses.push_back(CB.arg_begin() + CBCalleeIdx);
51   }
52 }
53 
54 /// Create an abstract call site from a use.
AbstractCallSite(const Use * U)55 AbstractCallSite::AbstractCallSite(const Use *U)
56     : CB(dyn_cast<CallBase>(U->getUser())) {
57 
58   // First handle unknown users.
59   if (!CB) {
60 
61     // If the use is actually in a constant cast expression which itself
62     // has only one use, we look through the constant cast expression.
63     // This happens by updating the use @p U to the use of the constant
64     // cast expression and afterwards re-initializing CB accordingly.
65     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U->getUser()))
66       if (CE->hasOneUse() && CE->isCast()) {
67         U = &*CE->use_begin();
68         CB = dyn_cast<CallBase>(U->getUser());
69       }
70 
71     if (!CB) {
72       NumInvalidAbstractCallSitesUnknownUse++;
73       return;
74     }
75   }
76 
77   // Then handle direct or indirect calls. Thus, if U is the callee of the
78   // call site CB it is not a callback and we are done.
79   if (CB->isCallee(U)) {
80     NumDirectAbstractCallSites++;
81     return;
82   }
83 
84   // If we cannot identify the broker function we cannot create a callback and
85   // invalidate the abstract call site.
86   Function *Callee = CB->getCalledFunction();
87   if (!Callee) {
88     NumInvalidAbstractCallSitesUnknownCallee++;
89     CB = nullptr;
90     return;
91   }
92 
93   MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
94   if (!CallbackMD) {
95     NumInvalidAbstractCallSitesNoCallback++;
96     CB = nullptr;
97     return;
98   }
99 
100   unsigned UseIdx = CB->getArgOperandNo(U);
101   MDNode *CallbackEncMD = nullptr;
102   for (const MDOperand &Op : CallbackMD->operands()) {
103     MDNode *OpMD = cast<MDNode>(Op.get());
104     auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
105     uint64_t CBCalleeIdx =
106         cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
107     if (CBCalleeIdx != UseIdx)
108       continue;
109     CallbackEncMD = OpMD;
110     break;
111   }
112 
113   if (!CallbackEncMD) {
114     NumInvalidAbstractCallSitesNoCallback++;
115     CB = nullptr;
116     return;
117   }
118 
119   NumCallbackCallSites++;
120 
121   assert(CallbackEncMD->getNumOperands() >= 2 && "Incomplete !callback metadata");
122 
123   unsigned NumCallOperands = CB->arg_size();
124   // Skip the var-arg flag at the end when reading the metadata.
125   for (unsigned u = 0, e = CallbackEncMD->getNumOperands() - 1; u < e; u++) {
126     Metadata *OpAsM = CallbackEncMD->getOperand(u).get();
127     auto *OpAsCM = cast<ConstantAsMetadata>(OpAsM);
128     assert(OpAsCM->getType()->isIntegerTy(64) &&
129            "Malformed !callback metadata");
130 
131     int64_t Idx = cast<ConstantInt>(OpAsCM->getValue())->getSExtValue();
132     assert(-1 <= Idx && Idx <= NumCallOperands &&
133            "Out-of-bounds !callback metadata index");
134 
135     CI.ParameterEncoding.push_back(Idx);
136   }
137 
138   if (!Callee->isVarArg())
139     return;
140 
141   Metadata *VarArgFlagAsM =
142       CallbackEncMD->getOperand(CallbackEncMD->getNumOperands() - 1).get();
143   auto *VarArgFlagAsCM = cast<ConstantAsMetadata>(VarArgFlagAsM);
144   assert(VarArgFlagAsCM->getType()->isIntegerTy(1) &&
145          "Malformed !callback metadata var-arg flag");
146 
147   if (VarArgFlagAsCM->getValue()->isNullValue())
148     return;
149 
150   // Add all variadic arguments at the end.
151   for (unsigned u = Callee->arg_size(); u < NumCallOperands; u++)
152     CI.ParameterEncoding.push_back(u);
153 }
154