1 /*
2 * Compile LLVM bytecode to ClamAV bytecode.
3 *
4 * Copyright (C) 2013-2022 Cisco Systems, Inc. and/or its affiliates. All rights reserved.
5 * Copyright (C) 2009-2013 Sourcefire, Inc.
6 *
7 * Authors: Török Edvin, Kevin Lin
8 *
9 * This program is free software; you can redistribute it and/or modify
10 * it under the terms of the GNU General Public License version 2 as
11 * published by the Free Software Foundation.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
21 * MA 02110-1301, USA.
22 */
23
24 #define DEBUG_TYPE "clambc-rtcheck"
25 #include "ClamBCModule.h"
26 #include "ClamBCDiagnostics.h"
27 #include "llvm30_compat.h" /* libclamav-specific */
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/PostOrderIterator.h"
30 #include "llvm/ADT/SCCIterator.h"
31 #include "llvm/Analysis/CallGraph.h"
32 #if LLVM_VERSION < 32
33 #include "llvm/Analysis/DebugInfo.h"
34 #elif LLVM_VERSION < 35
35 #include "llvm/DebugInfo.h"
36 #else
37 #include "llvm/IR/DebugInfo.h"
38 #endif
39 #if LLVM_VERSION < 35
40 #include "llvm/Analysis/Dominators.h"
41 #include "llvm/Analysis/Verifier.h"
42 #else
43 #include "llvm/IR/Dominators.h"
44 #include "llvm/IR/Verifier.h"
45 #endif
46 #include "llvm/Analysis/ConstantFolding.h"
47 #if LLVM_VERSION < 29
48 //#include "llvm/Analysis/LiveValues.h" (unused)
49 #include "llvm/Analysis/PointerTracking.h"
50 #else
51 #include "llvm/Analysis/ValueTracking.h"
52 #include "PointerTracking.h" /* included from old LLVM source */
53 #endif
54 #include "llvm/Analysis/ScalarEvolution.h"
55 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
56 #include "llvm/Analysis/ScalarEvolutionExpander.h"
57 #include "llvm/Config/config.h"
58 #include "llvm/Pass.h"
59 #include "llvm/Support/CommandLine.h"
60 #if LLVM_VERSION < 35
61 #include "llvm/Support/DataFlow.h"
62 #include "llvm/Support/InstIterator.h"
63 #include "llvm/Support/GetElementPtrTypeIterator.h"
64 #else
65 #include "llvm/IR/InstIterator.h"
66 #include "llvm/IR/GetElementPtrTypeIterator.h"
67 #endif
68 #include "llvm/ADT/DepthFirstIterator.h"
69 #include "llvm/Transforms/Scalar.h"
70 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
71 #include "llvm/Support/Debug.h"
72 #if LLVM_VERSION < 32
73 #include "llvm/Target/TargetData.h"
74 #elif LLVM_VERSION < 33
75 #include "llvm/DataLayout.h"
76 #else
77 #include "llvm/IR/DataLayout.h"
78 #endif
79 #if LLVM_VERSION < 33
80 #include "llvm/DerivedTypes.h"
81 #include "llvm/Instructions.h"
82 #include "llvm/IntrinsicInst.h"
83 #include "llvm/Intrinsics.h"
84 #include "llvm/LLVMContext.h"
85 #include "llvm/Module.h"
86 #else
87 #include "llvm/IR/DerivedTypes.h"
88 #include "llvm/IR/Instructions.h"
89 #include "llvm/IR/IntrinsicInst.h"
90 #include "llvm/IR/Intrinsics.h"
91 #include "llvm/IR/LLVMContext.h"
92 #include "llvm/IR/Module.h"
93 #endif
94
95 #if LLVM_VERSION < 33
96 #include "llvm/Support/InstVisitor.h"
97 #elif LLVM_VERSION < 35
98 #include "llvm/InstVisitor.h"
99 #else
100 #include "llvm/IR/InstVisitor.h"
101 #endif
102
103 #define DEFINEPASS(passname) passname() : FunctionPass(ID)
104
105 using namespace llvm;
106 #if LLVM_VERSION < 29
107 /* function is succeeded in later LLVM with LLVM corresponding standalone */
GetUnderlyingObject(Value * P,TargetData * TD)108 static Value *GetUnderlyingObject(Value *P, TargetData *TD)
109 {
110 return P->getUnderlyingObject();
111 }
112 #endif
113
114 namespace llvm {
115 class PtrVerifier;
116 #if LLVM_VERSION >= 29
117 void initializePtrVerifierPass(PassRegistry&);
118 #endif
119
120 class PtrVerifier : public FunctionPass {
121 private:
122 DenseSet<Function*> badFunctions;
123 std::vector<Instruction*> delInst;
124 #if LLVM_VERSION < 35
125 CallGraphNode *rootNode;
126 #else
127 CallGraph *CG;
128 #endif
129 public:
130 static char ID;
131 #if LLVM_VERSION < 35
DEFINEPASS(PtrVerifier)132 DEFINEPASS(PtrVerifier), rootNode(0), PT(), TD(), SE(), expander(),
133 #else
134 DEFINEPASS(PtrVerifier), CG(0), PT(), TD(), SE(), expander(),
135 #endif
136 DT(), AbrtBB(), Changed(false), valid(false), EP() {
137 #if LLVM_VERSION >= 29
138 initializePtrVerifierPass(*PassRegistry::getPassRegistry());
139 #endif
140 }
141
runOnFunction(Function & F)142 virtual bool runOnFunction(Function &F) {
143 /*
144 #ifndef CLAMBC_COMPILER
145 // Bytecode was already verified and had stack protector applied.
146 // We get called again because ALL bytecode functions loaded are part of
147 // the same module.
148 if (F.hasFnAttr(Attribute::StackProtectReq))
149 return false;
150 #endif
151 */
152
153 DEBUG(errs() << "Running on " << F.getName() << "\n");
154 DEBUG(F.dump());
155 Changed = false;
156 BaseMap.clear();
157 BoundsMap.clear();
158 delInst.clear();
159 AbrtBB = 0;
160 valid = true;
161
162 #if LLVM_VERSION < 35
163 if (!rootNode) {
164 rootNode = getAnalysis<CallGraph>().getRoot();
165 #else
166 if (!CG) {
167 CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
168 #endif
169 // No recursive functions for now.
170 // In the future we may insert runtime checks for stack depth.
171 #if LLVM_VERSION < 35
172 for (scc_iterator<CallGraphNode*> SCCI = scc_begin(rootNode),
173 E = scc_end(rootNode); SCCI != E; ++SCCI) {
174 #else
175 for (scc_iterator<CallGraph*> SCCI = scc_begin(CG); !SCCI.isAtEnd(); ++SCCI) {
176 #endif
177 const std::vector<CallGraphNode*> &nextSCC = *SCCI;
178 if (nextSCC.size() > 1 || SCCI.hasLoop()) {
179 errs() << "INVALID: Recursion detected, callgraph SCC components: ";
180 for (std::vector<CallGraphNode*>::const_iterator I = nextSCC.begin(),
181 E = nextSCC.end(); I != E; ++I) {
182 Function *FF = (*I)->getFunction();
183 if (FF) {
184 errs() << FF->getName() << ", ";
185 badFunctions.insert(FF);
186 }
187 }
188 if (SCCI.hasLoop())
189 errs() << "(self-loop)";
190 errs() << "\n";
191 }
192 // we could also have recursion via function pointers, but we don't
193 // allow calls to unknown functions, see runOnFunction() below
194 }
195 }
196
197 BasicBlock::iterator It = F.getEntryBlock().begin();
198 while (isa<AllocaInst>(It) || isa<PHINode>(It)) ++It;
199 EP = &*It;
200 #if LLVM_VERSION < 32
201 TD = &getAnalysis<TargetData>();
202 #elif LLVM_VERSION < 35
203 TD = &getAnalysis<DataLayout>();
204 #else
205 DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
206 TD = DLP ? &DLP->getDataLayout() : 0;
207 #endif
208 SE = &getAnalysis<ScalarEvolution>();
209 PT = &getAnalysis<PointerTracking>();
210 #if LLVM_VERSION < 35
211 DT = &getAnalysis<DominatorTree>();
212 #else
213 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
214 #endif
215 expander = new SCEVExpander(*SE OPT("SCEVexpander"));
216
217 std::vector<Instruction*> insns;
218
219 BasicBlock *LastBB = 0;
220 for (inst_iterator I=inst_begin(F),E=inst_end(F); I != E;++I) {
221 Instruction *II = &*I;
222 /* only appears in the libclamav version */
223 if (II->getParent() != LastBB) {
224 LastBB = II->getParent();
225 if (DT->getNode(LastBB) == 0)
226 continue;
227 }
228 /* end-block */
229 if (isa<LoadInst>(II) || isa<StoreInst>(II) || isa<MemIntrinsic>(II))
230 insns.push_back(II);
231 else if (CallInst *CI = dyn_cast<CallInst>(II)) {
232 Value *V = CI->getCalledValue()->stripPointerCasts();
233 Function *F = dyn_cast<Function>(V);
234 if (!F) {
235 printLocation(CI, true);
236 errs() << "Could not determine call target\n";
237 valid = 0;
238 continue;
239 }
240 // this statement disable checks on user-defined CallInst
241 //if (!F->isDeclaration())
242 //continue;
243 insns.push_back(CI);
244 }
245 }
246
247 for (unsigned Idx = 0; Idx < insns.size(); ++Idx) {
248 Instruction *II = insns[Idx];
249 DEBUG(dbgs() << "checking " << *II << "\n");
250 if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
251 constType *Ty = LI->getType();
252 valid &= validateAccess(LI->getPointerOperand(),
253 TD->getTypeAllocSize(Ty), LI);
254 } else if (StoreInst *SI = dyn_cast<StoreInst>(II)) {
255 constType *Ty = SI->getOperand(0)->getType();
256 valid &= validateAccess(SI->getPointerOperand(),
257 TD->getTypeAllocSize(Ty), SI);
258 } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) {
259 valid &= validateAccess(MI->getDest(), MI->getLength(), MI);
260 if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
261 valid &= validateAccess(MTI->getSource(), MI->getLength(), MI);
262 }
263 } else if (CallInst *CI = dyn_cast<CallInst>(II)) {
264 Value *V = CI->getCalledValue()->stripPointerCasts();
265 Function *F = cast<Function>(V);
266 constFunctionType *FTy = F->getFunctionType();
267 CallSite CS(CI);
268 if (F->getName().equals("memcmp") && FTy->getNumParams() == 3) {
269 valid &= validateAccess(CS.getArgument(0), CS.getArgument(2), CI);
270 valid &= validateAccess(CS.getArgument(1), CS.getArgument(2), CI);
271 continue;
272 }
273 unsigned i;
274 #ifdef CLAMBC_COMPILER
275 i = 0;
276 #else
277 i = 1;// skip hidden ctx*
278 #endif
279 for (;i<FTy->getNumParams();i++) {
280 if (isa<PointerType>(FTy->getParamType(i))) {
281 Value *Ptr = CS.getArgument(i);
282 if (i+1 >= FTy->getNumParams()) {
283 printLocation(CI, false);
284 errs() << "Call to external function with pointer parameter last"
285 " cannot be analyzed\n";
286 errs() << *CI << "\n";
287 valid = 0;
288 break;
289 }
290 Value *Size = CS.getArgument(i+1);
291 if (!Size->getType()->isIntegerTy()) {
292 printLocation(CI, false);
293 errs() << "Pointer argument must be followed by integer argument"
294 " representing its size\n";
295 errs() << *CI << "\n";
296 valid = 0;
297 break;
298 }
299 valid &= validateAccess(Ptr, Size, CI);
300 }
301 }
302 }
303 }
304 if (badFunctions.count(&F))
305 valid = 0;
306
307 if (!valid) {
308 DEBUG(F.dump());
309 ClamBCModule::stop("Verification found errors!", &F);
310 // replace function with call to abort
311 std::vector<constType*>args;
312 FunctionType* abrtTy = FunctionType::get(Type::getVoidTy(F.getContext()),args,false);
313 Constant *func_abort = F.getParent()->getOrInsertFunction("abort", abrtTy);
314
315 BasicBlock *BB = &F.getEntryBlock();
316 Instruction *I = &*BB->begin();
317 Instruction *UI = new UnreachableInst(F.getContext(), I);
318 CallInst *AbrtC = CallInst::Create(func_abort, "", UI);
319 AbrtC->setCallingConv(CallingConv::C);
320 AbrtC->setTailCall(true);
321 #if LLVM_VERSION < 32
322 AbrtC->setDoesNotReturn(true);
323 AbrtC->setDoesNotThrow(true);
324 #else
325 AbrtC->setDoesNotReturn();
326 AbrtC->setDoesNotThrow();
327 #endif
328 // remove all instructions from entry
329 BasicBlock::iterator BBI = I, BBE=BB->end();
330 while (BBI != BBE) {
331 if (!BBI->use_empty())
332 BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
333 BB->getInstList().erase(BBI++);
334 }
335 }
336
337 // bb#9967 - deleting obsolete termination instructions
338 for (unsigned i = 0; i < delInst.size(); ++i)
339 delInst[i]->eraseFromParent();
340
341 delete expander;
342 return Changed;
343 }
344
345 virtual void releaseMemory() {
346 badFunctions.clear();
347 }
348
349 virtual void getAnalysisUsage(AnalysisUsage &AU) const {
350 #if LLVM_VERSION < 32
351 AU.addRequired<TargetData>();
352 #elif LLVM_VERSION < 35
353 AU.addRequired<DataLayout>();
354 #else
355 AU.addRequired<DataLayoutPass>();
356 #endif
357 #if LLVM_VERSION < 35
358 AU.addRequired<DominatorTree>();
359 #else
360 AU.addRequired<DominatorTreeWrapperPass>();
361 #endif
362 AU.addRequired<ScalarEvolution>();
363 AU.addRequired<PointerTracking>();
364 #if LLVM_VERSION < 35
365 AU.addRequired<CallGraph>();
366 #else
367 AU.addRequired<CallGraphWrapperPass>();
368 #endif
369 }
370
371 bool isValid() const { return valid; }
372 private:
373 PointerTracking *PT;
374 #if LLVM_VERSION < 32
375 TargetData *TD;
376 #elif LLVM_VERSION < 35
377 DataLayout *TD;
378 #else
379 const DataLayout *TD;
380 #endif
381 ScalarEvolution *SE;
382 SCEVExpander *expander;
383 DominatorTree *DT;
384 DenseMap<Value*, Value*> BaseMap;
385 DenseMap<Value*, Value*> BoundsMap;
386 BasicBlock *AbrtBB;
387 bool Changed;
388 bool valid;
389 Instruction *EP;
390
391 Instruction *getInsertPoint(Value *V)
392 {
393 BasicBlock::iterator It = EP;
394 if (Instruction *I = dyn_cast<Instruction>(V)) {
395 It = I;
396 ++It;
397 }
398 return &*It;
399 }
400
401 Value *getPointerBase(Value *Ptr)
402 {
403 if (BaseMap.count(Ptr))
404 return BaseMap[Ptr];
405 Value *P = Ptr->stripPointerCasts();
406 if (BaseMap.count(P)) {
407 return BaseMap[Ptr] = BaseMap[P];
408 }
409 Value *P2 = GetUnderlyingObject(P, TD);
410 if (P2 != P) {
411 Value *V = getPointerBase(P2);
412 return BaseMap[Ptr] = V;
413 }
414
415 constType *P8Ty =
416 PointerType::getUnqual(Type::getInt8Ty(Ptr->getContext()));
417 if (PHINode *PN = dyn_cast<PHINode>(Ptr)) {
418 BasicBlock::iterator It = PN;
419 ++It;
420 PHINode *newPN = PHINode::Create(P8Ty, HINT(PN->getNumIncomingValues()) ".verif.base", &*It);
421 Changed = true;
422 BaseMap[Ptr] = newPN;
423
424 for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
425 Value *Inc = PN->getIncomingValue(i);
426 Value *V = getPointerBase(Inc);
427 newPN->addIncoming(V, PN->getIncomingBlock(i));
428 }
429 return newPN;
430 }
431 if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
432 BasicBlock::iterator It = SI;
433 ++It;
434 Value *TrueB = getPointerBase(SI->getTrueValue());
435 Value *FalseB = getPointerBase(SI->getFalseValue());
436 if (TrueB && FalseB) {
437 SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
438 FalseB, ".select.base", &*It);
439 Changed = true;
440 return BaseMap[Ptr] = NewSI;
441 }
442 }
443 if (Ptr->getType() != P8Ty) {
444 if (Constant *C = dyn_cast<Constant>(Ptr))
445 Ptr = ConstantExpr::getPointerCast(C, P8Ty);
446 else {
447 Instruction *I = getInsertPoint(Ptr);
448 Ptr = new BitCastInst(Ptr, P8Ty, "", I);
449 }
450 }
451 return BaseMap[Ptr] = Ptr;
452 }
453
454 Value* getValAtIdx(Function *F, unsigned Idx) {
455 Value *Val= NULL;
456
457 // check if accessed Idx is within function parameter list
458 if (Idx < F->arg_size()) {
459 Function::arg_iterator It = F->arg_begin();
460 Function::arg_iterator ItEnd = F->arg_end();
461 for (unsigned i = 0; i < Idx; ++i, ++It) {
462 // redundant check, should not be possible
463 if (It == ItEnd) {
464 // Houston, the impossible has become possible
465 //printDiagnostic("Idx is outside of Function parameters", F);
466 errs() << "Idx is outside of Function parameters\n";
467 errs() << *F << "\n";
468 //valid = 0;
469 break;
470 }
471 }
472 // retrieve value ptr of argument of F at Idx
473 Val = &(*It);
474 }
475 else {
476 // Idx is outside function parameter list
477 //printDiagnostic("Idx is outside of Function parameters", F);
478 errs() << "Idx is outside of Function parameters\n";
479 errs() << *F << "\n";
480 //valid = 0;
481 }
482 return Val;
483 }
484
485 Value* getPointerBounds(Value *Base) {
486 if (BoundsMap.count(Base))
487 return BoundsMap[Base];
488 constType *I64Ty =
489 Type::getInt64Ty(Base->getContext());
490
491 #ifndef CLAMBC_COMPILER
492 // first arg is hidden ctx
493 if (Argument *A = dyn_cast<Argument>(Base)) {
494 if (A->getArgNo() == 0) {
495 constType *Ty = cast<PointerType>(A->getType())->getElementType();
496 return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
497 } else if (Base->getType()->isPointerTy()) {
498 Function *F = A->getParent();
499 const FunctionType *FT = F->getFunctionType();
500
501 bool checks = true;
502 // last argument check
503 if (A->getArgNo() == (FT->getNumParams()-1)) {
504 //printDiagnostic("pointer argument cannot be last argument", F);
505 errs() << "pointer argument cannot be last argument\n";
506 errs() << *F << "\n";
507 checks = false;
508 }
509
510 // argument after pointer MUST be a integer (unsigned probably too)
511 if (checks && !FT->getParamType(A->getArgNo()+1)->isIntegerTy()) {
512 //printDiagnostic("argument following pointer argument is not an integer", F);
513 errs() << "argument following pointer argument is not an integer\n";
514 errs() << *F << "\n";
515 checks = false;
516 }
517
518 if (checks)
519 return BoundsMap[Base] = getValAtIdx(F, A->getArgNo()+1);
520 }
521 }
522 if (LoadInst *LI = dyn_cast<LoadInst>(Base)) {
523 Value *V = GetUnderlyingObject(LI->getPointerOperand()->stripPointerCasts(), TD);
524 if (Argument *A = dyn_cast<Argument>(V)) {
525 if (A->getArgNo() == 0) {
526 // pointers from hidden ctx are trusted to be at least the
527 // size they say they are
528 constType *Ty = cast<PointerType>(LI->getType())->getElementType();
529 return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
530 }
531 }
532 }
533 #else
534 if (Base->getType()->isPointerTy()) {
535 if (Argument *A = dyn_cast<Argument>(Base)) {
536 Function *F = A->getParent();
537 const FunctionType *FT = F->getFunctionType();
538
539 bool checks = true;
540 // last argument check
541 if (A->getArgNo() == (FT->getNumParams()-1)) {
542 //printDiagnostic("pointer argument cannot be last argument", F);
543 errs() << "pointer argument cannot be last argument\n";
544 errs() << *F << "\n";
545 checks = false;
546 }
547
548 // argument after pointer MUST be a integer (unsigned probably too)
549 if (checks && !FT->getParamType(A->getArgNo()+1)->isIntegerTy()) {
550 //printDiagnostic("argument following pointer argument is not an integer", F);
551 errs() << "argument following pointer argument is not an integer\n";
552 errs() << *F << "\n";
553 checks = false;
554 }
555
556 if (checks)
557 return BoundsMap[Base] = getValAtIdx(F, A->getArgNo()+1);
558 }
559 }
560 #endif
561 if (PHINode *PN = dyn_cast<PHINode>(Base)) {
562 BasicBlock::iterator It = PN;
563 ++It;
564 PHINode *newPN = PHINode::Create(I64Ty, HINT(PN->getNumIncomingValues()) ".verif.bounds", &*It);
565 Changed = true;
566 BoundsMap[Base] = newPN;
567
568 bool good = true;
569 for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
570 Value *Inc = PN->getIncomingValue(i);
571 Value *B = getPointerBounds(Inc);
572 if (!B) {
573 good = false;
574 B = ConstantInt::get(newPN->getType(), 0);
575 DEBUG(dbgs() << "bounds not found while solving phi node: " << *Inc
576 << "\n");
577 }
578 newPN->addIncoming(B, PN->getIncomingBlock(i));
579 }
580 if (!good)
581 newPN = 0;
582 return BoundsMap[Base] = newPN;
583 }
584 if (SelectInst *SI = dyn_cast<SelectInst>(Base)) {
585 BasicBlock::iterator It = SI;
586 ++It;
587 Value *TrueB = getPointerBounds(SI->getTrueValue());
588 Value *FalseB = getPointerBounds(SI->getFalseValue());
589 if (TrueB && FalseB) {
590 SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
591 FalseB, ".select.bounds", &*It);
592 Changed = true;
593 return BoundsMap[Base] = NewSI;
594 }
595 }
596
597 constType *Ty;
598 Value *V = PT->computeAllocationCountValue(Base, Ty);
599 if (!V) {
600 Base = Base->stripPointerCasts();
601 if (CallInst *CI = dyn_cast<CallInst>(Base)) {
602 Function *F = CI->getCalledFunction();
603 constFunctionType *FTy = F->getFunctionType();
604 // last operand is always size for this API call kind
605 if (F->isDeclaration() && FTy->getNumParams() > 0) {
606 CallSite CS(CI);
607 if (FTy->getParamType(FTy->getNumParams()-1)->isIntegerTy())
608 V = CS.getArgument(FTy->getNumParams()-1);
609 }
610 }
611 if (!V)
612 return BoundsMap[Base] = 0;
613 } else {
614 unsigned size = TD->getTypeAllocSize(Ty);
615 if (size > 1) {
616 Constant *C = cast<Constant>(V);
617 C = ConstantExpr::getMul(C,
618 ConstantInt::get(Type::getInt32Ty(C->getContext()),
619 size));
620 V = C;
621 }
622 }
623 if (V->getType() != I64Ty) {
624 if (Constant *C = dyn_cast<Constant>(V))
625 V = ConstantExpr::getZExt(C, I64Ty);
626 else {
627 Instruction *I = getInsertPoint(V);
628 V = new ZExtInst(V, I64Ty, "", I);
629 }
630 }
631 return BoundsMap[Base] = V;
632 }
633
634 MDNode *getLocation(Instruction *I, bool &Approximate, unsigned MDDbgKind)
635 {
636 Approximate = false;
637 if (MDNode *Dbg = I->getMetadata(MDDbgKind))
638 return Dbg;
639 if (!MDDbgKind)
640 return 0;
641 Approximate = true;
642 BasicBlock::iterator It = I;
643 while (It != I->getParent()->begin()) {
644 --It;
645 if (MDNode *Dbg = It->getMetadata(MDDbgKind))
646 return Dbg;
647 }
648 BasicBlock *BB = I->getParent();
649 while ((BB = BB->getUniquePredecessor())) {
650 It = BB->end();
651 while (It != BB->begin()) {
652 --It;
653 if (MDNode *Dbg = It->getMetadata(MDDbgKind))
654 return Dbg;
655 }
656 }
657 return 0;
658 }
659
660 bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I,
661 bool strict)
662 {
663 if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
664 errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
665 return false;
666 }
667 if (isa<SCEVCouldNotCompute>(Idx)) {
668 errs() << "Could not compute index: \n" << *I << "\n";
669 return false;
670 }
671 if (isa<SCEVCouldNotCompute>(Limit)) {
672 errs() << "Could not compute limit: " << *I << "\n";
673 return false;
674 }
675 BasicBlock *BB = I->getParent();
676 BasicBlock::iterator It = I;
677 BasicBlock *newBB = SplitBlock(BB, &*It, this);
678 PHINode *PN;
679 unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
680 //verifyFunction(*BB->getParent());
681 if (!AbrtBB) {
682 std::vector<constType*>args;
683 FunctionType* abrtTy = FunctionType::get(Type::getVoidTy(BB->getContext()),args,false);
684 args.push_back(Type::getInt32Ty(BB->getContext()));
685 FunctionType* rterrTy = FunctionType::get(Type::getInt32Ty(BB->getContext()),args,false);
686 Constant *func_abort = BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
687 Constant *func_rterr = BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error",
688 rterrTy);
689 AbrtBB = BasicBlock::Create(BB->getContext(), "rterr.trig", BB->getParent());
690
691 PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),HINT(1) "",
692 AbrtBB);
693 if (MDDbgKind) {
694 CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
695 RtErrCall->setCallingConv(CallingConv::C);
696 RtErrCall->setTailCall(true);
697 #if LLVM_VERSION < 32
698 RtErrCall->setDoesNotThrow(true);
699 #else
700 RtErrCall->setDoesNotThrow();
701 #endif
702 }
703 CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
704 AbrtC->setCallingConv(CallingConv::C);
705 AbrtC->setTailCall(true);
706 #if LLVM_VERSION < 32
707 AbrtC->setDoesNotReturn(true);
708 AbrtC->setDoesNotThrow(true);
709 #else
710 AbrtC->setDoesNotReturn();
711 AbrtC->setDoesNotThrow();
712 #endif
713 new UnreachableInst(BB->getContext(), AbrtBB);
714 DT->addNewBlock(AbrtBB, BB);
715 //verifyFunction(*BB->getParent());
716 } else {
717 PN = cast<PHINode>(AbrtBB->begin());
718 }
719 unsigned locationid = 0;
720 bool Approximate;
721 if (MDNode *Dbg = getLocation(I, Approximate, MDDbgKind)) {
722 DILocation Loc(Dbg);
723 locationid = Loc.getLineNumber() << 8;
724 unsigned col = Loc.getColumnNumber();
725 if (col > 254)
726 col = 254;
727 if (Approximate)
728 col = 255;
729 locationid |= col;
730 }
731 PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
732 locationid), BB);
733 TerminatorInst *TI = BB->getTerminator();
734 Value *IdxV = expander->expandCodeFor(Idx, Limit->getType(), TI);
735 Value *LimitV = expander->expandCodeFor(Limit, Limit->getType(), TI);
736 if (isa<Instruction>(IdxV) &&
737 !DT->dominates(cast<Instruction>(IdxV)->getParent(),I->getParent())) {
738 printLocation(I, true);
739 errs() << "basic block with value [ " << IdxV->getName();
740 errs() << " ] with limit [ " << LimitV->getName();
741 errs() << " ] does not dominate" << *I << "\n";
742 return false;
743 }
744 if (isa<Instruction>(LimitV) &&
745 !DT->dominates(cast<Instruction>(LimitV)->getParent(),I->getParent())) {
746 printLocation(I, true);
747 errs() << "basic block with limit [" << LimitV->getName();
748 errs() << " ] on value [ " << IdxV->getName();
749 errs() << " ] does not dominate" << *I << "\n";
750 return false;
751 }
752 Value *Cond = new ICmpInst(TI, strict ?
753 ICmpInst::ICMP_ULT :
754 ICmpInst::ICMP_ULE, IdxV, LimitV);
755 BranchInst::Create(newBB, AbrtBB, Cond, TI);
756 //TI->eraseFromParent();
757 delInst.push_back(TI);
758 // Update dominator info
759 BasicBlock *DomBB =
760 DT->findNearestCommonDominator(BB, DT->getNode(AbrtBB)->getIDom()->getBlock());
761 DT->changeImmediateDominator(AbrtBB, DomBB);
762 return true;
763 }
764
765 static void MakeCompatible(ScalarEvolution *SE, const SCEV*& LHS, const SCEV*& RHS)
766 {
767 if (const SCEVZeroExtendExpr *ZL = dyn_cast<SCEVZeroExtendExpr>(LHS))
768 LHS = ZL->getOperand();
769 if (const SCEVZeroExtendExpr *ZR = dyn_cast<SCEVZeroExtendExpr>(RHS))
770 RHS = ZR->getOperand();
771
772 constType* LTy = SE->getEffectiveSCEVType(LHS->getType());
773 constType *RTy = SE->getEffectiveSCEVType(RHS->getType());
774 if (SE->getTypeSizeInBits(RTy) > SE->getTypeSizeInBits(LTy))
775 LTy = RTy;
776 LHS = SE->getNoopOrZeroExtend(LHS, LTy);
777 RHS = SE->getNoopOrZeroExtend(RHS, LTy);
778 }
779
780 bool checkCond(Instruction *ICI, Instruction *I, bool equal)
781 {
782 #if LLVM_VERSION < 35
783 for (Value::use_iterator JU=ICI->use_begin(),JUE=ICI->use_end();
784 JU != JUE; ++JU) {
785 #else
786 for (Value::user_iterator JU=ICI->user_begin(),JUE=ICI->user_end();
787 JU != JUE; ++JU) {
788 #endif
789 Value *JU_V = *JU;
790 if (BranchInst *BI = dyn_cast<BranchInst>(JU_V)) {
791 if (!BI->isConditional())
792 continue;
793 BasicBlock *S = BI->getSuccessor(equal);
794 if (DT->dominates(S, I->getParent()))
795 return true;
796 }
797 if (BinaryOperator *BI = dyn_cast<BinaryOperator>(JU_V)) {
798 if (BI->getOpcode() == Instruction::Or &&
799 checkCond(BI, I, equal))
800 return true;
801 if (BI->getOpcode() == Instruction::And &&
802 checkCond(BI, I, !equal))
803 return true;
804 }
805 }
806 return false;
807 }
808
809 bool checkCondition(Instruction *CI, Instruction *I)
810 {
811 #if LLVM_VERSION < 35
812 for (Value::use_iterator U=CI->use_begin(),UE=CI->use_end();
813 U != UE; ++U) {
814 #else
815 for (Value::user_iterator U=CI->user_begin(),UE=CI->user_end();
816 U != UE; ++U) {
817 #endif
818 Value *U_V = *U;
819 if (ICmpInst *ICI = dyn_cast<ICmpInst>(U_V)) {
820 if (ICI->getOperand(0)->stripPointerCasts() == CI &&
821 isa<ConstantPointerNull>(ICI->getOperand(1))) {
822 if (checkCond(ICI, I, ICI->getPredicate() == ICmpInst::ICMP_EQ))
823 return true;
824 }
825 }
826 }
827 return false;
828 }
829
830 bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
831 {
832 // get base
833 Value *Base = getPointerBase(Pointer);
834
835 Value *SBase = Base->stripPointerCasts();
836 // get bounds
837 Value *Bounds = getPointerBounds(SBase);
838 if (!Bounds) {
839 printLocation(I, true);
840 errs() << "no bounds for base ";
841 printValue(SBase);
842 errs() << " while checking access to ";
843 printValue(Pointer);
844 errs() << " of length ";
845 printValue(Length);
846 errs() << "\n";
847
848 return false;
849 }
850
851 // checks if a NULL pointer check (returned from function) is made:
852 if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
853 // by checking if use is in the same block (i.e. no branching decisions)
854 if (I->getParent() == CI->getParent()) {
855 printLocation(I, true);
856 errs() << "no null pointer check of pointer ";
857 printValue(Base, false, true);
858 errs() << " obtained by function call";
859 errs() << " before use in same block\n";
860 return false;
861 }
862 // by checking if a conditional contains the values in question somewhere
863 // between their usage
864 if (!checkCondition(CI, I)) {
865 printLocation(I, true);
866 errs() << "no null pointer check of pointer ";
867 printValue(Base, false, true);
868 errs() << " obtained by function call";
869 errs() << " before use\n";
870 return false;
871 }
872 }
873
874 constType *I64Ty =
875 Type::getInt64Ty(Base->getContext());
876 const SCEV *SLen = SE->getSCEV(Length);
877 const SCEV *OffsetP = SE->getMinusSCEV(SE->getSCEV(Pointer),
878 SE->getSCEV(Base));
879 SLen = SE->getNoopOrZeroExtend(SLen, I64Ty);
880 OffsetP = SE->getNoopOrZeroExtend(OffsetP, I64Ty);
881 const SCEV *Limit = SE->getSCEV(Bounds);
882 Limit = SE->getNoopOrZeroExtend(Limit, I64Ty);
883
884 DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
885 *Length << "\n");
886 if (OffsetP == Limit) {
887 printLocation(I, true);
888 errs() << "OffsetP == Limit: " << *OffsetP << "\n";
889 errs() << " while checking access to ";
890 printValue(Pointer);
891 errs() << " of length ";
892 printValue(Length);
893 errs() << "\n";
894 return false;
895 }
896
897 if (SLen == Limit) {
898 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
899 if (SC->isZero())
900 return true;
901 }
902 errs() << "SLen == Limit: " << *SLen << "\n";
903 errs() << " while checking access to " << *Pointer << " of length "
904 << *Length << " at " << *I << "\n";
905 return false;
906 }
907
908 bool valid = true;
909 SLen = SE->getAddExpr(OffsetP, SLen);
910 // check that offset + slen <= limit;
911 // umax(offset+slen, limit) == limit is a sufficient (but not necessary
912 // condition)
913 const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
914 if (MaxL != Limit) {
915 DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
916 valid &= insertCheck(SLen, Limit, I, false);
917 }
918
919 //TODO: nullpointer check
920 const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
921 if (Max == Limit)
922 return valid;
923 DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
924
925 // check that offset < limit
926 valid &= insertCheck(OffsetP, Limit, I, true);
927 return valid;
928 }
929
930 bool validateAccess(Value *Pointer, unsigned size, Instruction *I)
931 {
932 return validateAccess(Pointer,
933 ConstantInt::get(Type::getInt32Ty(Pointer->getContext()),
934 size), I);
935 }
936
937 };
938 char PtrVerifier::ID;
939
940 } /* end namespace llvm */
941 #if LLVM_VERSION >= 29
942 INITIALIZE_PASS_BEGIN(PtrVerifier, "", "", false, false)
943 #if LLVM_VERSION < 32
944 INITIALIZE_PASS_DEPENDENCY(TargetData)
945 #elif LLVM_VERSION < 35
946 INITIALIZE_PASS_DEPENDENCY(DataLayout)
947 #else
948 INITIALIZE_PASS_DEPENDENCY(DataLayoutPass)
949 #endif
950 #if LLVM_VERSION < 35
951 INITIALIZE_PASS_DEPENDENCY(DominatorTree)
952 #else
953 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
954 #endif
955 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
956 #if LLVM_VERSION < 34
957 INITIALIZE_AG_DEPENDENCY(CallGraph)
958 #elif LLVM_VERSION < 35
959 INITIALIZE_PASS_DEPENDENCY(CallGraph)
960 #else
961 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
962 #endif
963 INITIALIZE_PASS_DEPENDENCY(PointerTracking)
964 INITIALIZE_PASS_END(PtrVerifier, "clambc-rtchecks", "ClamBC RTchecks", false, false)
965 #endif
966
967
968 llvm::Pass *createClamBCRTChecks()
969 {
970 return new PtrVerifier();
971 }
972