1 /*
2  *  Compile LLVM bytecode to ClamAV bytecode.
3  *
4  *  Copyright (C) 2013-2022 Cisco Systems, Inc. and/or its affiliates. All rights reserved.
5  *  Copyright (C) 2009-2013 Sourcefire, Inc.
6  *
7  *  Authors: Török Edvin, Kevin Lin
8  *
9  *  This program is free software; you can redistribute it and/or modify
10  *  it under the terms of the GNU General Public License version 2 as
11  *  published by the Free Software Foundation.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
21  *  MA 02110-1301, USA.
22  */
23 
24 #define DEBUG_TYPE "clambc-rtcheck"
25 #include "ClamBCModule.h"
26 #include "ClamBCDiagnostics.h"
27 #include "llvm30_compat.h" /* libclamav-specific */
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/PostOrderIterator.h"
30 #include "llvm/ADT/SCCIterator.h"
31 #include "llvm/Analysis/CallGraph.h"
32 #if LLVM_VERSION < 32
33 #include "llvm/Analysis/DebugInfo.h"
34 #elif LLVM_VERSION < 35
35 #include "llvm/DebugInfo.h"
36 #else
37 #include "llvm/IR/DebugInfo.h"
38 #endif
39 #if LLVM_VERSION < 35
40 #include "llvm/Analysis/Dominators.h"
41 #include "llvm/Analysis/Verifier.h"
42 #else
43 #include "llvm/IR/Dominators.h"
44 #include "llvm/IR/Verifier.h"
45 #endif
46 #include "llvm/Analysis/ConstantFolding.h"
47 #if LLVM_VERSION < 29
48 //#include "llvm/Analysis/LiveValues.h" (unused)
49 #include "llvm/Analysis/PointerTracking.h"
50 #else
51 #include "llvm/Analysis/ValueTracking.h"
52 #include "PointerTracking.h" /* included from old LLVM source */
53 #endif
54 #include "llvm/Analysis/ScalarEvolution.h"
55 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
56 #include "llvm/Analysis/ScalarEvolutionExpander.h"
57 #include "llvm/Config/config.h"
58 #include "llvm/Pass.h"
59 #include "llvm/Support/CommandLine.h"
60 #if LLVM_VERSION < 35
61 #include "llvm/Support/DataFlow.h"
62 #include "llvm/Support/InstIterator.h"
63 #include "llvm/Support/GetElementPtrTypeIterator.h"
64 #else
65 #include "llvm/IR/InstIterator.h"
66 #include "llvm/IR/GetElementPtrTypeIterator.h"
67 #endif
68 #include "llvm/ADT/DepthFirstIterator.h"
69 #include "llvm/Transforms/Scalar.h"
70 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
71 #include "llvm/Support/Debug.h"
72 #if LLVM_VERSION < 32
73 #include "llvm/Target/TargetData.h"
74 #elif LLVM_VERSION < 33
75 #include "llvm/DataLayout.h"
76 #else
77 #include "llvm/IR/DataLayout.h"
78 #endif
79 #if LLVM_VERSION < 33
80 #include "llvm/DerivedTypes.h"
81 #include "llvm/Instructions.h"
82 #include "llvm/IntrinsicInst.h"
83 #include "llvm/Intrinsics.h"
84 #include "llvm/LLVMContext.h"
85 #include "llvm/Module.h"
86 #else
87 #include "llvm/IR/DerivedTypes.h"
88 #include "llvm/IR/Instructions.h"
89 #include "llvm/IR/IntrinsicInst.h"
90 #include "llvm/IR/Intrinsics.h"
91 #include "llvm/IR/LLVMContext.h"
92 #include "llvm/IR/Module.h"
93 #endif
94 
95 #if LLVM_VERSION < 33
96 #include "llvm/Support/InstVisitor.h"
97 #elif LLVM_VERSION < 35
98 #include "llvm/InstVisitor.h"
99 #else
100 #include "llvm/IR/InstVisitor.h"
101 #endif
102 
103 #define DEFINEPASS(passname) passname() : FunctionPass(ID)
104 
105 using namespace llvm;
106 #if LLVM_VERSION < 29
107 /* function is succeeded in later LLVM with LLVM corresponding standalone */
GetUnderlyingObject(Value * P,TargetData * TD)108 static Value *GetUnderlyingObject(Value *P, TargetData *TD)
109 {
110     return P->getUnderlyingObject();
111 }
112 #endif
113 
114 namespace llvm {
115   class PtrVerifier;
116 #if LLVM_VERSION >= 29
117   void initializePtrVerifierPass(PassRegistry&);
118 #endif
119 
120   class PtrVerifier : public FunctionPass {
121   private:
122       DenseSet<Function*> badFunctions;
123       std::vector<Instruction*> delInst;
124 #if LLVM_VERSION < 35
125       CallGraphNode *rootNode;
126 #else
127       CallGraph *CG;
128 #endif
129   public:
130       static char ID;
131 #if LLVM_VERSION < 35
DEFINEPASS(PtrVerifier)132       DEFINEPASS(PtrVerifier), rootNode(0), PT(), TD(), SE(), expander(),
133 #else
134       DEFINEPASS(PtrVerifier), CG(0), PT(), TD(), SE(), expander(),
135 #endif
136           DT(), AbrtBB(), Changed(false), valid(false), EP() {
137 #if LLVM_VERSION >= 29
138           initializePtrVerifierPass(*PassRegistry::getPassRegistry());
139 #endif
140       }
141 
runOnFunction(Function & F)142       virtual bool runOnFunction(Function &F) {
143           /*
144 #ifndef CLAMBC_COMPILER
145           // Bytecode was already verified and had stack protector applied.
146           // We get called again because ALL bytecode functions loaded are part of
147           // the same module.
148           if (F.hasFnAttr(Attribute::StackProtectReq))
149               return false;
150 #endif
151           */
152 
153           DEBUG(errs() << "Running on " << F.getName() << "\n");
154           DEBUG(F.dump());
155           Changed = false;
156           BaseMap.clear();
157           BoundsMap.clear();
158           delInst.clear();
159           AbrtBB = 0;
160           valid = true;
161 
162 #if LLVM_VERSION < 35
163           if (!rootNode) {
164               rootNode = getAnalysis<CallGraph>().getRoot();
165 #else
166           if (!CG) {
167               CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
168 #endif
169               // No recursive functions for now.
170               // In the future we may insert runtime checks for stack depth.
171 #if LLVM_VERSION < 35
172               for (scc_iterator<CallGraphNode*> SCCI = scc_begin(rootNode),
173                        E = scc_end(rootNode); SCCI != E; ++SCCI) {
174 #else
175               for (scc_iterator<CallGraph*> SCCI = scc_begin(CG); !SCCI.isAtEnd(); ++SCCI) {
176 #endif
177                   const std::vector<CallGraphNode*> &nextSCC = *SCCI;
178                   if (nextSCC.size() > 1 || SCCI.hasLoop()) {
179                       errs() << "INVALID: Recursion detected, callgraph SCC components: ";
180                       for (std::vector<CallGraphNode*>::const_iterator I = nextSCC.begin(),
181                                E = nextSCC.end(); I != E; ++I) {
182                           Function *FF = (*I)->getFunction();
183                           if (FF) {
184                               errs() << FF->getName() << ", ";
185                               badFunctions.insert(FF);
186                           }
187                       }
188                       if (SCCI.hasLoop())
189                           errs() << "(self-loop)";
190                       errs() << "\n";
191                   }
192                   // we could also have recursion via function pointers, but we don't
193                   // allow calls to unknown functions, see runOnFunction() below
194               }
195           }
196 
197           BasicBlock::iterator It = F.getEntryBlock().begin();
198           while (isa<AllocaInst>(It) || isa<PHINode>(It)) ++It;
199           EP = &*It;
200 #if LLVM_VERSION < 32
201           TD = &getAnalysis<TargetData>();
202 #elif LLVM_VERSION < 35
203           TD = &getAnalysis<DataLayout>();
204 #else
205           DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
206           TD = DLP ? &DLP->getDataLayout() : 0;
207 #endif
208           SE = &getAnalysis<ScalarEvolution>();
209           PT = &getAnalysis<PointerTracking>();
210 #if LLVM_VERSION < 35
211           DT = &getAnalysis<DominatorTree>();
212 #else
213           DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
214 #endif
215           expander = new SCEVExpander(*SE OPT("SCEVexpander"));
216 
217           std::vector<Instruction*> insns;
218 
219           BasicBlock *LastBB = 0;
220           for (inst_iterator I=inst_begin(F),E=inst_end(F); I != E;++I) {
221               Instruction *II = &*I;
222               /* only appears in the libclamav version */
223               if (II->getParent() != LastBB) {
224                   LastBB = II->getParent();
225                   if (DT->getNode(LastBB) == 0)
226                       continue;
227               }
228               /* end-block */
229               if (isa<LoadInst>(II) || isa<StoreInst>(II) || isa<MemIntrinsic>(II))
230                   insns.push_back(II);
231               else if (CallInst *CI = dyn_cast<CallInst>(II)) {
232                   Value *V = CI->getCalledValue()->stripPointerCasts();
233                   Function *F = dyn_cast<Function>(V);
234                   if (!F) {
235                       printLocation(CI, true);
236                       errs() << "Could not determine call target\n";
237                       valid = 0;
238                       continue;
239                   }
240                   // this statement disable checks on user-defined CallInst
241                   //if (!F->isDeclaration())
242                   //continue;
243                   insns.push_back(CI);
244               }
245           }
246 
247           for (unsigned Idx = 0; Idx < insns.size(); ++Idx) {
248               Instruction *II = insns[Idx];
249               DEBUG(dbgs() << "checking " << *II << "\n");
250               if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
251                   constType *Ty = LI->getType();
252                   valid &= validateAccess(LI->getPointerOperand(),
253                                           TD->getTypeAllocSize(Ty), LI);
254               } else if (StoreInst *SI = dyn_cast<StoreInst>(II)) {
255                   constType *Ty = SI->getOperand(0)->getType();
256                   valid &= validateAccess(SI->getPointerOperand(),
257                                           TD->getTypeAllocSize(Ty), SI);
258               } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) {
259                   valid &= validateAccess(MI->getDest(), MI->getLength(), MI);
260                   if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
261                       valid &= validateAccess(MTI->getSource(), MI->getLength(), MI);
262                   }
263               } else if (CallInst *CI = dyn_cast<CallInst>(II)) {
264                   Value *V = CI->getCalledValue()->stripPointerCasts();
265                   Function *F = cast<Function>(V);
266                   constFunctionType *FTy = F->getFunctionType();
267                   CallSite CS(CI);
268                   if (F->getName().equals("memcmp") && FTy->getNumParams() == 3) {
269                       valid &= validateAccess(CS.getArgument(0), CS.getArgument(2), CI);
270                       valid &= validateAccess(CS.getArgument(1), CS.getArgument(2), CI);
271                       continue;
272                   }
273                   unsigned i;
274 #ifdef CLAMBC_COMPILER
275                   i = 0;
276 #else
277                   i = 1;// skip hidden ctx*
278 #endif
279                   for (;i<FTy->getNumParams();i++) {
280                       if (isa<PointerType>(FTy->getParamType(i))) {
281                           Value *Ptr = CS.getArgument(i);
282                           if (i+1 >= FTy->getNumParams()) {
283                               printLocation(CI, false);
284                               errs() << "Call to external function with pointer parameter last"
285                                   " cannot be analyzed\n";
286                               errs() << *CI << "\n";
287                               valid = 0;
288                               break;
289                           }
290                           Value *Size = CS.getArgument(i+1);
291                           if (!Size->getType()->isIntegerTy()) {
292                               printLocation(CI, false);
293                               errs() << "Pointer argument must be followed by integer argument"
294                                   " representing its size\n";
295                               errs() << *CI << "\n";
296                               valid = 0;
297                               break;
298                           }
299                           valid &= validateAccess(Ptr, Size, CI);
300                       }
301                   }
302               }
303           }
304           if (badFunctions.count(&F))
305               valid = 0;
306 
307           if (!valid) {
308               DEBUG(F.dump());
309               ClamBCModule::stop("Verification found errors!", &F);
310               // replace function with call to abort
311               std::vector<constType*>args;
312               FunctionType* abrtTy = FunctionType::get(Type::getVoidTy(F.getContext()),args,false);
313               Constant *func_abort = F.getParent()->getOrInsertFunction("abort", abrtTy);
314 
315               BasicBlock *BB = &F.getEntryBlock();
316               Instruction *I = &*BB->begin();
317               Instruction *UI = new UnreachableInst(F.getContext(), I);
318               CallInst *AbrtC = CallInst::Create(func_abort, "", UI);
319               AbrtC->setCallingConv(CallingConv::C);
320               AbrtC->setTailCall(true);
321 #if LLVM_VERSION < 32
322               AbrtC->setDoesNotReturn(true);
323               AbrtC->setDoesNotThrow(true);
324 #else
325               AbrtC->setDoesNotReturn();
326               AbrtC->setDoesNotThrow();
327 #endif
328               // remove all instructions from entry
329               BasicBlock::iterator BBI = I, BBE=BB->end();
330               while (BBI != BBE) {
331                   if (!BBI->use_empty())
332                       BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
333                   BB->getInstList().erase(BBI++);
334               }
335           }
336 
337           // bb#9967 - deleting obsolete termination instructions
338           for (unsigned i = 0; i < delInst.size(); ++i)
339               delInst[i]->eraseFromParent();
340 
341           delete expander;
342           return Changed;
343       }
344 
345       virtual void releaseMemory() {
346           badFunctions.clear();
347       }
348 
349       virtual void getAnalysisUsage(AnalysisUsage &AU) const {
350 #if LLVM_VERSION < 32
351           AU.addRequired<TargetData>();
352 #elif LLVM_VERSION < 35
353           AU.addRequired<DataLayout>();
354 #else
355           AU.addRequired<DataLayoutPass>();
356 #endif
357 #if LLVM_VERSION < 35
358           AU.addRequired<DominatorTree>();
359 #else
360           AU.addRequired<DominatorTreeWrapperPass>();
361 #endif
362           AU.addRequired<ScalarEvolution>();
363           AU.addRequired<PointerTracking>();
364 #if LLVM_VERSION < 35
365           AU.addRequired<CallGraph>();
366 #else
367           AU.addRequired<CallGraphWrapperPass>();
368 #endif
369       }
370 
371       bool isValid() const { return valid; }
372   private:
373       PointerTracking *PT;
374 #if LLVM_VERSION < 32
375       TargetData *TD;
376 #elif LLVM_VERSION < 35
377       DataLayout *TD;
378 #else
379       const DataLayout *TD;
380 #endif
381       ScalarEvolution *SE;
382       SCEVExpander *expander;
383       DominatorTree *DT;
384       DenseMap<Value*, Value*> BaseMap;
385       DenseMap<Value*, Value*> BoundsMap;
386       BasicBlock *AbrtBB;
387       bool Changed;
388       bool valid;
389       Instruction *EP;
390 
391       Instruction *getInsertPoint(Value *V)
392       {
393           BasicBlock::iterator It = EP;
394           if (Instruction *I = dyn_cast<Instruction>(V)) {
395               It = I;
396               ++It;
397           }
398           return &*It;
399       }
400 
401       Value *getPointerBase(Value *Ptr)
402       {
403           if (BaseMap.count(Ptr))
404               return BaseMap[Ptr];
405           Value *P = Ptr->stripPointerCasts();
406           if (BaseMap.count(P)) {
407               return BaseMap[Ptr] = BaseMap[P];
408           }
409           Value *P2 = GetUnderlyingObject(P, TD);
410           if (P2 != P) {
411               Value *V = getPointerBase(P2);
412               return BaseMap[Ptr] = V;
413           }
414 
415           constType *P8Ty =
416               PointerType::getUnqual(Type::getInt8Ty(Ptr->getContext()));
417           if (PHINode *PN = dyn_cast<PHINode>(Ptr)) {
418               BasicBlock::iterator It = PN;
419               ++It;
420               PHINode *newPN = PHINode::Create(P8Ty, HINT(PN->getNumIncomingValues()) ".verif.base", &*It);
421               Changed = true;
422               BaseMap[Ptr] = newPN;
423 
424               for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
425                   Value *Inc = PN->getIncomingValue(i);
426                   Value *V = getPointerBase(Inc);
427                   newPN->addIncoming(V, PN->getIncomingBlock(i));
428               }
429               return newPN;
430           }
431           if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
432               BasicBlock::iterator It = SI;
433               ++It;
434               Value *TrueB = getPointerBase(SI->getTrueValue());
435               Value *FalseB = getPointerBase(SI->getFalseValue());
436               if (TrueB && FalseB) {
437                   SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
438                                                          FalseB, ".select.base", &*It);
439                   Changed = true;
440                   return BaseMap[Ptr] = NewSI;
441               }
442           }
443           if (Ptr->getType() != P8Ty) {
444               if (Constant *C = dyn_cast<Constant>(Ptr))
445                   Ptr = ConstantExpr::getPointerCast(C, P8Ty);
446               else {
447                   Instruction *I = getInsertPoint(Ptr);
448                   Ptr = new BitCastInst(Ptr, P8Ty, "", I);
449               }
450           }
451           return BaseMap[Ptr] = Ptr;
452       }
453 
454       Value* getValAtIdx(Function *F, unsigned Idx) {
455           Value *Val= NULL;
456 
457           // check if accessed Idx is within function parameter list
458           if (Idx < F->arg_size()) {
459               Function::arg_iterator It = F->arg_begin();
460               Function::arg_iterator ItEnd = F->arg_end();
461               for (unsigned i = 0; i < Idx; ++i, ++It) {
462                   // redundant check, should not be possible
463                   if (It == ItEnd) {
464                       // Houston, the impossible has become possible
465                       //printDiagnostic("Idx is outside of Function parameters", F);
466                       errs() << "Idx is outside of Function parameters\n";
467                       errs() << *F << "\n";
468                       //valid = 0;
469                       break;
470                   }
471               }
472               // retrieve value ptr of argument of F at Idx
473               Val = &(*It);
474           }
475           else {
476               // Idx is outside function parameter list
477               //printDiagnostic("Idx is outside of Function parameters", F);
478               errs() << "Idx is outside of Function parameters\n";
479               errs() << *F << "\n";
480               //valid = 0;
481           }
482           return Val;
483       }
484 
485       Value* getPointerBounds(Value *Base) {
486           if (BoundsMap.count(Base))
487               return BoundsMap[Base];
488           constType *I64Ty =
489               Type::getInt64Ty(Base->getContext());
490 
491 #ifndef CLAMBC_COMPILER
492           // first arg is hidden ctx
493           if (Argument *A = dyn_cast<Argument>(Base)) {
494               if (A->getArgNo() == 0) {
495                   constType *Ty = cast<PointerType>(A->getType())->getElementType();
496                   return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
497               } else if (Base->getType()->isPointerTy()) {
498                   Function *F = A->getParent();
499                   const FunctionType *FT = F->getFunctionType();
500 
501                   bool checks = true;
502                   // last argument check
503                   if (A->getArgNo() == (FT->getNumParams()-1)) {
504                       //printDiagnostic("pointer argument cannot be last argument", F);
505                       errs() << "pointer argument cannot be last argument\n";
506                       errs() << *F << "\n";
507                       checks = false;
508                   }
509 
510                   // argument after pointer MUST be a integer (unsigned probably too)
511                   if (checks && !FT->getParamType(A->getArgNo()+1)->isIntegerTy()) {
512                       //printDiagnostic("argument following pointer argument is not an integer", F);
513                       errs() << "argument following pointer argument is not an integer\n";
514                       errs() << *F << "\n";
515                       checks = false;
516                   }
517 
518                   if (checks)
519                       return BoundsMap[Base] = getValAtIdx(F, A->getArgNo()+1);
520               }
521           }
522           if (LoadInst *LI = dyn_cast<LoadInst>(Base)) {
523               Value *V = GetUnderlyingObject(LI->getPointerOperand()->stripPointerCasts(), TD);
524               if (Argument *A = dyn_cast<Argument>(V)) {
525                   if (A->getArgNo() == 0) {
526                       // pointers from hidden ctx are trusted to be at least the
527                       // size they say they are
528                       constType *Ty = cast<PointerType>(LI->getType())->getElementType();
529                       return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
530                   }
531               }
532           }
533 #else
534           if (Base->getType()->isPointerTy()) {
535               if (Argument *A = dyn_cast<Argument>(Base)) {
536                   Function *F = A->getParent();
537                   const FunctionType *FT = F->getFunctionType();
538 
539                   bool checks = true;
540                   // last argument check
541                   if (A->getArgNo() == (FT->getNumParams()-1)) {
542                       //printDiagnostic("pointer argument cannot be last argument", F);
543                       errs() << "pointer argument cannot be last argument\n";
544                       errs() << *F << "\n";
545                       checks = false;
546                   }
547 
548                   // argument after pointer MUST be a integer (unsigned probably too)
549                   if (checks && !FT->getParamType(A->getArgNo()+1)->isIntegerTy()) {
550                       //printDiagnostic("argument following pointer argument is not an integer", F);
551                       errs() << "argument following pointer argument is not an integer\n";
552                       errs() << *F << "\n";
553                       checks = false;
554                   }
555 
556                   if (checks)
557                       return BoundsMap[Base] = getValAtIdx(F, A->getArgNo()+1);
558               }
559           }
560 #endif
561           if (PHINode *PN = dyn_cast<PHINode>(Base)) {
562               BasicBlock::iterator It = PN;
563               ++It;
564               PHINode *newPN = PHINode::Create(I64Ty, HINT(PN->getNumIncomingValues()) ".verif.bounds", &*It);
565               Changed = true;
566               BoundsMap[Base] = newPN;
567 
568               bool good = true;
569               for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
570                   Value *Inc = PN->getIncomingValue(i);
571                   Value *B = getPointerBounds(Inc);
572                   if (!B) {
573                       good = false;
574                       B = ConstantInt::get(newPN->getType(), 0);
575                       DEBUG(dbgs() << "bounds not found while solving phi node: " << *Inc
576                             << "\n");
577                   }
578                   newPN->addIncoming(B, PN->getIncomingBlock(i));
579               }
580               if (!good)
581                   newPN = 0;
582               return BoundsMap[Base] = newPN;
583           }
584           if (SelectInst *SI = dyn_cast<SelectInst>(Base)) {
585               BasicBlock::iterator It = SI;
586               ++It;
587               Value *TrueB = getPointerBounds(SI->getTrueValue());
588               Value *FalseB = getPointerBounds(SI->getFalseValue());
589               if (TrueB && FalseB) {
590                   SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
591                                                          FalseB, ".select.bounds", &*It);
592                   Changed = true;
593                   return BoundsMap[Base] = NewSI;
594               }
595           }
596 
597           constType *Ty;
598           Value *V = PT->computeAllocationCountValue(Base, Ty);
599           if (!V) {
600               Base = Base->stripPointerCasts();
601               if (CallInst *CI = dyn_cast<CallInst>(Base)) {
602                   Function *F = CI->getCalledFunction();
603                   constFunctionType *FTy = F->getFunctionType();
604                   // last operand is always size for this API call kind
605                   if (F->isDeclaration() && FTy->getNumParams() > 0) {
606                       CallSite CS(CI);
607                       if (FTy->getParamType(FTy->getNumParams()-1)->isIntegerTy())
608                           V = CS.getArgument(FTy->getNumParams()-1);
609                   }
610               }
611               if (!V)
612                   return BoundsMap[Base] = 0;
613           } else {
614               unsigned size = TD->getTypeAllocSize(Ty);
615               if (size > 1) {
616                   Constant *C = cast<Constant>(V);
617                   C = ConstantExpr::getMul(C,
618                                            ConstantInt::get(Type::getInt32Ty(C->getContext()),
619                                                             size));
620                   V = C;
621               }
622           }
623           if (V->getType() != I64Ty) {
624               if (Constant *C = dyn_cast<Constant>(V))
625                   V = ConstantExpr::getZExt(C, I64Ty);
626               else {
627                   Instruction *I = getInsertPoint(V);
628                   V = new ZExtInst(V, I64Ty, "", I);
629               }
630           }
631           return BoundsMap[Base] = V;
632       }
633 
634       MDNode *getLocation(Instruction *I, bool &Approximate, unsigned MDDbgKind)
635       {
636           Approximate = false;
637           if (MDNode *Dbg = I->getMetadata(MDDbgKind))
638               return Dbg;
639           if (!MDDbgKind)
640               return 0;
641           Approximate = true;
642           BasicBlock::iterator It = I;
643           while (It != I->getParent()->begin()) {
644               --It;
645               if (MDNode *Dbg = It->getMetadata(MDDbgKind))
646                   return Dbg;
647           }
648           BasicBlock *BB = I->getParent();
649           while ((BB = BB->getUniquePredecessor())) {
650               It = BB->end();
651               while (It != BB->begin()) {
652                   --It;
653                   if (MDNode *Dbg = It->getMetadata(MDDbgKind))
654                       return Dbg;
655               }
656           }
657           return 0;
658       }
659 
660       bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I,
661                        bool strict)
662       {
663           if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
664               errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
665               return false;
666           }
667           if (isa<SCEVCouldNotCompute>(Idx)) {
668               errs() << "Could not compute index: \n" << *I << "\n";
669               return false;
670           }
671           if (isa<SCEVCouldNotCompute>(Limit)) {
672               errs() << "Could not compute limit: " << *I << "\n";
673               return false;
674           }
675           BasicBlock *BB = I->getParent();
676           BasicBlock::iterator It = I;
677           BasicBlock *newBB = SplitBlock(BB, &*It, this);
678           PHINode *PN;
679           unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
680           //verifyFunction(*BB->getParent());
681           if (!AbrtBB) {
682               std::vector<constType*>args;
683               FunctionType* abrtTy = FunctionType::get(Type::getVoidTy(BB->getContext()),args,false);
684               args.push_back(Type::getInt32Ty(BB->getContext()));
685               FunctionType* rterrTy = FunctionType::get(Type::getInt32Ty(BB->getContext()),args,false);
686               Constant *func_abort = BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
687               Constant *func_rterr = BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error",
688                                                                                        rterrTy);
689               AbrtBB = BasicBlock::Create(BB->getContext(), "rterr.trig", BB->getParent());
690 
691               PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),HINT(1) "",
692                                    AbrtBB);
693               if (MDDbgKind) {
694                   CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
695                   RtErrCall->setCallingConv(CallingConv::C);
696                   RtErrCall->setTailCall(true);
697 #if LLVM_VERSION < 32
698                   RtErrCall->setDoesNotThrow(true);
699 #else
700                   RtErrCall->setDoesNotThrow();
701 #endif
702               }
703               CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
704               AbrtC->setCallingConv(CallingConv::C);
705               AbrtC->setTailCall(true);
706 #if LLVM_VERSION < 32
707               AbrtC->setDoesNotReturn(true);
708               AbrtC->setDoesNotThrow(true);
709 #else
710               AbrtC->setDoesNotReturn();
711               AbrtC->setDoesNotThrow();
712 #endif
713               new UnreachableInst(BB->getContext(), AbrtBB);
714               DT->addNewBlock(AbrtBB, BB);
715               //verifyFunction(*BB->getParent());
716           } else {
717               PN = cast<PHINode>(AbrtBB->begin());
718           }
719           unsigned locationid = 0;
720           bool Approximate;
721           if (MDNode *Dbg = getLocation(I, Approximate, MDDbgKind)) {
722               DILocation Loc(Dbg);
723               locationid = Loc.getLineNumber() << 8;
724               unsigned col = Loc.getColumnNumber();
725               if (col > 254)
726                   col = 254;
727               if (Approximate)
728                   col = 255;
729               locationid |= col;
730           }
731           PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
732                                            locationid), BB);
733           TerminatorInst *TI = BB->getTerminator();
734           Value *IdxV = expander->expandCodeFor(Idx, Limit->getType(), TI);
735           Value *LimitV = expander->expandCodeFor(Limit, Limit->getType(), TI);
736           if (isa<Instruction>(IdxV) &&
737               !DT->dominates(cast<Instruction>(IdxV)->getParent(),I->getParent())) {
738               printLocation(I, true);
739               errs() << "basic block with value [ " << IdxV->getName();
740               errs() << " ] with limit [ " << LimitV->getName();
741               errs() << " ] does not dominate" << *I << "\n";
742               return false;
743           }
744           if (isa<Instruction>(LimitV) &&
745               !DT->dominates(cast<Instruction>(LimitV)->getParent(),I->getParent())) {
746               printLocation(I, true);
747               errs() << "basic block with limit [" << LimitV->getName();
748               errs() << " ] on value [ " << IdxV->getName();
749               errs() << " ] does not dominate" << *I << "\n";
750               return false;
751           }
752           Value *Cond = new ICmpInst(TI, strict ?
753                                      ICmpInst::ICMP_ULT :
754                                      ICmpInst::ICMP_ULE, IdxV, LimitV);
755           BranchInst::Create(newBB, AbrtBB, Cond, TI);
756           //TI->eraseFromParent();
757           delInst.push_back(TI);
758           // Update dominator info
759           BasicBlock *DomBB =
760               DT->findNearestCommonDominator(BB, DT->getNode(AbrtBB)->getIDom()->getBlock());
761           DT->changeImmediateDominator(AbrtBB, DomBB);
762           return true;
763       }
764 
765       static void MakeCompatible(ScalarEvolution *SE, const SCEV*& LHS, const SCEV*& RHS)
766       {
767           if (const SCEVZeroExtendExpr *ZL = dyn_cast<SCEVZeroExtendExpr>(LHS))
768               LHS = ZL->getOperand();
769           if (const SCEVZeroExtendExpr *ZR = dyn_cast<SCEVZeroExtendExpr>(RHS))
770               RHS = ZR->getOperand();
771 
772           constType* LTy = SE->getEffectiveSCEVType(LHS->getType());
773           constType *RTy = SE->getEffectiveSCEVType(RHS->getType());
774           if (SE->getTypeSizeInBits(RTy) > SE->getTypeSizeInBits(LTy))
775               LTy = RTy;
776           LHS = SE->getNoopOrZeroExtend(LHS, LTy);
777           RHS = SE->getNoopOrZeroExtend(RHS, LTy);
778       }
779 
780       bool checkCond(Instruction *ICI, Instruction *I, bool equal)
781       {
782 #if LLVM_VERSION < 35
783           for (Value::use_iterator JU=ICI->use_begin(),JUE=ICI->use_end();
784                JU != JUE; ++JU) {
785 #else
786           for (Value::user_iterator JU=ICI->user_begin(),JUE=ICI->user_end();
787                JU != JUE; ++JU) {
788 #endif
789               Value *JU_V = *JU;
790               if (BranchInst *BI = dyn_cast<BranchInst>(JU_V)) {
791                   if (!BI->isConditional())
792                       continue;
793                   BasicBlock *S = BI->getSuccessor(equal);
794                   if (DT->dominates(S, I->getParent()))
795                       return true;
796               }
797               if (BinaryOperator *BI = dyn_cast<BinaryOperator>(JU_V)) {
798                   if (BI->getOpcode() == Instruction::Or &&
799                       checkCond(BI, I, equal))
800                       return true;
801                   if (BI->getOpcode() == Instruction::And &&
802                       checkCond(BI, I, !equal))
803                       return true;
804               }
805           }
806           return false;
807       }
808 
809       bool checkCondition(Instruction *CI, Instruction *I)
810       {
811 #if LLVM_VERSION < 35
812           for (Value::use_iterator U=CI->use_begin(),UE=CI->use_end();
813                U != UE; ++U) {
814 #else
815           for (Value::user_iterator U=CI->user_begin(),UE=CI->user_end();
816                U != UE; ++U) {
817 #endif
818               Value *U_V = *U;
819               if (ICmpInst *ICI = dyn_cast<ICmpInst>(U_V)) {
820                   if (ICI->getOperand(0)->stripPointerCasts() == CI &&
821                       isa<ConstantPointerNull>(ICI->getOperand(1))) {
822                       if (checkCond(ICI, I, ICI->getPredicate() == ICmpInst::ICMP_EQ))
823                           return true;
824                   }
825               }
826           }
827           return false;
828       }
829 
830       bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
831       {
832           // get base
833           Value *Base = getPointerBase(Pointer);
834 
835           Value *SBase = Base->stripPointerCasts();
836           // get bounds
837           Value *Bounds = getPointerBounds(SBase);
838           if (!Bounds) {
839               printLocation(I, true);
840               errs() << "no bounds for base ";
841               printValue(SBase);
842               errs() << " while checking access to ";
843               printValue(Pointer);
844               errs() << " of length ";
845               printValue(Length);
846               errs() << "\n";
847 
848               return false;
849           }
850 
851           // checks if a NULL pointer check (returned from function) is made:
852           if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
853               // by checking if use is in the same block (i.e. no branching decisions)
854               if (I->getParent() == CI->getParent()) {
855                   printLocation(I, true);
856                   errs() << "no null pointer check of pointer ";
857                   printValue(Base, false, true);
858                   errs() << " obtained by function call";
859                   errs() << " before use in same block\n";
860                   return false;
861               }
862               // by checking if a conditional contains the values in question somewhere
863               // between their usage
864               if (!checkCondition(CI, I)) {
865                   printLocation(I, true);
866                   errs() << "no null pointer check of pointer ";
867                   printValue(Base, false, true);
868                   errs() << " obtained by function call";
869                   errs() << " before use\n";
870                   return false;
871               }
872           }
873 
874       constType *I64Ty =
875           Type::getInt64Ty(Base->getContext());
876       const SCEV *SLen = SE->getSCEV(Length);
877       const SCEV *OffsetP = SE->getMinusSCEV(SE->getSCEV(Pointer),
878                                              SE->getSCEV(Base));
879       SLen = SE->getNoopOrZeroExtend(SLen, I64Ty);
880       OffsetP = SE->getNoopOrZeroExtend(OffsetP, I64Ty);
881       const SCEV *Limit = SE->getSCEV(Bounds);
882       Limit = SE->getNoopOrZeroExtend(Limit, I64Ty);
883 
884       DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
885             *Length << "\n");
886       if (OffsetP == Limit) {
887           printLocation(I, true);
888           errs() << "OffsetP == Limit: " << *OffsetP << "\n";
889           errs() << " while checking access to ";
890           printValue(Pointer);
891           errs() << " of length ";
892           printValue(Length);
893           errs() << "\n";
894           return false;
895       }
896 
897       if (SLen == Limit) {
898           if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
899               if (SC->isZero())
900                   return true;
901           }
902           errs() << "SLen == Limit: " << *SLen << "\n";
903           errs() << " while checking access to " << *Pointer << " of length "
904                  << *Length << " at " << *I << "\n";
905           return false;
906       }
907 
908       bool valid = true;
909       SLen = SE->getAddExpr(OffsetP, SLen);
910       // check that offset + slen <= limit;
911       // umax(offset+slen, limit) == limit is a sufficient (but not necessary
912       // condition)
913       const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
914       if (MaxL != Limit) {
915           DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
916           valid &= insertCheck(SLen, Limit, I, false);
917       }
918 
919       //TODO: nullpointer check
920       const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
921       if (Max == Limit)
922           return valid;
923       DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
924 
925       // check that offset < limit
926       valid &= insertCheck(OffsetP, Limit, I, true);
927       return valid;
928       }
929 
930       bool validateAccess(Value *Pointer, unsigned size, Instruction *I)
931       {
932           return validateAccess(Pointer,
933                                 ConstantInt::get(Type::getInt32Ty(Pointer->getContext()),
934                                                  size), I);
935       }
936 
937   };
938     char PtrVerifier::ID;
939 
940 } /* end namespace llvm */
941 #if LLVM_VERSION >= 29
942 INITIALIZE_PASS_BEGIN(PtrVerifier, "", "", false, false)
943 #if LLVM_VERSION < 32
944 INITIALIZE_PASS_DEPENDENCY(TargetData)
945 #elif LLVM_VERSION < 35
946 INITIALIZE_PASS_DEPENDENCY(DataLayout)
947 #else
948 INITIALIZE_PASS_DEPENDENCY(DataLayoutPass)
949 #endif
950 #if LLVM_VERSION < 35
951 INITIALIZE_PASS_DEPENDENCY(DominatorTree)
952 #else
953 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
954 #endif
955 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
956 #if LLVM_VERSION < 34
957 INITIALIZE_AG_DEPENDENCY(CallGraph)
958 #elif LLVM_VERSION < 35
959 INITIALIZE_PASS_DEPENDENCY(CallGraph)
960 #else
961 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
962 #endif
963 INITIALIZE_PASS_DEPENDENCY(PointerTracking)
964 INITIALIZE_PASS_END(PtrVerifier, "clambc-rtchecks", "ClamBC RTchecks", false, false)
965 #endif
966 
967 
968 llvm::Pass *createClamBCRTChecks()
969 {
970     return new PtrVerifier();
971 }
972