1 //
2 //  MNNTestSuite.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/10.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "MNNTestSuite.h"
10 #include <stdlib.h>
11 
12 MNNTestSuite* MNNTestSuite::gInstance = NULL;
13 
get()14 MNNTestSuite* MNNTestSuite::get() {
15     if (gInstance == NULL)
16         gInstance = new MNNTestSuite;
17     return gInstance;
18 }
19 
~MNNTestSuite()20 MNNTestSuite::~MNNTestSuite() {
21     for (int i = 0; i < mTests.size(); ++i) {
22         delete mTests[i];
23     }
24     mTests.clear();
25 }
26 
add(MNNTestCase * test,const char * name)27 void MNNTestSuite::add(MNNTestCase* test, const char* name) {
28     test->name = name;
29     mTests.push_back(test);
30 }
31 
run(const char * key,int precision)32 void MNNTestSuite::run(const char* key, int precision) {
33     if (key == NULL || strlen(key) == 0)
34         return;
35 
36     auto suite         = MNNTestSuite::get();
37     std::string prefix = key;
38     std::vector<std::string> wrongs;
39     size_t runUnit = 0;
40     for (int i = 0; i < suite->mTests.size(); ++i) {
41         MNNTestCase* test = suite->mTests[i];
42         if (test->name.find(prefix) == 0) {
43             runUnit++;
44             printf("\trunning %s.\n", test->name.c_str());
45             auto res = test->run(precision);
46             if (!res) {
47                 wrongs.emplace_back(test->name);
48             }
49         }
50     }
51     if (wrongs.empty()) {
52         printf("√√√ all <%s> tests passed.\n", key);
53     }
54     for (auto& wrong : wrongs) {
55         printf("Error: %s\n", wrong.c_str());
56     }
57     printf("### Wrong/Total: %zu / %zu ###\n", wrongs.size(), runUnit);
58 }
59 
runAll(int precision)60 void MNNTestSuite::runAll(int precision) {
61     auto suite = MNNTestSuite::get();
62     std::vector<std::string> wrongs;
63     for (int i = 0; i < suite->mTests.size(); ++i) {
64         MNNTestCase* test = suite->mTests[i];
65         if (test->name.find("speed") != std::string::npos) {
66             // Don't test for speed because cost
67             continue;
68         }
69         if (test->name.find("model") != std::string::npos) {
70             // Don't test for model because need resource
71             continue;
72         }
73         printf("\trunning %s.\n", test->name.c_str());
74         auto res = test->run(precision);
75         if (!res) {
76             wrongs.emplace_back(test->name);
77         }
78     }
79     if (wrongs.empty()) {
80         printf("√√√ all tests passed.\n");
81     }
82     for (auto& wrong : wrongs) {
83         printf("Error: %s\n", wrong.c_str());
84     }
85     printf("### Wrong/Total: %zu / %zu ###\n", wrongs.size(), suite->mTests.size());
86 }
87