1 /* SPDX-License-Identifier: BSL-1.0 OR BSD-3-Clause */
2
3 #ifndef MPT_TEST_TEST_HPP
4 #define MPT_TEST_TEST_HPP
5
6
7
8 #include "mpt/base/detect.hpp"
9 #include "mpt/base/namespace.hpp"
10 #include "mpt/base/source_location.hpp"
11
12 #include <functional>
13 #include <iostream>
14 #include <map>
15 #include <sstream>
16 #include <stdexcept>
17 #include <type_traits>
18 #include <typeinfo>
19 #include <utility>
20 #include <variant>
21
22 #include <cstddef>
23 #include <cstdlib>
24
25
26
27 namespace mpt {
28 inline namespace MPT_INLINE_NS {
29
30
31
32 namespace test {
33
34
35
36 template <typename S, typename T, typename = void>
37 struct is_to_stream_writable : std::false_type { };
38
39 template <typename S, typename T>
40 struct is_to_stream_writable<S, T, std::void_t<decltype(std::declval<S &>() << std::declval<T>())>> : std::true_type { };
41
42 template <typename T>
format(const T & x)43 inline auto format(const T & x) -> typename std::enable_if<mpt::test::is_to_stream_writable<std::ostringstream, T>::value, std::string>::type {
44 std::ostringstream s;
45 s << x;
46 return s.str();
47 }
48
49 template <typename T>
format(const T & x)50 inline auto format(const T & x) -> typename std::enable_if<!mpt::test::is_to_stream_writable<std::ostringstream, T>::value, std::string>::type {
51 return typeid(x).name();
52 }
53
get_exception_text()54 inline std::string get_exception_text() {
55 std::string result;
56 try {
57 // cppcheck false-positive
58 // cppcheck-suppress rethrowNoCurrentException
59 throw;
60 } catch (const std::exception & e) {
61 result = e.what();
62 } catch (...) {
63 result = "unknown exception";
64 }
65 return result;
66 }
67
68 struct result_success {
69 };
70 struct result_failure {
71 std::string text{};
72 };
73 struct result_unexpected_exception {
74 std::string text{};
75 };
76
77 struct result {
78 std::variant<std::monostate, result_success, result_failure, result_unexpected_exception> info{std::monostate{}};
79 };
80
81 struct statistics_counters {
82 std::size_t total{0};
83 std::size_t run{0};
84 std::size_t successes{0};
85 std::size_t failures{0};
86 std::size_t unexpected_exceptions{0};
87 std::size_t completed{0};
operator +=mpt::MPT_INLINE_NS::test::statistics_counters88 constexpr statistics_counters & operator+=(const statistics_counters & other) noexcept {
89 total += other.total;
90 run += other.run;
91 successes += other.successes;
92 failures += other.failures;
93 unexpected_exceptions += other.unexpected_exceptions;
94 completed += other.completed;
95 return *this;
96 }
97 };
98
99 struct group_statistics {
100 statistics_counters tests{};
101 statistics_counters cases{};
102 statistics_counters local_cases{};
103 };
104
105 struct global_statistics {
106 statistics_counters groups{};
107 statistics_counters tests{};
108 statistics_counters cases{};
109 std::map<std::string, group_statistics> individual_group_statistics{};
operator boolmpt::MPT_INLINE_NS::test::global_statistics110 explicit constexpr operator bool() noexcept {
111 return succeeded();
112 }
operator !mpt::MPT_INLINE_NS::test::global_statistics113 constexpr bool operator!() noexcept {
114 return failed();
115 }
succeededmpt::MPT_INLINE_NS::test::global_statistics116 constexpr bool succeeded() noexcept {
117 return groups.successes == groups.run;
118 }
failedmpt::MPT_INLINE_NS::test::global_statistics119 constexpr bool failed() noexcept {
120 return groups.failures > 0 || groups.unexpected_exceptions > 0;
121 }
122 };
123
124 class reporter_interface {
125 protected:
126 virtual ~reporter_interface() = default;
127
128 public:
129 virtual void run_begin(const mpt::source_location & loc) = 0;
130 virtual void group_begin(const mpt::source_location & loc, const char * name) = 0;
131 virtual void test_begin(const mpt::source_location & loc, const char * name) = 0;
132 virtual void case_run(const mpt::source_location & loc) = 0;
133 virtual void case_run(const mpt::source_location & loc, const char * text_e) = 0;
134 virtual void case_run(const mpt::source_location & loc, const char * text_ex, const char * text_e) = 0;
135 virtual void case_run(const mpt::source_location & loc, const char * text_a, const char * text_cmp, const char * text_b) = 0;
136 virtual void case_result(const mpt::source_location & loc, const mpt::test::result & result) = 0;
137 virtual void test_end(const mpt::source_location & loc, const char * name, const statistics_counters & counters) = 0;
138 virtual void group_end(const mpt::source_location & loc, const char * name, const group_statistics & statistics) = 0;
139 virtual void run_end(const mpt::source_location & loc, const global_statistics & statistics) = 0;
140 virtual void immediate_breakpoint() = 0;
141 };
142
143 class silent_reporter
144 : public reporter_interface {
145 public:
146 silent_reporter() = default;
147 ~silent_reporter() override = default;
148
149 public:
run_begin(const mpt::source_location &)150 virtual void run_begin(const mpt::source_location &) override {
151 }
group_begin(const mpt::source_location &,const char *)152 virtual void group_begin(const mpt::source_location &, const char *) override {
153 }
test_begin(const mpt::source_location &,const char *)154 virtual void test_begin(const mpt::source_location &, const char *) override {
155 }
case_run(const mpt::source_location &)156 virtual void case_run(const mpt::source_location &) override {
157 }
case_run(const mpt::source_location &,const char *)158 virtual void case_run(const mpt::source_location &, const char *) override {
159 }
case_run(const mpt::source_location &,const char *,const char *)160 virtual void case_run(const mpt::source_location &, const char *, const char *) override {
161 }
case_run(const mpt::source_location &,const char *,const char *,const char *)162 virtual void case_run(const mpt::source_location &, const char *, const char *, const char *) override {
163 }
case_result(const mpt::source_location &,const mpt::test::result &)164 virtual void case_result(const mpt::source_location &, const mpt::test::result &) override {
165 }
test_end(const mpt::source_location &,const char *,const statistics_counters &)166 virtual void test_end(const mpt::source_location &, const char *, const statistics_counters &) override {
167 }
group_end(const mpt::source_location &,const char *,const group_statistics &)168 virtual void group_end(const mpt::source_location &, const char *, const group_statistics &) override {
169 }
run_end(const mpt::source_location &,const global_statistics &)170 virtual void run_end(const mpt::source_location &, const global_statistics &) override {
171 }
immediate_breakpoint()172 virtual void immediate_breakpoint() override {
173 }
174 };
175
176 class simple_reporter : public reporter_interface {
177 private:
178 std::ostream & s;
179
180 public:
simple_reporter(std::ostream & s_)181 simple_reporter(std::ostream & s_)
182 : s(s_) {
183 s.flush();
184 }
~simple_reporter()185 ~simple_reporter() override {
186 s.flush();
187 }
188
189 public:
run_begin(const mpt::source_location & loc)190 void run_begin(const mpt::source_location & loc) override {
191 static_cast<void>(loc);
192 s << "Running test suite ..." << std::endl;
193 }
group_begin(const mpt::source_location & loc,const char * name)194 void group_begin(const mpt::source_location & loc, const char * name) override {
195 static_cast<void>(loc);
196 s << "Running group '" << name << "' ..." << std::endl;
197 }
test_begin(const mpt::source_location & loc,const char * name)198 void test_begin(const mpt::source_location & loc, const char * name) override {
199 static_cast<void>(loc);
200 s << " Running test '" << name << "' ..." << std::endl;
201 }
case_run(const mpt::source_location & loc)202 void case_run(const mpt::source_location & loc) override {
203 static_cast<void>(loc);
204 s << " Checking ..." << std::endl;
205 }
case_run(const mpt::source_location & loc,const char * text_e)206 void case_run(const mpt::source_location & loc, const char * text_e) override {
207 static_cast<void>(loc);
208 s << " Checking '" << text_e << "' ..." << std::endl;
209 }
case_run(const mpt::source_location & loc,const char * text_ex,const char * text_e)210 void case_run(const mpt::source_location & loc, const char * text_ex, const char * text_e) override {
211 static_cast<void>(loc);
212 if (text_ex) {
213 s << " Checking '" << text_e << " throws " << text_ex << "' ..." << std::endl;
214 } else {
215 s << " Checking '" << text_e << " throws' ..." << std::endl;
216 }
217 }
case_run(const mpt::source_location & loc,const char * text_a,const char * text_cmp,const char * text_b)218 void case_run(const mpt::source_location & loc, const char * text_a, const char * text_cmp, const char * text_b) override {
219 static_cast<void>(loc);
220 s << " Checking '" << text_a << " " << text_cmp << " " << text_b << "' ..." << std::endl;
221 }
case_result(const mpt::source_location & loc,const mpt::test::result & result)222 void case_result(const mpt::source_location & loc, const mpt::test::result & result) override {
223 static_cast<void>(loc);
224 s << " Checking done: ";
225 if (std::holds_alternative<result_success>(result.info)) {
226 s << "Success.";
227 } else if (std::holds_alternative<result_failure>(result.info)) {
228 s << "FAILURE: " << std::get<result_failure>(result.info).text;
229 } else if (std::holds_alternative<result_unexpected_exception>(result.info)) {
230 s << "UNEXPECTED EXCEPTION: " << std::get<result_unexpected_exception>(result.info).text;
231 }
232 s << std::endl;
233 }
test_end(const mpt::source_location & loc,const char * name,const statistics_counters & counters)234 void test_end(const mpt::source_location & loc, const char * name, const statistics_counters & counters) override {
235 static_cast<void>(loc);
236 static_cast<void>(counters);
237 s << " Running test '" << name << "' done." << std::endl;
238 }
group_end(const mpt::source_location & loc,const char * name,const group_statistics & statistics)239 void group_end(const mpt::source_location & loc, const char * name, const group_statistics & statistics) override {
240 static_cast<void>(loc);
241 static_cast<void>(statistics);
242 s << "Running group '" << name << "' done." << std::endl;
243 }
run_end(const mpt::source_location & loc,const global_statistics & statistics)244 void run_end(const mpt::source_location & loc, const global_statistics & statistics) override {
245 static_cast<void>(loc);
246 s << "Running test suite done." << std::endl;
247 s << "groups: " << statistics.groups.total << " | " << statistics.groups.successes << " passed";
248 if (statistics.groups.failures || statistics.groups.unexpected_exceptions) {
249 s << " | " << statistics.groups.failures << " FAILED";
250 if (statistics.groups.unexpected_exceptions) {
251 s << " | " << statistics.groups.unexpected_exceptions << " UNEXPECTED EXCEPTIONS";
252 }
253 }
254 s << std::endl;
255 s << "tests: " << statistics.tests.total << " | " << statistics.tests.successes << " passed";
256 if (statistics.tests.failures || statistics.tests.unexpected_exceptions) {
257 s << " | " << statistics.tests.failures << " FAILED";
258 if (statistics.tests.unexpected_exceptions) {
259 s << " | " << statistics.tests.unexpected_exceptions << " UNEXPECTED EXCEPTIONS";
260 }
261 }
262 s << std::endl;
263 s << "checks: " << statistics.cases.total << " | " << statistics.cases.successes << " passed";
264 if (statistics.cases.failures || statistics.cases.unexpected_exceptions) {
265 s << " | " << statistics.cases.failures << " FAILED";
266 if (statistics.cases.unexpected_exceptions) {
267 s << " | " << statistics.cases.unexpected_exceptions << " UNEXPECTED EXCEPTIONS";
268 }
269 }
270 s << std::endl;
271 }
immediate_breakpoint()272 void immediate_breakpoint() override {
273 return;
274 }
275 };
276
277 struct group;
278
279 struct context {
280 mpt::test::group & group;
281 mpt::test::reporter_interface & reporter;
282 mpt::test::group_statistics statistics{};
283 };
284
285 using void_context_function = void (*)(mpt::test::context &);
286
287 struct group {
288 group * next{nullptr};
289 const char * name{""};
290 void_context_function func{nullptr};
groupmpt::MPT_INLINE_NS::test::group291 inline group(const char * name_, void_context_function f)
292 : name(name_)
293 , func(f) {
294 next = group_list();
295 group_list() = this;
296 }
runmpt::MPT_INLINE_NS::test::group297 group_statistics run(mpt::test::reporter_interface & reporter, const mpt::source_location & loc = mpt::source_location::current()) {
298 mpt::test::context context{*this, reporter};
299 context.reporter.group_begin(loc, name);
300 if (func) {
301 func(context);
302 }
303 context.reporter.group_end(loc, name, context.statistics);
304 return context.statistics;
305 }
306
307 public:
group_listmpt::MPT_INLINE_NS::test::group308 [[nodiscard]] static inline group *& group_list() noexcept {
309 static group * group_list = nullptr;
310 return group_list;
311 }
312 };
313
run_all(mpt::test::reporter_interface & reporter,const mpt::source_location & loc=mpt::source_location::current ())314 inline global_statistics run_all(mpt::test::reporter_interface & reporter, const mpt::source_location & loc = mpt::source_location::current()) {
315 global_statistics statistics{};
316 reporter.run_begin(loc);
317 for (group * g = group::group_list(); g; g = g->next) {
318 statistics.groups.total++;
319 statistics.groups.run++;
320 group_statistics s = g->run(reporter, loc);
321 if (s.tests.unexpected_exceptions) {
322 statistics.groups.unexpected_exceptions++;
323 } else if (s.tests.failures) {
324 statistics.groups.failures++;
325 } else {
326 statistics.groups.successes++;
327 }
328 statistics.tests += s.tests;
329 statistics.cases += s.cases;
330 statistics.groups.completed++;
331 statistics.individual_group_statistics[g->name] = s;
332 }
333 reporter.run_end(loc, statistics);
334 return statistics;
335 }
336
337 struct test {
338
339 mpt::test::context & context;
340 const char * name{""};
341 mpt::source_location source_location{mpt::source_location::current()};
342 void (*breakpoint)(void){nullptr};
343
344 test(const test &) = delete;
345 test & operator=(const test &) = delete;
346
testmpt::MPT_INLINE_NS::test::test347 inline test(mpt::test::context & context_, void (*breakpoint_)(void) = nullptr, const mpt::source_location & source_location_ = mpt::source_location::current())
348 : context(context_)
349 , source_location(source_location_)
350 , breakpoint(breakpoint_) {
351 report_test_begin();
352 }
testmpt::MPT_INLINE_NS::test::test353 inline test(mpt::test::context & context_, const char * name_, void (*breakpoint_)(void) = nullptr, const mpt::source_location & source_location_ = mpt::source_location::current())
354 : context(context_)
355 , name(name_)
356 , source_location(source_location_)
357 , breakpoint(breakpoint_) {
358 report_test_begin();
359 }
360
~testmpt::MPT_INLINE_NS::test::test361 inline ~test() {
362 report_test_end();
363 }
364
immediate_breakpointmpt::MPT_INLINE_NS::test::test365 inline void immediate_breakpoint() {
366 if (breakpoint) {
367 breakpoint();
368 } else {
369 context.reporter.immediate_breakpoint();
370 }
371 }
372
report_test_beginmpt::MPT_INLINE_NS::test::test373 void report_test_begin() {
374 context.statistics.tests.total++;
375 context.statistics.tests.run++;
376 context.statistics.local_cases = statistics_counters{};
377 context.reporter.test_begin(source_location, name);
378 }
379
report_runmpt::MPT_INLINE_NS::test::test380 void report_run() {
381 context.statistics.local_cases.total++;
382 context.statistics.local_cases.run++;
383 context.reporter.case_run(source_location);
384 }
report_runmpt::MPT_INLINE_NS::test::test385 void report_run(const char * text_e) {
386 context.statistics.local_cases.total++;
387 context.statistics.local_cases.run++;
388 context.reporter.case_run(source_location, text_e);
389 }
report_runmpt::MPT_INLINE_NS::test::test390 void report_run(const char * text_ex, const char * text_e) {
391 context.statistics.local_cases.total++;
392 context.statistics.local_cases.run++;
393 context.reporter.case_run(source_location, text_ex, text_e);
394 }
report_runmpt::MPT_INLINE_NS::test::test395 void report_run(const char * text_a, const char * text_cmp, const char * text_b) {
396 context.statistics.local_cases.total++;
397 context.statistics.local_cases.run++;
398 context.reporter.case_run(source_location, text_a, text_cmp, text_b);
399 }
400
report_resultmpt::MPT_INLINE_NS::test::test401 void report_result(mpt::test::result result) {
402 if (std::holds_alternative<result_success>(result.info)) {
403 context.statistics.local_cases.successes++;
404 } else if (std::holds_alternative<result_failure>(result.info)) {
405 context.statistics.local_cases.failures++;
406 } else if (std::holds_alternative<result_unexpected_exception>(result.info)) {
407 context.statistics.local_cases.unexpected_exceptions++;
408 }
409 context.statistics.local_cases.completed++;
410 context.reporter.case_result(source_location, result);
411 }
412
report_test_endmpt::MPT_INLINE_NS::test::test413 void report_test_end() {
414 context.statistics.cases += context.statistics.local_cases;
415 if (context.statistics.local_cases.unexpected_exceptions) {
416 context.statistics.tests.unexpected_exceptions++;
417 } else if (context.statistics.local_cases.failures) {
418 context.statistics.tests.failures++;
419 } else {
420 context.statistics.tests.successes++;
421 }
422 context.statistics.tests.completed++;
423 context.reporter.test_end(source_location, name, context.statistics.local_cases);
424 }
425
426 template <typename Texception, typename Tcallable, typename std::enable_if<std::is_invocable<Tcallable>::value, bool>::type = true>
expect_throwsmpt::MPT_INLINE_NS::test::test427 inline test & expect_throws(Tcallable c, const char * text_ex = nullptr, const char * text_e = nullptr) {
428 const std::type_info & tiexception = typeid(Texception);
429 const std::type_info & tic = typeid(decltype(c()));
430 report_run(text_ex ? text_ex : tiexception.name(), text_e ? text_e : tic.name());
431 mpt::test::result result;
432 try {
433 c();
434 immediate_breakpoint();
435 result.info = mpt::test::result_failure{};
436 } catch (const Texception &) {
437 result.info = mpt::test::result_success{};
438 } catch (...) {
439 immediate_breakpoint();
440 result.info = mpt::test::result_unexpected_exception{mpt::test::get_exception_text()};
441 }
442 report_result(result);
443 return *this;
444 }
445
446 template <typename Tcallable, typename std::enable_if<std::is_invocable<Tcallable>::value, bool>::type = true>
expect_throws_anympt::MPT_INLINE_NS::test::test447 inline test & expect_throws_any(Tcallable c, const char * text_e = nullptr) {
448 const std::type_info & tic = typeid(decltype(c()));
449 report_run(nullptr, text_e ? text_e : tic.name());
450 mpt::test::result result;
451 try {
452 c();
453 immediate_breakpoint();
454 result.info = mpt::test::result_failure{};
455 } catch (...) {
456 result.info = mpt::test::result_success{};
457 }
458 report_result(result);
459 return *this;
460 }
461
462 template <typename Texpr, typename std::enable_if<std::is_invocable<Texpr>::value, bool>::type = true>
expectmpt::MPT_INLINE_NS::test::test463 inline test & expect(Texpr e, const char * text_e = nullptr) {
464 const std::type_info & tie = typeid(decltype(std::invoke(e)));
465 report_run(text_e ? text_e : tie.name());
466 mpt::test::result result;
467 try {
468 const auto ve = std::invoke(e);
469 if (!ve) {
470 immediate_breakpoint();
471 result.info = mpt::test::result_failure{/*mpt::test::format(ve)*/};
472 } else {
473 result.info = mpt::test::result_success{};
474 }
475 } catch (...) {
476 immediate_breakpoint();
477 result.info = mpt::test::result_unexpected_exception{mpt::test::get_exception_text()};
478 }
479 report_result(result);
480 return *this;
481 }
482
483 template <typename Ta, typename Tcmp, typename Tb, typename std::enable_if<std::is_invocable<Ta>::value, bool>::type = true, typename std::enable_if<std::is_invocable<Tb>::value, bool>::type = true>
expectmpt::MPT_INLINE_NS::test::test484 inline test & expect(Ta && a, Tcmp cmp, Tb && b, const char * text_a = nullptr, const char * text_cmp = nullptr, const char * text_b = nullptr) {
485 const std::type_info & tia = typeid(decltype(std::invoke(a)));
486 const std::type_info & ticmp = typeid(decltype(cmp));
487 const std::type_info & tib = typeid(decltype(std::invoke(b)));
488 report_run(text_a ? text_a : tia.name(), text_cmp ? text_cmp : ticmp.name(), text_b ? text_b : tib.name());
489 mpt::test::result result;
490 try {
491 const auto va = std::invoke(a);
492 const auto vb = std::invoke(b);
493 if (!cmp(va, vb)) {
494 immediate_breakpoint();
495 result.info = mpt::test::result_failure{mpt::test::format(va) + " " + mpt::test::format(cmp) + " " + mpt::test::format(vb)};
496 } else {
497 result.info = mpt::test::result_success{};
498 }
499 } catch (...) {
500 immediate_breakpoint();
501 result.info = mpt::test::result_unexpected_exception{mpt::test::get_exception_text()};
502 }
503 report_result(result);
504 return *this;
505 }
506
507 template <typename Texpr, typename std::enable_if<!std::is_invocable<Texpr>::value, bool>::type = true>
expectmpt::MPT_INLINE_NS::test::test508 inline test & expect(Texpr && e, const char * text_e = nullptr) {
509 const std::type_info & tie = typeid(decltype(std::forward<Texpr>(e)));
510 report_run(text_e ? text_e : tie.name());
511 mpt::test::result result;
512 try {
513 const auto ve = std::forward<Texpr>(e);
514 if (!ve) {
515 immediate_breakpoint();
516 result.info = mpt::test::result_failure{/*mpt::test::format(ve)*/};
517 } else {
518 result.info = mpt::test::result_success{};
519 }
520 } catch (...) {
521 immediate_breakpoint();
522 result.info = mpt::test::result_unexpected_exception{mpt::test::get_exception_text()};
523 }
524 report_result(result);
525 return *this;
526 }
527
528 template <typename Ta, typename Tcmp, typename Tb, typename std::enable_if<!std::is_invocable<Ta>::value, bool>::type = true, typename std::enable_if<!std::is_invocable<Tb>::value, bool>::type = true>
expectmpt::MPT_INLINE_NS::test::test529 inline test & expect(Ta && a, Tcmp cmp, Tb && b, const char * text_a = nullptr, const char * text_cmp = nullptr, const char * text_b = nullptr) {
530 const std::type_info & tia = typeid(decltype(std::forward<Ta>(a)));
531 const std::type_info & ticmp = typeid(decltype(cmp));
532 const std::type_info & tib = typeid(decltype(std::forward<Tb>(b)));
533 report_run(text_a ? text_a : tia.name(), text_cmp ? text_cmp : ticmp.name(), text_b ? text_b : tib.name());
534 mpt::test::result result;
535 try {
536 const auto va = std::forward<Ta>(a);
537 const auto vb = std::forward<Tb>(b);
538 if (!cmp(va, vb)) {
539 immediate_breakpoint();
540 result.info = mpt::test::result_failure{mpt::test::format(va) + " " + mpt::test::format(cmp) + " " + mpt::test::format(vb)};
541 } else {
542 result.info = mpt::test::result_success{};
543 }
544 } catch (...) {
545 immediate_breakpoint();
546 result.info = mpt::test::result_unexpected_exception{mpt::test::get_exception_text()};
547 }
548 report_result(result);
549 return *this;
550 }
551 };
552
553
554
555 } // namespace test
556
557
558
559 } // namespace MPT_INLINE_NS
560 } // namespace mpt
561
562
563
564 #endif // MPT_TEST_TEST_HPP
565