1 
2 /**
3  *    Copyright (C) 2018-present MongoDB, Inc.
4  *
5  *    This program is free software: you can redistribute it and/or modify
6  *    it under the terms of the Server Side Public License, version 1,
7  *    as published by MongoDB, Inc.
8  *
9  *    This program is distributed in the hope that it will be useful,
10  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *    Server Side Public License for more details.
13  *
14  *    You should have received a copy of the Server Side Public License
15  *    along with this program. If not, see
16  *    <http://www.mongodb.com/licensing/server-side-public-license>.
17  *
18  *    As a special exception, the copyright holders give permission to link the
19  *    code of portions of this program with the OpenSSL library under certain
20  *    conditions as described in each individual source file and distribute
21  *    linked combinations including the program with the OpenSSL library. You
22  *    must comply with the Server Side Public License in all respects for
23  *    all of the code used other than as permitted herein. If you modify file(s)
24  *    with this exception, you may extend this exception to your version of the
25  *    file(s), but you are not obligated to do so. If you do not wish to do so,
26  *    delete this exception statement from your version. If you delete this
27  *    exception statement from all source files in the program, then also delete
28  *    it in the license file.
29  */
30 
31 #define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kDefault
32 
33 #include "mongo/platform/basic.h"
34 
35 #include "mongo/unittest/unittest.h"
36 
37 #include <iostream>
38 #include <map>
39 
40 #include "mongo/base/checked_cast.h"
41 #include "mongo/base/init.h"
42 #include "mongo/logger/console_appender.h"
43 #include "mongo/logger/log_manager.h"
44 #include "mongo/logger/logger.h"
45 #include "mongo/logger/message_event_utf8_encoder.h"
46 #include "mongo/logger/message_log_domain.h"
47 #include "mongo/stdx/functional.h"
48 #include "mongo/stdx/memory.h"
49 #include "mongo/stdx/mutex.h"
50 #include "mongo/util/assert_util.h"
51 #include "mongo/util/log.h"
52 #include "mongo/util/timer.h"
53 
54 namespace mongo {
55 
56 using std::shared_ptr;
57 using std::string;
58 
59 namespace unittest {
60 
61 namespace {
62 
stringContains(const std::string & haystack,const std::string & needle)63 bool stringContains(const std::string& haystack, const std::string& needle) {
64     return haystack.find(needle) != std::string::npos;
65 }
66 
67 logger::MessageLogDomain* unittestOutput = logger::globalLogManager()->getNamedDomain("unittest");
68 
69 typedef std::map<std::string, std::shared_ptr<Suite>> SuiteMap;
70 
_allSuites()71 inline SuiteMap& _allSuites() {
72     static SuiteMap allSuites;
73     return allSuites;
74 }
75 
76 }  // namespace
77 
log()78 logger::LogstreamBuilder log() {
79     return LogstreamBuilder(unittestOutput, getThreadName(), logger::LogSeverity::Log());
80 }
81 
82 MONGO_INITIALIZER_WITH_PREREQUISITES(UnitTestOutput, ("GlobalLogManager", "default"))
83 (InitializerContext*) {
84     unittestOutput->attachAppender(logger::MessageLogDomain::AppenderAutoPtr(
85         new logger::ConsoleAppender<logger::MessageLogDomain::Event>(
86             new logger::MessageEventDetailsEncoder)));
87     return Status::OK();
88 }
89 
90 class Result {
91 public:
Result(const std::string & name)92     Result(const std::string& name)
93         : _name(name), _rc(0), _tests(0), _fails(), _asserts(0), _millis(0) {}
94 
toString()95     std::string toString() {
96         std::stringstream ss;
97 
98         char result[128];
99         sprintf(result,
100                 "%-30s | tests: %4d | fails: %4d | assert calls: %10d | time secs: %6.3f\n",
101                 _name.c_str(),
102                 _tests,
103                 static_cast<int>(_fails.size()),
104                 _asserts,
105                 _millis / 1000.0);
106         ss << result;
107 
108         for (std::vector<std::string>::iterator i = _messages.begin(); i != _messages.end(); i++) {
109             ss << "\t" << *i << '\n';
110         }
111 
112         return ss.str();
113     }
114 
rc()115     int rc() {
116         return _rc;
117     }
118 
119     string _name;
120 
121     int _rc;
122     int _tests;
123     std::vector<std::string> _fails;
124     int _asserts;
125     int _millis;
126     std::vector<std::string> _messages;
127 
128     static Result* cur;
129 };
130 
131 Result* Result::cur = 0;
132 
133 namespace {
134 
135 /**
136  * This unsafe scope guard allows exceptions in its destructor. Thus, if it goes out of scope when
137  * an exception is active and the guard function also throws an exception, the program will call
138  * std::terminate. This should only be used in unittests where termination on exception is okay.
139  */
140 template <typename F>
141 class UnsafeScopeGuard {
142 public:
UnsafeScopeGuard(F fun)143     UnsafeScopeGuard(F fun) : _fun(fun) {}
144 
~UnsafeScopeGuard()145     ~UnsafeScopeGuard() noexcept(false) {
146         _fun();
147     }
148 
149 private:
150     F _fun;
151 };
152 
153 template <typename F>
MakeUnsafeScopeGuard(F fun)154 inline UnsafeScopeGuard<F> MakeUnsafeScopeGuard(F fun) {
155     return UnsafeScopeGuard<F>(std::move(fun));
156 }
157 
158 }  // namespace
159 
Test()160 Test::Test() : _isCapturingLogMessages(false) {}
161 
~Test()162 Test::~Test() {
163     if (_isCapturingLogMessages) {
164         stopCapturingLogMessages();
165     }
166 }
167 
run()168 void Test::run() {
169     setUp();
170     auto guard = MakeUnsafeScopeGuard([this] { tearDown(); });
171 
172     // An uncaught exception does not prevent the tear down from running. But
173     // such an event still constitutes an error. To test this behavior we use a
174     // special exception here that when thrown does trigger the tear down but is
175     // not considered an error.
176     try {
177         _doTest();
178     } catch (FixtureExceptionForTesting&) {
179         return;
180     }
181 }
182 
setUp()183 void Test::setUp() {}
tearDown()184 void Test::tearDown() {}
185 
186 namespace {
187 class StringVectorAppender : public logger::MessageLogDomain::EventAppender {
188 public:
StringVectorAppender(std::vector<std::string> * lines)189     explicit StringVectorAppender(std::vector<std::string>* lines) : _lines(lines) {}
~StringVectorAppender()190     virtual ~StringVectorAppender() {}
append(const logger::MessageLogDomain::Event & event)191     virtual Status append(const logger::MessageLogDomain::Event& event) {
192         std::ostringstream _os;
193         if (!_encoder.encode(event, _os)) {
194             return Status(ErrorCodes::LogWriteFailed, "Failed to append to LogTestAppender.");
195         }
196         stdx::lock_guard<stdx::mutex> lk(_mutex);
197         if (_enabled) {
198             _lines->push_back(_os.str());
199         }
200         return Status::OK();
201     }
202 
enable()203     void enable() {
204         stdx::lock_guard<stdx::mutex> lk(_mutex);
205         invariant(!_enabled);
206         _enabled = true;
207     }
208 
disable()209     void disable() {
210         stdx::lock_guard<stdx::mutex> lk(_mutex);
211         invariant(_enabled);
212         _enabled = false;
213     }
214 
215 private:
216     stdx::mutex _mutex;
217     bool _enabled = false;
218     logger::MessageEventDetailsEncoder _encoder;
219     std::vector<std::string>* _lines;
220 };
221 }  // namespace
222 
startCapturingLogMessages()223 void Test::startCapturingLogMessages() {
224     invariant(!_isCapturingLogMessages);
225     _capturedLogMessages.clear();
226     if (!_captureAppender) {
227         _captureAppender = stdx::make_unique<StringVectorAppender>(&_capturedLogMessages);
228     }
229     checked_cast<StringVectorAppender*>(_captureAppender.get())->enable();
230     _captureAppenderHandle = logger::globalLogDomain()->attachAppender(std::move(_captureAppender));
231     _isCapturingLogMessages = true;
232 }
233 
stopCapturingLogMessages()234 void Test::stopCapturingLogMessages() {
235     invariant(_isCapturingLogMessages);
236     invariant(!_captureAppender);
237     _captureAppender = logger::globalLogDomain()->detachAppender(_captureAppenderHandle);
238     checked_cast<StringVectorAppender*>(_captureAppender.get())->disable();
239     _isCapturingLogMessages = false;
240 }
printCapturedLogLines() const241 void Test::printCapturedLogLines() const {
242     log() << "****************************** Captured Lines (start) *****************************";
243     std::for_each(getCapturedLogMessages().begin(),
244                   getCapturedLogMessages().end(),
245                   [](std::string line) { log() << line; });
246     log() << "****************************** Captured Lines (end) ******************************";
247 }
248 
countLogLinesContaining(const std::string & needle)249 int64_t Test::countLogLinesContaining(const std::string& needle) {
250     return std::count_if(getCapturedLogMessages().begin(),
251                          getCapturedLogMessages().end(),
252                          stdx::bind(stringContains, stdx::placeholders::_1, needle));
253 }
254 
Suite(const std::string & name)255 Suite::Suite(const std::string& name) : _name(name) {
256     registerSuite(name, this);
257 }
258 
~Suite()259 Suite::~Suite() {}
260 
add(const std::string & name,const TestFunction & testFn)261 void Suite::add(const std::string& name, const TestFunction& testFn) {
262     _tests.push_back(std::shared_ptr<TestHolder>(new TestHolder(name, testFn)));
263 }
264 
run(const std::string & filter,int runsPerTest)265 Result* Suite::run(const std::string& filter, int runsPerTest) {
266     LOG(1) << "\t about to setupTests" << std::endl;
267     setupTests();
268     LOG(1) << "\t done setupTests" << std::endl;
269 
270     Timer timer;
271     Result* r = new Result(_name);
272     Result::cur = r;
273 
274     for (std::vector<std::shared_ptr<TestHolder>>::iterator i = _tests.begin(); i != _tests.end();
275          i++) {
276         std::shared_ptr<TestHolder>& tc = *i;
277         if (filter.size() && tc->getName().find(filter) == std::string::npos) {
278             LOG(1) << "\t skipping test: " << tc->getName() << " because doesn't match filter"
279                    << std::endl;
280             continue;
281         }
282 
283         r->_tests++;
284 
285         bool passes = false;
286 
287         std::stringstream err;
288         err << tc->getName() << "\t";
289 
290         try {
291             for (int x = 0; x < runsPerTest; x++) {
292                 std::stringstream runTimes;
293                 if (runsPerTest > 1) {
294                     runTimes << "  (" << x + 1 << "/" << runsPerTest << ")";
295                 }
296                 log() << "\t going to run test: " << tc->getName() << runTimes.str();
297                 tc->run();
298             }
299             passes = true;
300         } catch (const TestAssertionFailureException& ae) {
301             err << ae.toString();
302         } catch (const std::exception& e) {
303             err << " std::exception: " << e.what() << " in test " << tc->getName();
304         } catch (int x) {
305             err << " caught int " << x << " in test " << tc->getName();
306         }
307 
308         if (!passes) {
309             std::string s = err.str();
310             log() << "FAIL: " << s << std::endl;
311             r->_fails.push_back(tc->getName());
312             r->_messages.push_back(s);
313         }
314     }
315 
316     if (!r->_fails.empty())
317         r->_rc = 17;
318 
319     r->_millis = timer.millis();
320 
321     log() << "\t DONE running tests" << std::endl;
322 
323     return r;
324 }
325 
run(const std::vector<std::string> & suites,const std::string & filter,int runsPerTest)326 int Suite::run(const std::vector<std::string>& suites, const std::string& filter, int runsPerTest) {
327     if (_allSuites().empty()) {
328         log() << "error: no suites registered.";
329         return EXIT_FAILURE;
330     }
331 
332     for (unsigned int i = 0; i < suites.size(); i++) {
333         if (_allSuites().count(suites[i]) == 0) {
334             log() << "invalid test suite [" << suites[i] << "], use --list to see valid names"
335                   << std::endl;
336             return EXIT_FAILURE;
337         }
338     }
339 
340     std::vector<std::string> torun(suites);
341 
342     if (torun.empty()) {
343         for (SuiteMap::const_iterator i = _allSuites().begin(); i != _allSuites().end(); ++i) {
344             torun.push_back(i->first);
345         }
346     }
347 
348     std::vector<Result*> results;
349 
350     for (std::vector<std::string>::iterator i = torun.begin(); i != torun.end(); i++) {
351         std::string name = *i;
352         std::shared_ptr<Suite>& s = _allSuites()[name];
353         fassert(16145, s != NULL);
354 
355         log() << "going to run suite: " << name << std::endl;
356         results.push_back(s->run(filter, runsPerTest));
357     }
358 
359     log() << "**************************************************" << std::endl;
360 
361     int rc = 0;
362 
363     int tests = 0;
364     int asserts = 0;
365     int millis = 0;
366 
367     Result totals("TOTALS");
368     std::vector<std::string> failedSuites;
369 
370     Result::cur = NULL;
371     for (std::vector<Result*>::iterator i = results.begin(); i != results.end(); i++) {
372         Result* r = *i;
373         log() << r->toString();
374         if (abs(r->rc()) > abs(rc))
375             rc = r->rc();
376 
377         tests += r->_tests;
378         if (!r->_fails.empty()) {
379             failedSuites.push_back(r->toString());
380             for (std::vector<std::string>::const_iterator j = r->_fails.begin();
381                  j != r->_fails.end();
382                  j++) {
383                 const std::string& s = (*j);
384                 totals._fails.push_back(r->_name + "/" + s);
385             }
386         }
387         asserts += r->_asserts;
388         millis += r->_millis;
389 
390         delete r;
391     }
392 
393     totals._tests = tests;
394     totals._asserts = asserts;
395     totals._millis = millis;
396 
397     log() << totals.toString();  // includes endl
398 
399     // summary
400     if (!totals._fails.empty()) {
401         log() << "Failing tests:" << std::endl;
402         for (std::vector<std::string>::const_iterator i = totals._fails.begin();
403              i != totals._fails.end();
404              i++) {
405             const std::string& s = (*i);
406             log() << "\t " << s << " Failed";
407         }
408         log() << "FAILURE - " << totals._fails.size() << " tests in " << failedSuites.size()
409               << " suites failed";
410     } else {
411         log() << "SUCCESS - All tests in all suites passed";
412     }
413 
414     return rc;
415 }
416 
registerSuite(const std::string & name,Suite * s)417 void Suite::registerSuite(const std::string& name, Suite* s) {
418     std::shared_ptr<Suite>& m = _allSuites()[name];
419     fassert(10162, !m);
420     m.reset(s);
421 }
422 
getSuite(const std::string & name)423 Suite* Suite::getSuite(const std::string& name) {
424     std::shared_ptr<Suite>& result = _allSuites()[name];
425     if (!result) {
426         // Suites are self-registering.
427         new Suite(name);
428     }
429     invariant(result);
430     return result.get();
431 }
432 
setupTests()433 void Suite::setupTests() {}
434 
TestAssertionFailureException(const std::string & theFile,unsigned theLine,const std::string & theFailingExpression)435 TestAssertionFailureException::TestAssertionFailureException(
436     const std::string& theFile, unsigned theLine, const std::string& theFailingExpression)
437     : _file(theFile), _line(theLine), _message(theFailingExpression) {}
438 
toString() const439 std::string TestAssertionFailureException::toString() const {
440     std::ostringstream os;
441     os << getMessage() << " @" << getFile() << ":" << getLine();
442     return os.str();
443 }
444 
TestAssertionFailure(const std::string & file,unsigned line,const std::string & message)445 TestAssertionFailure::TestAssertionFailure(const std::string& file,
446                                            unsigned line,
447                                            const std::string& message)
448     : _exception(file, line, message), _enabled(false) {}
449 
TestAssertionFailure(const TestAssertionFailure & other)450 TestAssertionFailure::TestAssertionFailure(const TestAssertionFailure& other)
451     : _exception(other._exception), _enabled(false) {
452     invariant(!other._enabled);
453 }
454 
operator =(const TestAssertionFailure & other)455 TestAssertionFailure& TestAssertionFailure::operator=(const TestAssertionFailure& other) {
456     invariant(!_enabled);
457     invariant(!other._enabled);
458     _exception = other._exception;
459     return *this;
460 }
461 
~TestAssertionFailure()462 TestAssertionFailure::~TestAssertionFailure() noexcept(false) {
463     if (!_enabled) {
464         invariant(_stream.str().empty());
465         return;
466     }
467     if (!_stream.str().empty()) {
468         _exception.setMessage(_exception.getMessage() + " " + _stream.str());
469     }
470     error() << "Throwing exception: " << _exception;
471     throw _exception;
472 }
473 
stream()474 std::ostream& TestAssertionFailure::stream() {
475     invariant(!_enabled);
476     _enabled = true;
477     return _stream;
478 }
479 
getAllSuiteNames()480 std::vector<std::string> getAllSuiteNames() {
481     std::vector<std::string> result;
482     for (SuiteMap::const_iterator i = _allSuites().begin(); i != _allSuites().end(); ++i) {
483         result.push_back(i->first);
484     }
485     return result;
486 }
487 
488 }  // namespace unittest
489 }  // namespace mongo
490