1 #define DEBUG_TYPE "cheri-range-checker"
2
3 #include "Mips.h"
4 #include "llvm/ADT/StringSwitch.h"
5 #include "llvm/Analysis/ValueTracking.h"
6 #include "llvm/IR/Cheri.h"
7 #include "llvm/IR/Constants.h"
8 #include "llvm/IR/DataLayout.h"
9 #include "llvm/IR/Function.h"
10 #include "llvm/IR/GlobalVariable.h"
11 #include "llvm/IR/IRBuilder.h"
12 #include "llvm/IR/InstVisitor.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/Pass.h"
18 #include "llvm/Transforms/Utils/CheriSetBounds.h"
19 #include "llvm/Transforms/Utils/Local.h"
20
21 #include <string>
22 #include <tuple>
23 #include <utility>
24
25 #include "llvm/IR/Verifier.h"
26
27 using namespace llvm;
28 using std::pair;
29
30 namespace {
31 // Operands for an allocation. Either one or two integers (constant or
32 // variable). If there are two, then they must be multiplied together.
33 struct ValueSource {
34 ValueSource() = default;
35 Value *Base = nullptr;
36 int64_t Offset = 0;
37 };
38 struct AllocOperands {
39 AllocOperands() = default;
40 Value *Size = nullptr;
41 Value *SizeMultiplier = nullptr;
42 ValueSource ValueSrc;
43 cheri::SetBoundsPointerSource Src = cheri::SetBoundsPointerSource::Unknown;
operator !=__anonba7517290111::AllocOperands44 bool operator!=(const AllocOperands &Other) {
45 return Size != Other.Size || SizeMultiplier != Other.SizeMultiplier ||
46 ValueSrc.Base != Other.ValueSrc.Base ||
47 ValueSrc.Offset != Other.ValueSrc.Offset || Src != Other.Src;
48 }
49 };
50 class CheriRangeChecker : public FunctionPass,
51 public InstVisitor<CheriRangeChecker> {
52 struct ConstantCast {
53 Instruction *Instr;
54 unsigned OpNo;
55 User *Origin;
56 };
57 std::unique_ptr<DataLayout> TD;
58 Module *M;
59 IntegerType *SizeTy;
60 PointerType *CapPtrTy;
61 SmallVector<pair<AllocOperands, Instruction *>, 32> Casts;
62 SmallVector<pair<AllocOperands, ConstantCast>, 32> ConstantCasts;
63 Function *SetLengthFn;
64
getValueSource(Value * Src)65 ValueSource getValueSource(Value *Src) {
66 int64_t Offset = 0;
67 Src = Src->stripPointerCasts();
68 auto Base = GetPointerBaseWithConstantOffset(Src, Offset, *TD);
69 if (Base && Base != Src) {
70 LLVM_DEBUG(dbgs() << "Found base: "; Base->dump());
71 Src = Base;
72 }
73 return ValueSource{Src, Offset};
74 }
75
getRangeForAllocation(ValueSource Src)76 AllocOperands getRangeForAllocation(ValueSource Src) {
77 // FIXME: This should not hardcode function names but instead use the
78 // alloc_size attribute!
79 if (auto Malloc = dyn_cast<CallBase>(Src.Base)) {
80 Function *Fn = Malloc->getCalledFunction();
81 if (!Fn)
82 return AllocOperands();
83 switch (StringSwitch<int>(Fn->getName())
84 .Case("malloc", 1)
85 .Case("valloc", 1)
86 .Case("realloc", 2)
87 .Case("aligned_alloc", 2)
88 .Case("reallocf", 2)
89 .Case("calloc", 3)
90 .Default(-1)) {
91 default:
92 return AllocOperands();
93 case 1:
94 return AllocOperands{Malloc->getArgOperand(0), nullptr, Src,
95 cheri::SetBoundsPointerSource::Heap};
96 case 2:
97 return AllocOperands{Malloc->getArgOperand(1), nullptr, Src,
98 cheri::SetBoundsPointerSource::Heap};
99 case 3:
100 return AllocOperands{Malloc->getArgOperand(0), Malloc->getArgOperand(1),
101 Src, cheri::SetBoundsPointerSource::Heap};
102 }
103 } else if (AllocaInst *AI = dyn_cast<AllocaInst>(Src.Base)) {
104 PointerType *AllocaTy = AI->getType();
105 Value *ArraySize = AI->getArraySize();
106 Type *AllocationTy = AllocaTy->getElementType();
107 unsigned ElementSize = TD->getTypeAllocSize(AllocationTy);
108 if (ElementSize == 1)
109 return AllocOperands{ArraySize, nullptr, Src,
110 cheri::SetBoundsPointerSource::Stack};
111 Value *Size = ConstantInt::get(ArraySize->getType(), ElementSize);
112 return AllocOperands{Size, ArraySize, Src,
113 cheri::SetBoundsPointerSource::Stack};
114 }
115 return AllocOperands();
116 }
RangeCheckedValue(Instruction * InsertPt,AllocOperands AO,Value * I2P,Value * & BitCast)117 Value *RangeCheckedValue(Instruction *InsertPt, AllocOperands AO, Value *I2P,
118 Value *&BitCast) {
119 LLVM_DEBUG(dbgs() << "Adding RangeChecker bounds\n";
120 dbgs() << "\tCast = "; I2P->dump();
121 dbgs() << "\tBase = "; AO.ValueSrc.Base->dump();
122 dbgs() << "\tOffset = " << AO.ValueSrc.Offset << "\n";);
123 IRBuilder<> B(InsertPt);
124 Value *Size =
125 AO.SizeMultiplier ? B.CreateMul(AO.Size, AO.SizeMultiplier) : AO.Size;
126 BitCast = B.CreatePointerBitCastOrAddrSpaceCast(AO.ValueSrc.Base, CapPtrTy);
127 if (Size->getType() != SizeTy)
128 Size = B.CreateZExt(Size, SizeTy);
129 CallInst *SetLength = B.CreateCall(SetLengthFn, {BitCast, Size});
130 if (cheri::ShouldCollectCSetBoundsStats) {
131 Value *AlignmentSource = BitCast;
132 Instruction *DebugInst = dyn_cast<Instruction>(AlignmentSource);
133 if (!DebugInst)
134 DebugInst = InsertPt;
135 cheri::addSetBoundsStats(Align(getKnownAlignment(AlignmentSource, *TD)),
136 Size, getPassName(), AO.Src, "",
137 cheri::inferSourceLocation(DebugInst));
138 }
139 if (BitCast == AO.ValueSrc.Base)
140 BitCast = SetLength;
141 Value* Result = SetLength;
142 if (AO.ValueSrc.Offset != 0) {
143 LLVM_DEBUG(dbgs() << "Inserting GEP for non-zero Offset "
144 << AO.ValueSrc.Offset << "\n";
145 BitCast->dump(););
146 Result = B.CreateConstGEP1_64(Result, AO.ValueSrc.Offset, "offs");
147 }
148 return B.CreateBitCast(Result, I2P->getType());
149 }
150
151 public:
152 static char ID;
CheriRangeChecker()153 CheriRangeChecker() : FunctionPass(ID) {}
getPassName() const154 StringRef getPassName() const override { return "CHERI range checker"; }
doInitialization(Module & Mod)155 bool doInitialization(Module &Mod) override {
156 M = &Mod;
157 TD = std::make_unique<DataLayout>(M);
158 SizeTy = IntegerType::get(M->getContext(), TD->getIndexSizeInBits(200));
159 CapPtrTy = PointerType::get(IntegerType::get(M->getContext(), 8), 200);
160 return true;
161 }
~CheriRangeChecker()162 virtual ~CheriRangeChecker() {}
checkOpcode(Value * V,unsigned Opcode)163 bool checkOpcode(Value *V, unsigned Opcode) {
164 if (Instruction *I = dyn_cast<Instruction>(V))
165 return I->getOpcode() == Opcode;
166 if (ConstantExpr *E = dyn_cast<ConstantExpr>(V))
167 return E->getOpcode() == Opcode;
168 return false;
169 }
170
testI2P(User & I2P)171 User *testI2P(User &I2P) {
172 PointerType *DestTy = dyn_cast<PointerType>(I2P.getType());
173 if (DestTy && isCheriPointer(DestTy, TD.get())) {
174 if (checkOpcode(I2P.getOperand(0), Instruction::PtrToInt)) {
175 User *P2I = cast<User>(I2P.getOperand(0));
176 PointerType *SrcTy =
177 dyn_cast<PointerType>(P2I->getOperand(0)->getType());
178 if (SrcTy && SrcTy->getAddressSpace() == 0) {
179 Value *Src = P2I->getOperand(0)->stripPointerCasts();
180 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Src))
181 return GV->hasExternalLinkage() ? 0 : P2I;
182 if (isa<AllocaInst>(Src) || isa<CallBase>(Src)) {
183 return P2I;
184 }
185 }
186 }
187 }
188 return 0;
189 }
visitAddrSpaceCast(AddrSpaceCastInst & ASC)190 void visitAddrSpaceCast(AddrSpaceCastInst &ASC) {
191 PointerType *DestTy = dyn_cast<PointerType>(ASC.getType());
192 PointerType *SrcTy = dyn_cast<PointerType>(ASC.getOperand(0)->getType());
193 LLVM_DEBUG(dbgs() << "Visiting address space cast: "; ASC.dump());
194
195 if ((DestTy && isCheriPointer(DestTy, TD.get())) &&
196 (SrcTy && SrcTy->getAddressSpace() == 0)) {
197 auto Src = getValueSource(ASC.getOperand(0));
198 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Src.Base)) {
199 if (GV->hasExternalLinkage())
200 return;
201 } else if (!(isa<AllocaInst>(Src.Base) || isa<CallBase>(Src.Base)))
202 return;
203 AllocOperands AO = getRangeForAllocation(Src);
204 if (AO != AllocOperands())
205 Casts.push_back(pair<AllocOperands, Instruction *>(AO, &ASC));
206 }
207 }
visitIntToPtrInst(IntToPtrInst & I2P)208 void visitIntToPtrInst(IntToPtrInst &I2P) {
209 if (User *P2I = testI2P(I2P)) {
210 auto Src = getValueSource(P2I->getOperand(0));
211 AllocOperands AO = getRangeForAllocation(Src);
212 if (AO != AllocOperands())
213 Casts.push_back(pair<AllocOperands, Instruction *>(AO, &I2P));
214 }
215 }
visitRet(ReturnInst & RI)216 void visitRet(ReturnInst &RI) {
217 Value *RV = RI.getReturnValue();
218 if (RV && isa<ConstantExpr>(RV)) {
219 ConstantCast C = {&RI, 0, testI2P(*cast<User>(RV))};
220 if (C.Origin) {
221 auto Src = getValueSource(C.Origin->getOperand(0));
222 AllocOperands AO = getRangeForAllocation(Src);
223 if (AO != AllocOperands())
224 ConstantCasts.push_back(pair<AllocOperands, ConstantCast>(AO, C));
225 }
226 }
227 }
visitCall(CallInst & CI)228 void visitCall(CallInst &CI) {
229 for (unsigned i = 0; i < CI.getNumOperands(); i++) {
230 Value *AV = CI.getOperand(i);
231 if (AV && isa<ConstantExpr>(AV)) {
232 ConstantCast C = {&CI, i, testI2P(*cast<User>(AV))};
233 if (C.Origin) {
234 auto Src = getValueSource(C.Origin->getOperand(0));
235 AllocOperands AO = getRangeForAllocation(Src);
236 if (AO != AllocOperands())
237 ConstantCasts.push_back(pair<AllocOperands, ConstantCast>(AO, C));
238 }
239 }
240 }
241 }
runOnFunction(Function & F)242 bool runOnFunction(Function &F) override{
243 Casts.clear();
244 ConstantCasts.clear();
245
246 visit(F);
247
248 if (!(Casts.empty() && ConstantCasts.empty())) {
249 Intrinsic::ID SetLength = Intrinsic::cheri_cap_bounds_set;
250 SetLengthFn = Intrinsic::getDeclaration(M, SetLength, SizeTy);
251 Value *BitCast = 0;
252
253 for (auto *i = Casts.begin(), *e = Casts.end(); i != e; ++i) {
254 Instruction *I2P = i->second;
255 auto InsertPt = I2P->getParent()->begin();
256 while (&(*InsertPt) != I2P) {
257 ++InsertPt;
258 }
259 ++InsertPt;
260 Value *New = RangeCheckedValue(&*InsertPt, i->first,
261 I2P, BitCast);
262 LLVM_DEBUG(dbgs() << "Replacing "; I2P->dump(); dbgs() << " with "; New->dump());
263 I2P->replaceAllUsesWith(New);
264 // XXX: why was this needed?
265 // cast<Instruction>(BitCast)->setOperand(0, I2P);
266 RecursivelyDeleteTriviallyDeadInstructions(I2P);
267 }
268 for (pair<AllocOperands, ConstantCast> *i = ConstantCasts.begin(),
269 *e = ConstantCasts.end();
270 i != e; ++i) {
271 Value *I2P = i->second.Instr->getOperand(i->second.OpNo);
272 Value *New = RangeCheckedValue(i->second.Instr, i->first, I2P, BitCast);
273 i->second.Instr->setOperand(i->second.OpNo, New);
274 }
275 return true;
276 }
277 return false;
278 }
279 };
280 }
281
282 char CheriRangeChecker::ID;
283 INITIALIZE_PASS(CheriRangeChecker, DEBUG_TYPE, "CHERI rage checker", false,
284 false)
285
286 namespace llvm {
createCheriRangeChecker(void)287 FunctionPass *createCheriRangeChecker(void) { return new CheriRangeChecker(); }
288 } // namespace llvm
289