1 #define DEBUG_TYPE "cheri-range-checker"
2 
3 #include "Mips.h"
4 #include "llvm/ADT/StringSwitch.h"
5 #include "llvm/Analysis/ValueTracking.h"
6 #include "llvm/IR/Cheri.h"
7 #include "llvm/IR/Constants.h"
8 #include "llvm/IR/DataLayout.h"
9 #include "llvm/IR/Function.h"
10 #include "llvm/IR/GlobalVariable.h"
11 #include "llvm/IR/IRBuilder.h"
12 #include "llvm/IR/InstVisitor.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/Intrinsics.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/Pass.h"
18 #include "llvm/Transforms/Utils/CheriSetBounds.h"
19 #include "llvm/Transforms/Utils/Local.h"
20 
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 
25 #include "llvm/IR/Verifier.h"
26 
27 using namespace llvm;
28 using std::pair;
29 
30 namespace {
31 // Operands for an allocation.  Either one or two integers (constant or
32 // variable).  If there are two, then they must be multiplied together.
33 struct ValueSource {
34   ValueSource() = default;
35   Value *Base = nullptr;
36   int64_t Offset = 0;
37 };
38 struct AllocOperands {
39   AllocOperands() = default;
40   Value *Size = nullptr;
41   Value *SizeMultiplier = nullptr;
42   ValueSource ValueSrc;
43   cheri::SetBoundsPointerSource Src = cheri::SetBoundsPointerSource::Unknown;
operator !=__anonba7517290111::AllocOperands44   bool operator!=(const AllocOperands &Other) {
45     return Size != Other.Size || SizeMultiplier != Other.SizeMultiplier ||
46         ValueSrc.Base != Other.ValueSrc.Base ||
47         ValueSrc.Offset != Other.ValueSrc.Offset || Src != Other.Src;
48   }
49 };
50 class CheriRangeChecker : public FunctionPass,
51                           public InstVisitor<CheriRangeChecker> {
52   struct ConstantCast {
53     Instruction *Instr;
54     unsigned OpNo;
55     User *Origin;
56   };
57   std::unique_ptr<DataLayout> TD;
58   Module *M;
59   IntegerType *SizeTy;
60   PointerType *CapPtrTy;
61   SmallVector<pair<AllocOperands, Instruction *>, 32> Casts;
62   SmallVector<pair<AllocOperands, ConstantCast>, 32> ConstantCasts;
63   Function *SetLengthFn;
64 
getValueSource(Value * Src)65   ValueSource getValueSource(Value *Src) {
66     int64_t Offset = 0;
67     Src = Src->stripPointerCasts();
68     auto Base = GetPointerBaseWithConstantOffset(Src, Offset, *TD);
69     if (Base && Base != Src) {
70       LLVM_DEBUG(dbgs() << "Found base: "; Base->dump());
71       Src = Base;
72     }
73     return ValueSource{Src, Offset};
74   }
75 
getRangeForAllocation(ValueSource Src)76   AllocOperands getRangeForAllocation(ValueSource Src) {
77     // FIXME: This should not hardcode function names but instead use the
78     //  alloc_size attribute!
79     if (auto Malloc = dyn_cast<CallBase>(Src.Base)) {
80       Function *Fn = Malloc->getCalledFunction();
81       if (!Fn)
82         return AllocOperands();
83       switch (StringSwitch<int>(Fn->getName())
84                   .Case("malloc", 1)
85                   .Case("valloc", 1)
86                   .Case("realloc", 2)
87                   .Case("aligned_alloc", 2)
88                   .Case("reallocf", 2)
89                   .Case("calloc", 3)
90                   .Default(-1)) {
91       default:
92         return AllocOperands();
93       case 1:
94         return AllocOperands{Malloc->getArgOperand(0), nullptr, Src,
95                              cheri::SetBoundsPointerSource::Heap};
96       case 2:
97         return AllocOperands{Malloc->getArgOperand(1), nullptr, Src,
98                              cheri::SetBoundsPointerSource::Heap};
99       case 3:
100         return AllocOperands{Malloc->getArgOperand(0), Malloc->getArgOperand(1),
101                              Src, cheri::SetBoundsPointerSource::Heap};
102       }
103     } else if (AllocaInst *AI = dyn_cast<AllocaInst>(Src.Base)) {
104       PointerType *AllocaTy = AI->getType();
105       Value *ArraySize = AI->getArraySize();
106       Type *AllocationTy = AllocaTy->getElementType();
107       unsigned ElementSize = TD->getTypeAllocSize(AllocationTy);
108       if (ElementSize == 1)
109         return AllocOperands{ArraySize, nullptr, Src,
110                              cheri::SetBoundsPointerSource::Stack};
111       Value *Size = ConstantInt::get(ArraySize->getType(), ElementSize);
112       return AllocOperands{Size, ArraySize, Src,
113                            cheri::SetBoundsPointerSource::Stack};
114     }
115     return AllocOperands();
116   }
RangeCheckedValue(Instruction * InsertPt,AllocOperands AO,Value * I2P,Value * & BitCast)117   Value *RangeCheckedValue(Instruction *InsertPt, AllocOperands AO, Value *I2P,
118                            Value *&BitCast) {
119     LLVM_DEBUG(dbgs() << "Adding RangeChecker bounds\n";
120                dbgs() << "\tCast = "; I2P->dump();
121                dbgs() << "\tBase = "; AO.ValueSrc.Base->dump();
122                dbgs() << "\tOffset = " << AO.ValueSrc.Offset << "\n";);
123     IRBuilder<> B(InsertPt);
124     Value *Size =
125         AO.SizeMultiplier ? B.CreateMul(AO.Size, AO.SizeMultiplier) : AO.Size;
126     BitCast = B.CreatePointerBitCastOrAddrSpaceCast(AO.ValueSrc.Base, CapPtrTy);
127     if (Size->getType() != SizeTy)
128       Size = B.CreateZExt(Size, SizeTy);
129     CallInst *SetLength = B.CreateCall(SetLengthFn, {BitCast, Size});
130     if (cheri::ShouldCollectCSetBoundsStats) {
131       Value *AlignmentSource = BitCast;
132       Instruction *DebugInst = dyn_cast<Instruction>(AlignmentSource);
133       if (!DebugInst)
134         DebugInst = InsertPt;
135       cheri::addSetBoundsStats(Align(getKnownAlignment(AlignmentSource, *TD)),
136                                Size, getPassName(), AO.Src, "",
137                                cheri::inferSourceLocation(DebugInst));
138     }
139     if (BitCast == AO.ValueSrc.Base)
140       BitCast = SetLength;
141     Value* Result = SetLength;
142     if (AO.ValueSrc.Offset != 0) {
143       LLVM_DEBUG(dbgs() << "Inserting GEP for non-zero Offset "
144                         << AO.ValueSrc.Offset << "\n";
145                      BitCast->dump(););
146       Result = B.CreateConstGEP1_64(Result, AO.ValueSrc.Offset, "offs");
147     }
148     return B.CreateBitCast(Result, I2P->getType());
149   }
150 
151 public:
152   static char ID;
CheriRangeChecker()153   CheriRangeChecker() : FunctionPass(ID) {}
getPassName() const154   StringRef getPassName() const override { return "CHERI range checker"; }
doInitialization(Module & Mod)155   bool doInitialization(Module &Mod) override {
156     M = &Mod;
157     TD = std::make_unique<DataLayout>(M);
158     SizeTy = IntegerType::get(M->getContext(), TD->getIndexSizeInBits(200));
159     CapPtrTy = PointerType::get(IntegerType::get(M->getContext(), 8), 200);
160     return true;
161   }
~CheriRangeChecker()162   virtual ~CheriRangeChecker() {}
checkOpcode(Value * V,unsigned Opcode)163   bool checkOpcode(Value *V, unsigned Opcode) {
164     if (Instruction *I = dyn_cast<Instruction>(V))
165       return I->getOpcode() == Opcode;
166     if (ConstantExpr *E = dyn_cast<ConstantExpr>(V))
167       return E->getOpcode() == Opcode;
168     return false;
169   }
170 
testI2P(User & I2P)171   User *testI2P(User &I2P) {
172     PointerType *DestTy = dyn_cast<PointerType>(I2P.getType());
173     if (DestTy && isCheriPointer(DestTy, TD.get())) {
174       if (checkOpcode(I2P.getOperand(0), Instruction::PtrToInt)) {
175         User *P2I = cast<User>(I2P.getOperand(0));
176         PointerType *SrcTy =
177             dyn_cast<PointerType>(P2I->getOperand(0)->getType());
178         if (SrcTy && SrcTy->getAddressSpace() == 0) {
179           Value *Src = P2I->getOperand(0)->stripPointerCasts();
180           if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Src))
181             return GV->hasExternalLinkage() ? 0 : P2I;
182           if (isa<AllocaInst>(Src) || isa<CallBase>(Src)) {
183             return P2I;
184           }
185         }
186       }
187     }
188     return 0;
189   }
visitAddrSpaceCast(AddrSpaceCastInst & ASC)190   void visitAddrSpaceCast(AddrSpaceCastInst &ASC) {
191     PointerType *DestTy = dyn_cast<PointerType>(ASC.getType());
192     PointerType *SrcTy = dyn_cast<PointerType>(ASC.getOperand(0)->getType());
193     LLVM_DEBUG(dbgs() << "Visiting address space cast: "; ASC.dump());
194 
195     if ((DestTy && isCheriPointer(DestTy, TD.get())) &&
196         (SrcTy && SrcTy->getAddressSpace() == 0)) {
197       auto Src = getValueSource(ASC.getOperand(0));
198       if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Src.Base)) {
199         if (GV->hasExternalLinkage())
200           return;
201       } else if (!(isa<AllocaInst>(Src.Base) || isa<CallBase>(Src.Base)))
202         return;
203       AllocOperands AO = getRangeForAllocation(Src);
204       if (AO != AllocOperands())
205         Casts.push_back(pair<AllocOperands, Instruction *>(AO, &ASC));
206     }
207   }
visitIntToPtrInst(IntToPtrInst & I2P)208   void visitIntToPtrInst(IntToPtrInst &I2P) {
209     if (User *P2I = testI2P(I2P)) {
210       auto Src = getValueSource(P2I->getOperand(0));
211       AllocOperands AO = getRangeForAllocation(Src);
212       if (AO != AllocOperands())
213         Casts.push_back(pair<AllocOperands, Instruction *>(AO, &I2P));
214     }
215   }
visitRet(ReturnInst & RI)216   void visitRet(ReturnInst &RI) {
217     Value *RV = RI.getReturnValue();
218     if (RV && isa<ConstantExpr>(RV)) {
219       ConstantCast C = {&RI, 0, testI2P(*cast<User>(RV))};
220       if (C.Origin) {
221         auto Src = getValueSource(C.Origin->getOperand(0));
222         AllocOperands AO = getRangeForAllocation(Src);
223         if (AO != AllocOperands())
224           ConstantCasts.push_back(pair<AllocOperands, ConstantCast>(AO, C));
225       }
226     }
227   }
visitCall(CallInst & CI)228   void visitCall(CallInst &CI) {
229     for (unsigned i = 0; i < CI.getNumOperands(); i++) {
230       Value *AV = CI.getOperand(i);
231       if (AV && isa<ConstantExpr>(AV)) {
232         ConstantCast C = {&CI, i, testI2P(*cast<User>(AV))};
233         if (C.Origin) {
234           auto Src = getValueSource(C.Origin->getOperand(0));
235           AllocOperands AO = getRangeForAllocation(Src);
236           if (AO != AllocOperands())
237             ConstantCasts.push_back(pair<AllocOperands, ConstantCast>(AO, C));
238         }
239       }
240     }
241   }
runOnFunction(Function & F)242   bool runOnFunction(Function &F) override{
243     Casts.clear();
244     ConstantCasts.clear();
245 
246     visit(F);
247 
248     if (!(Casts.empty() && ConstantCasts.empty())) {
249       Intrinsic::ID SetLength = Intrinsic::cheri_cap_bounds_set;
250       SetLengthFn = Intrinsic::getDeclaration(M, SetLength, SizeTy);
251       Value *BitCast = 0;
252 
253       for (auto *i = Casts.begin(), *e = Casts.end(); i != e; ++i) {
254         Instruction *I2P = i->second;
255         auto InsertPt = I2P->getParent()->begin();
256         while (&(*InsertPt) != I2P) {
257           ++InsertPt;
258         }
259         ++InsertPt;
260         Value *New = RangeCheckedValue(&*InsertPt, i->first,
261                                        I2P, BitCast);
262         LLVM_DEBUG(dbgs() << "Replacing "; I2P->dump(); dbgs() << "  with "; New->dump());
263         I2P->replaceAllUsesWith(New);
264         // XXX: why was this needed?
265         // cast<Instruction>(BitCast)->setOperand(0, I2P);
266         RecursivelyDeleteTriviallyDeadInstructions(I2P);
267       }
268       for (pair<AllocOperands, ConstantCast> *i = ConstantCasts.begin(),
269                                              *e = ConstantCasts.end();
270            i != e; ++i) {
271         Value *I2P = i->second.Instr->getOperand(i->second.OpNo);
272         Value *New = RangeCheckedValue(i->second.Instr, i->first, I2P, BitCast);
273         i->second.Instr->setOperand(i->second.OpNo, New);
274       }
275       return true;
276     }
277     return false;
278   }
279 };
280 }
281 
282 char CheriRangeChecker::ID;
283 INITIALIZE_PASS(CheriRangeChecker, DEBUG_TYPE, "CHERI rage checker", false,
284                 false)
285 
286 namespace llvm {
createCheriRangeChecker(void)287 FunctionPass *createCheriRangeChecker(void) { return new CheriRangeChecker(); }
288 } // namespace llvm
289