1 /*
2    american fuzzy lop++ - LLVM LTO instrumentation pass
3    ----------------------------------------------------
4 
5    Written by Marc Heuse <mh@mh-sec.de>
6 
7    Copyright 2019-2020 AFLplusplus Project. All rights reserved.
8 
9    Licensed under the Apache License, Version 2.0 (the "License");
10    you may not use this file except in compliance with the License.
11    You may obtain a copy of the License at:
12 
13      http://www.apache.org/licenses/LICENSE-2.0
14 
15    This library is plugged into LLVM when invoking clang through afl-clang-lto.
16 
17  */
18 
19 #define AFL_LLVM_PASS
20 
21 #include "config.h"
22 #include "debug.h"
23 
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <unistd.h>
27 #include <string.h>
28 #include <sys/time.h>
29 #include <sys/types.h>
30 #include <sys/stat.h>
31 #include <fcntl.h>
32 #include <ctype.h>
33 
34 #include <list>
35 #include <string>
36 #include <fstream>
37 #include <set>
38 
39 #include "llvm/Config/llvm-config.h"
40 #include "llvm/ADT/Statistic.h"
41 #include "llvm/IR/IRBuilder.h"
42 #include "llvm/IR/LegacyPassManager.h"
43 #include "llvm/IR/BasicBlock.h"
44 #include "llvm/IR/Module.h"
45 #include "llvm/IR/DebugInfo.h"
46 #include "llvm/IR/CFG.h"
47 #include "llvm/IR/Verifier.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/raw_ostream.h"
50 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
51 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
52 #include "llvm/Analysis/LoopInfo.h"
53 #include "llvm/Analysis/ValueTracking.h"
54 #include "llvm/Pass.h"
55 #include "llvm/IR/Constants.h"
56 
57 #include "afl-llvm-common.h"
58 
59 #ifndef O_DSYNC
60   #define O_DSYNC O_SYNC
61 #endif
62 
63 using namespace llvm;
64 
65 namespace {
66 
67 class AFLdict2filePass : public ModulePass {
68 
69  public:
70   static char ID;
71 
AFLdict2filePass()72   AFLdict2filePass() : ModulePass(ID) {
73 
74     if (getenv("AFL_DEBUG")) debug = 1;
75 
76   }
77 
78   bool runOnModule(Module &M) override;
79 
80 };
81 
82 }  // namespace
83 
dict2file(int fd,u8 * mem,u32 len)84 void dict2file(int fd, u8 *mem, u32 len) {
85 
86   u32  i, j, binary = 0;
87   char line[MAX_AUTO_EXTRA * 8], tmp[8];
88 
89   strcpy(line, "\"");
90   j = 1;
91   for (i = 0; i < len; i++) {
92 
93     if (isprint(mem[i]) && mem[i] != '\\' && mem[i] != '"') {
94 
95       line[j++] = mem[i];
96 
97     } else {
98 
99       if (i + 1 != len || mem[i] != 0 || binary || len == 4 || len == 8) {
100 
101         line[j] = 0;
102         sprintf(tmp, "\\x%02x", (u8)mem[i]);
103         strcat(line, tmp);
104         j = strlen(line);
105 
106       }
107 
108       binary = 1;
109 
110     }
111 
112   }
113 
114   line[j] = 0;
115   strcat(line, "\"\n");
116   if (write(fd, line, strlen(line)) <= 0)
117     PFATAL("Could not write to dictionary file");
118   fsync(fd);
119 
120   if (!be_quiet) fprintf(stderr, "Found dictionary token: %s", line);
121 
122 }
123 
runOnModule(Module & M)124 bool AFLdict2filePass::runOnModule(Module &M) {
125 
126   DenseMap<Value *, std::string *> valueMap;
127   char *                           ptr;
128   int                              fd, found = 0;
129 
130   /* Show a banner */
131   setvbuf(stdout, NULL, _IONBF, 0);
132 
133   if ((isatty(2) && !getenv("AFL_QUIET")) || debug) {
134 
135     SAYF(cCYA "afl-llvm-dict2file" VERSION cRST
136               " by Marc \"vanHauser\" Heuse <mh@mh-sec.de>\n");
137 
138   } else
139 
140     be_quiet = 1;
141 
142   scanForDangerousFunctions(&M);
143 
144   ptr = getenv("AFL_LLVM_DICT2FILE");
145 
146   if (!ptr || *ptr != '/')
147     FATAL("AFL_LLVM_DICT2FILE is not set to an absolute path: %s", ptr);
148 
149   if ((fd = open(ptr, O_WRONLY | O_APPEND | O_CREAT | O_DSYNC, 0644)) < 0)
150     PFATAL("Could not open/create %s.", ptr);
151 
152   /* Instrument all the things! */
153 
154   for (auto &F : M) {
155 
156     if (isIgnoreFunction(&F)) continue;
157     if (!isInInstrumentList(&F) || !F.size()) { continue; }
158 
159     /*  Some implementation notes.
160      *
161      *  We try to handle 3 cases:
162      *  - memcmp("foo", arg, 3) <- literal string
163      *  - static char globalvar[] = "foo";
164      *    memcmp(globalvar, arg, 3) <- global variable
165      *  - char localvar[] = "foo";
166      *    memcmp(locallvar, arg, 3) <- local variable
167      *
168      *  The local variable case is the hardest. We can only detect that
169      *  case if there is no reassignment or change in the variable.
170      *  And it might not work across llvm version.
171      *  What we do is hooking the initializer function for local variables
172      *  (llvm.memcpy.p0i8.p0i8.i64) and note the string and the assigned
173      *  variable. And if that variable is then used in a compare function
174      *  we use that noted string.
175      *  This seems not to work for tokens that have a size <= 4 :-(
176      *
177      *  - if the compared length is smaller than the string length we
178      *    save the full string. This is likely better for fuzzing but
179      *    might be wrong in a few cases depending on optimizers
180      *
181      *  - not using StringRef because there is a bug in the llvm 11
182      *    checkout I am using which sometimes points to wrong strings
183      *
184      *  Over and out. Took me a full day. damn. mh/vh
185      */
186 
187     for (auto &BB : F) {
188 
189       for (auto &IN : BB) {
190 
191         CallInst *callInst = nullptr;
192         CmpInst * cmpInst = nullptr;
193 
194         if ((cmpInst = dyn_cast<CmpInst>(&IN))) {
195 
196           Value *      op = cmpInst->getOperand(1);
197           ConstantInt *ilen = dyn_cast<ConstantInt>(op);
198 
199           /* We skip > 64 bit integers. why? first because their value is
200              difficult to obtain, and second because clang does not support
201              literals > 64 bit (as of llvm 12) */
202 
203           if (ilen && ilen->uge(0xffffffffffffffff) == false) {
204 
205             u64 val2 = 0, val = ilen->getZExtValue();
206             u32 len = 0;
207             if (val > 0x10000 && val < 0xffffffff) len = 4;
208             if (val > 0x100000001 && val < 0xffffffffffffffff) len = 8;
209 
210             if (len) {
211 
212               auto c = cmpInst->getPredicate();
213 
214               switch (c) {
215 
216                 case CmpInst::FCMP_OGT:  // fall through
217                 case CmpInst::FCMP_OLE:  // fall through
218                 case CmpInst::ICMP_SLE:  // fall through
219                 case CmpInst::ICMP_SGT:
220 
221                   // signed comparison and it is a negative constant
222                   if ((len == 4 && (val & 80000000)) ||
223                       (len == 8 && (val & 8000000000000000))) {
224 
225                     if ((val & 0xffff) != 1) val2 = val - 1;
226                     break;
227 
228                   }
229 
230                   // fall through
231 
232                 case CmpInst::FCMP_UGT:  // fall through
233                 case CmpInst::FCMP_ULE:  // fall through
234                 case CmpInst::ICMP_UGT:  // fall through
235                 case CmpInst::ICMP_ULE:
236                   if ((val & 0xffff) != 0xfffe) val2 = val + 1;
237                   break;
238 
239                 case CmpInst::FCMP_OLT:  // fall through
240                 case CmpInst::FCMP_OGE:  // fall through
241                 case CmpInst::ICMP_SLT:  // fall through
242                 case CmpInst::ICMP_SGE:
243 
244                   // signed comparison and it is a negative constant
245                   if ((len == 4 && (val & 80000000)) ||
246                       (len == 8 && (val & 8000000000000000))) {
247 
248                     if ((val & 0xffff) != 1) val2 = val - 1;
249                     break;
250 
251                   }
252 
253                   // fall through
254 
255                 case CmpInst::FCMP_ULT:  // fall through
256                 case CmpInst::FCMP_UGE:  // fall through
257                 case CmpInst::ICMP_ULT:  // fall through
258                 case CmpInst::ICMP_UGE:
259                   if ((val & 0xffff) != 1) val2 = val - 1;
260                   break;
261 
262                 default:
263                   val2 = 0;
264 
265               }
266 
267               dict2file(fd, (u8 *)&val, len);
268               found++;
269               if (val2) {
270 
271                 dict2file(fd, (u8 *)&val2, len);
272                 found++;
273 
274               }
275 
276             }
277 
278           }
279 
280         }
281 
282         if ((callInst = dyn_cast<CallInst>(&IN))) {
283 
284           bool   isStrcmp = true;
285           bool   isMemcmp = true;
286           bool   isStrncmp = true;
287           bool   isStrcasecmp = true;
288           bool   isStrncasecmp = true;
289           bool   isIntMemcpy = true;
290           bool   isStdString = true;
291           bool   addedNull = false;
292           size_t optLen = 0;
293 
294           Function *Callee = callInst->getCalledFunction();
295           if (!Callee) continue;
296           if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
297           std::string FuncName = Callee->getName().str();
298           isStrcmp &= !FuncName.compare("strcmp");
299           isMemcmp &=
300               (!FuncName.compare("memcmp") || !FuncName.compare("bcmp"));
301           isStrncmp &= !FuncName.compare("strncmp");
302           isStrcasecmp &= !FuncName.compare("strcasecmp");
303           isStrncasecmp &= !FuncName.compare("strncasecmp");
304           isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64");
305           isStdString &= ((FuncName.find("basic_string") != std::string::npos &&
306                            FuncName.find("compare") != std::string::npos) ||
307                           (FuncName.find("basic_string") != std::string::npos &&
308                            FuncName.find("find") != std::string::npos));
309 
310           if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
311               !isStrncasecmp && !isIntMemcpy && !isStdString)
312             continue;
313 
314           /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
315            * prototype */
316           FunctionType *FT = Callee->getFunctionType();
317 
318           isStrcmp &=
319               FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
320               FT->getParamType(0) == FT->getParamType(1) &&
321               FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
322           isStrcasecmp &=
323               FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
324               FT->getParamType(0) == FT->getParamType(1) &&
325               FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
326           isMemcmp &= FT->getNumParams() == 3 &&
327                       FT->getReturnType()->isIntegerTy(32) &&
328                       FT->getParamType(0)->isPointerTy() &&
329                       FT->getParamType(1)->isPointerTy() &&
330                       FT->getParamType(2)->isIntegerTy();
331           isStrncmp &= FT->getNumParams() == 3 &&
332                        FT->getReturnType()->isIntegerTy(32) &&
333                        FT->getParamType(0) == FT->getParamType(1) &&
334                        FT->getParamType(0) ==
335                            IntegerType::getInt8PtrTy(M.getContext()) &&
336                        FT->getParamType(2)->isIntegerTy();
337           isStrncasecmp &= FT->getNumParams() == 3 &&
338                            FT->getReturnType()->isIntegerTy(32) &&
339                            FT->getParamType(0) == FT->getParamType(1) &&
340                            FT->getParamType(0) ==
341                                IntegerType::getInt8PtrTy(M.getContext()) &&
342                            FT->getParamType(2)->isIntegerTy();
343           isStdString &= FT->getNumParams() >= 2 &&
344                          FT->getParamType(0)->isPointerTy() &&
345                          FT->getParamType(1)->isPointerTy();
346 
347           if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
348               !isStrncasecmp && !isIntMemcpy && !isStdString)
349             continue;
350 
351           /* is a str{n,}{case,}cmp/memcmp, check if we have
352            * str{case,}cmp(x, "const") or str{case,}cmp("const", x)
353            * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
354            * memcmp(x, "const", ..) or memcmp("const", x, ..) */
355           Value *Str1P = callInst->getArgOperand(0),
356                 *Str2P = callInst->getArgOperand(1);
357           std::string Str1, Str2;
358           StringRef   TmpStr;
359           bool        HasStr1;
360           getConstantStringInfo(Str1P, TmpStr);
361 
362           if (TmpStr.empty()) {
363 
364             HasStr1 = false;
365 
366           } else {
367 
368             HasStr1 = true;
369             Str1 = TmpStr.str();
370 
371           }
372 
373           bool HasStr2;
374           getConstantStringInfo(Str2P, TmpStr);
375           if (TmpStr.empty()) {
376 
377             HasStr2 = false;
378 
379           } else {
380 
381             HasStr2 = true;
382             Str2 = TmpStr.str();
383 
384           }
385 
386           if (debug)
387             fprintf(stderr, "F:%s %p(%s)->\"%s\"(%s) %p(%s)->\"%s\"(%s)\n",
388                     FuncName.c_str(), (void *)Str1P,
389                     Str1P->getName().str().c_str(), Str1.c_str(),
390                     HasStr1 == true ? "true" : "false", (void *)Str2P,
391                     Str2P->getName().str().c_str(), Str2.c_str(),
392                     HasStr2 == true ? "true" : "false");
393 
394           // we handle the 2nd parameter first because of llvm memcpy
395           if (!HasStr2) {
396 
397             auto *Ptr = dyn_cast<ConstantExpr>(Str2P);
398             if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) {
399 
400               if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) {
401 
402                 if (Var->hasInitializer()) {
403 
404                   if (auto *Array =
405                           dyn_cast<ConstantDataArray>(Var->getInitializer())) {
406 
407                     HasStr2 = true;
408                     Str2 = Array->getRawDataValues().str();
409 
410                   }
411 
412                 }
413 
414               }
415 
416             }
417 
418           }
419 
420           // for the internal memcpy routine we only care for the second
421           // parameter and are not reporting anything.
422           if (isIntMemcpy == true) {
423 
424             if (HasStr2 == true) {
425 
426               Value *      op2 = callInst->getArgOperand(2);
427               ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
428               if (ilen) {
429 
430                 uint64_t literalLength = Str2.length();
431                 uint64_t optLength = ilen->getZExtValue();
432                 if (optLength > literalLength + 1) {
433 
434                   optLength = Str2.length() + 1;
435 
436                 }
437 
438                 if (literalLength + 1 == optLength) {
439 
440                   Str2.append("\0", 1);  // add null byte
441 
442                 }
443 
444                 if (optLength > Str2.length()) { optLength = Str2.length(); }
445 
446               }
447 
448               valueMap[Str1P] = new std::string(Str2);
449 
450               if (debug)
451                 fprintf(stderr, "Saved: %s for %p\n", Str2.c_str(),
452                         (void *)Str1P);
453               continue;
454 
455             }
456 
457             continue;
458 
459           }
460 
461           // Neither a literal nor a global variable?
462           // maybe it is a local variable that we saved
463           if (!HasStr2) {
464 
465             std::string *strng = valueMap[Str2P];
466             if (strng && !strng->empty()) {
467 
468               Str2 = *strng;
469               HasStr2 = true;
470               if (debug)
471                 fprintf(stderr, "Filled2: %s for %p\n", strng->c_str(),
472                         (void *)Str2P);
473 
474             }
475 
476           }
477 
478           if (!HasStr1) {
479 
480             auto Ptr = dyn_cast<ConstantExpr>(Str1P);
481 
482             if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) {
483 
484               if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) {
485 
486                 if (Var->hasInitializer()) {
487 
488                   if (auto *Array =
489                           dyn_cast<ConstantDataArray>(Var->getInitializer())) {
490 
491                     HasStr1 = true;
492                     Str1 = Array->getRawDataValues().str();
493 
494                   }
495 
496                 }
497 
498               }
499 
500             }
501 
502           }
503 
504           // Neither a literal nor a global variable?
505           // maybe it is a local variable that we saved
506           if (!HasStr1) {
507 
508             std::string *strng = valueMap[Str1P];
509             if (strng && !strng->empty()) {
510 
511               Str1 = *strng;
512               HasStr1 = true;
513               if (debug)
514                 fprintf(stderr, "Filled1: %s for %p\n", strng->c_str(),
515                         (void *)Str1P);
516 
517             }
518 
519           }
520 
521           /* handle cases of one string is const, one string is variable */
522           if (!(HasStr1 ^ HasStr2)) continue;
523 
524           std::string thestring;
525 
526           if (HasStr1)
527             thestring = Str1;
528           else
529             thestring = Str2;
530 
531           optLen = thestring.length();
532 
533           if (optLen < 2 || (optLen == 2 && !thestring[1])) { continue; }
534 
535           if (isMemcmp || isStrncmp || isStrncasecmp) {
536 
537             Value *      op2 = callInst->getArgOperand(2);
538             ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
539 
540             if (ilen) {
541 
542               uint64_t literalLength = optLen;
543               optLen = ilen->getZExtValue();
544               if (optLen > thestring.length() + 1) {
545 
546                 optLen = thestring.length() + 1;
547 
548               }
549 
550               if (optLen < 2) { continue; }
551               if (literalLength + 1 == optLen) {  // add null byte
552                 thestring.append("\0", 1);
553                 addedNull = true;
554 
555               }
556 
557             }
558 
559           }
560 
561           // add null byte if this is a string compare function and a null
562           // was not already added
563           if (!isMemcmp) {
564 
565             if (addedNull == false && thestring[optLen - 1] != '\0') {
566 
567               thestring.append("\0", 1);  // add null byte
568               optLen++;
569 
570             }
571 
572             if (!isStdString) {
573 
574               // ensure we do not have garbage
575               size_t offset = thestring.find('\0', 0);
576               if (offset + 1 < optLen) optLen = offset + 1;
577               thestring = thestring.substr(0, optLen);
578 
579             }
580 
581           }
582 
583           // we take the longer string, even if the compare was to a
584           // shorter part. Note that depending on the optimizer of the
585           // compiler this can be wrong, but it is more likely that this
586           // is helping the fuzzer
587           if (optLen != thestring.length()) optLen = thestring.length();
588           if (optLen > MAX_AUTO_EXTRA) optLen = MAX_AUTO_EXTRA;
589           if (optLen < 3)  // too short? skip
590             continue;
591 
592           ptr = (char *)thestring.c_str();
593 
594           dict2file(fd, (u8 *)ptr, optLen);
595           found++;
596 
597         }
598 
599       }
600 
601     }
602 
603   }
604 
605   close(fd);
606 
607   /* Say something nice. */
608 
609   if (!be_quiet) {
610 
611     if (!found)
612       OKF("No entries for a dictionary found.");
613     else
614       OKF("Wrote %d entries to the dictionary file.\n", found);
615 
616   }
617 
618   return true;
619 
620 }
621 
622 char AFLdict2filePass::ID = 0;
623 
registerAFLdict2filePass(const PassManagerBuilder &,legacy::PassManagerBase & PM)624 static void registerAFLdict2filePass(const PassManagerBuilder &,
625                                      legacy::PassManagerBase &PM) {
626 
627   PM.add(new AFLdict2filePass());
628 
629 }
630 
631 static RegisterPass<AFLdict2filePass> X("afl-dict2file",
632                                         "afl++ dict2file instrumentation pass",
633                                         false, false);
634 
635 static RegisterStandardPasses RegisterAFLdict2filePass(
636     PassManagerBuilder::EP_OptimizerLast, registerAFLdict2filePass);
637 
638 static RegisterStandardPasses RegisterAFLdict2filePass0(
639     PassManagerBuilder::EP_EnabledOnOptLevel0, registerAFLdict2filePass);
640 
641