1 #include "testing.h"
2 #include "llvm/Support/raw_ostream.h"
3 #include <cstdarg>
4 #include <cstdio>
5 #include <cstdlib>
6 
7 namespace testing {
8 
9 namespace {
10 int passes{0};
11 int failures{0};
12 } // namespace
13 
BitBucket(const char *,...)14 static void BitBucket(const char *, ...) {}
15 
PrintFailureDetails(const char * format,...)16 static void PrintFailureDetails(const char *format, ...) {
17   va_list ap;
18   va_start(ap, format);
19   fputs("\t", stderr);
20   vfprintf(stderr, format, ap);
21   va_end(ap);
22   fputc('\n', stderr);
23 }
24 
Test(const char * file,int line,const char * predicate,bool pass)25 FailureDetailPrinter Test(
26     const char *file, int line, const char *predicate, bool pass) {
27   if (pass) {
28     ++passes;
29     return BitBucket;
30   } else {
31     ++failures;
32     fprintf(stderr, "%s:%d: FAIL: %s\n", file, line, predicate);
33     return PrintFailureDetails;
34   }
35 }
36 
Match(const char * file,int line,std::uint64_t want,const char * gots,std::uint64_t got)37 FailureDetailPrinter Match(const char *file, int line, std::uint64_t want,
38     const char *gots, std::uint64_t got) {
39   if (want == got) {
40     ++passes;
41     return BitBucket;
42   } else {
43     ++failures;
44     fprintf(stderr, "%s:%d: FAIL: %s == 0x%jx, not 0x%jx\n", file, line, gots,
45         static_cast<std::uintmax_t>(got), static_cast<std::uintmax_t>(want));
46     return PrintFailureDetails;
47   }
48 }
49 
Match(const char * file,int line,const char * want,const char * gots,const std::string & got)50 FailureDetailPrinter Match(const char *file, int line, const char *want,
51     const char *gots, const std::string &got) {
52   if (want == got) {
53     ++passes;
54     return BitBucket;
55   } else {
56     ++failures;
57     fprintf(stderr, "%s:%d: FAIL: %s == \"%s\", not \"%s\"\n", file, line, gots,
58         got.data(), want);
59     return PrintFailureDetails;
60   }
61 }
62 
Match(const char * file,int line,const std::string & want,const char * gots,const std::string & got)63 FailureDetailPrinter Match(const char *file, int line, const std::string &want,
64     const char *gots, const std::string &got) {
65   return Match(file, line, want.data(), gots, got);
66 }
67 
Compare(const char * file,int line,const char * xs,const char * rel,const char * ys,std::uint64_t x,std::uint64_t y)68 FailureDetailPrinter Compare(const char *file, int line, const char *xs,
69     const char *rel, const char *ys, std::uint64_t x, std::uint64_t y) {
70   while (*rel == ' ') {
71     ++rel;
72   }
73   bool pass{false};
74   if (*rel == '<') {
75     if (rel[1] == '=') {
76       pass = x <= y;
77     } else {
78       pass = x < y;
79     }
80   } else if (*rel == '>') {
81     if (rel[1] == '=') {
82       pass = x >= y;
83     } else {
84       pass = x > y;
85     }
86   } else if (*rel == '=') {
87     pass = x == y;
88   } else if (*rel == '!') {
89     pass = x != y;
90   }
91   if (pass) {
92     ++passes;
93     return BitBucket;
94   } else {
95     ++failures;
96     fprintf(stderr, "%s:%d: FAIL: %s[0x%jx] %s %s[0x%jx]\n", file, line, xs,
97         static_cast<std::uintmax_t>(x), rel, ys,
98         static_cast<std::uintmax_t>(y));
99     return PrintFailureDetails;
100   }
101 }
102 
Complete()103 int Complete() {
104   if (failures == 0) {
105     if (passes == 1) {
106       llvm::outs() << "single test PASSES\n";
107     } else {
108       llvm::outs() << "all " << passes << " tests PASS\n";
109     }
110     passes = 0;
111     return EXIT_SUCCESS;
112   } else {
113     if (passes == 1) {
114       llvm::errs() << "1 test passes, ";
115     } else {
116       llvm::errs() << passes << " tests pass, ";
117     }
118     if (failures == 1) {
119       llvm::errs() << "1 test FAILS\n";
120     } else {
121       llvm::errs() << failures << " tests FAIL\n";
122     }
123     passes = failures = 0;
124     return EXIT_FAILURE;
125   }
126 }
127 } // namespace testing
128