1 //
2 // TestConvertResult.cpp
3 // MNNConverter
4 //
5 // Created by MNN on 2020/01/22.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "MNN_generated.h"
10 #include <MNN/expr/Expr.hpp>
11 #include <MNN/expr/Module.hpp>
12 #include <MNN/expr/ExprCreator.hpp>
13 #include "PostConverter.hpp"
14 #include "rapidjson/document.h"
15 #include <fstream>
16 #include <sstream>
17 #include <cmath>
18 #include "cli.hpp"
19 using namespace MNN::Express;
20 using namespace MNN;
21
compareOutput(VARP output,const std::string & directName,const std::string & name,Dimensionformat dataFormat,int order)22 static bool compareOutput(VARP output, const std::string& directName, const std::string& name, Dimensionformat dataFormat, int order) {
23 auto info = output->getInfo();
24 auto ptr = output->readMap<float>();
25 if (nullptr == info || nullptr == ptr) {
26 MNN_ERROR("TESTERROR ptr / info nullptr\n");
27 return false;
28 }
29 std::ifstream outputOrigin;
30 // First find key
31 {
32 std::ostringstream outputFileOs;
33 outputFileOs << directName << "/" << name <<".txt";
34 outputOrigin.open(outputFileOs.str().c_str());
35 }
36 // Second find order
37 if (outputOrigin.fail()) {
38 std::ostringstream outputFileOs;
39 outputFileOs << directName << "/" << order <<".txt";
40 outputOrigin.open(outputFileOs.str().c_str());
41 }
42 if (info->order == NC4HW4 && info->dim.size() > 1) {
43 output = _Convert(output, dataFormat);
44 info = output->getInfo();
45 }
46 if (info->type.code != halide_type_float) {
47 output = _Cast<float>(output);
48 info = output->getInfo();
49 }
50 auto targetValue = _Input({info->dim}, info->order, info->type);
51 auto targetPtr = targetValue->writeMap<float>();
52 for (int i=0; i<info->size; ++i) {
53 outputOrigin >> targetPtr[i];
54 }
55 auto absMax = _ReduceMax(_Abs(targetValue), {});
56 auto diff = _Abs(targetValue - output);
57 auto diffAbsMax = _ReduceMax(diff);
58 auto absMaxV = absMax->readMap<float>()[0];
59 auto diffAbsMaxV = diffAbsMax->readMap<float>()[0];
60 if (absMaxV * 0.01f < diffAbsMaxV || std::isnan(absMaxV)) {
61 MNN_ERROR("TESTERROR %s value error : absMaxV:%f - DiffMax %f\n", name.c_str(), absMaxV, diffAbsMaxV);
62 return false;
63 }
64 return true;
65 }
main(int argc,char * argv[])66 int main(int argc, char *argv[]) {
67 if (argc < 3) {
68 MNN_ERROR("Usage: ./TestConvertResult [Onnx, Tf, Tflite, Torch] ${Dir}\n");
69 return 0;
70 }
71 std::string inputType = argv[1];
72 std::string directName = argv[2];
73 auto inputModel = modelConfig::ONNX;
74 auto suffix = ".onnx";
75 auto dataFormat = NCHW;
76 if (inputType == "Tf") {
77 inputModel = modelConfig::TENSORFLOW;
78 suffix = ".pb";
79 dataFormat = NHWC;
80 } else if (inputType == "Tflite") {
81 inputModel = modelConfig::TFLITE;
82 suffix = ".tflite";
83 dataFormat = NHWC;
84 } else if (inputType == "Torch") {
85 inputModel = modelConfig::TORCH;
86 suffix = ".pt";
87 }
88 MNN_PRINT("Test %s\n", directName.c_str());
89 std::string defaultCacheFile = ".___temp.mnn";
90 {
91 modelConfig modelPath;
92 modelPath.model = inputModel;
93 std::ostringstream modelNameOs;
94 modelNameOs << directName << "/test" << suffix;
95 modelPath.modelFile = modelNameOs.str();
96 modelPath.MNNModel = defaultCacheFile;
97 Cli::convertModel(modelPath);
98 }
99 bool useControlFlow = false;
100 rapidjson::Document document;
101 std::map<std::string, float> inputInfo;
102 std::map<std::string, std::vector<int>> inputShape;
103 std::vector<std::string> inputNames;
104 std::vector<std::string> outputNames;
105 {
106 std::ostringstream jsonNameOs;
107 jsonNameOs << directName << "/input.json";
108 std::ifstream fileNames(jsonNameOs.str().c_str());
109 std::ostringstream output;
110 output << fileNames.rdbuf();
111 auto outputStr = output.str();
112 document.Parse(outputStr.c_str());
113 if (document.HasParseError()) {
114 MNN_ERROR("Invalid json\n");
115 return 0;
116 }
117 if (document.HasMember("controlflow")) {
118 useControlFlow = document["controlflow"].GetBool();
119 }
120 if (document.HasMember("inputs")) {
121 auto inputsInfo = document["inputs"].GetArray();
122 for (auto iter = inputsInfo.begin(); iter !=inputsInfo.end(); iter++) {
123 auto obj = iter->GetObject();
124 std::string name = obj["name"].GetString();
125 inputNames.emplace_back(name);
126 MNN_PRINT("%s\n", name.c_str());
127 if (obj.HasMember("value")) {
128 float value = obj["value"].GetFloat();
129 inputInfo.insert(std::make_pair(name, value));
130 }
131 if (obj.HasMember("shape")) {
132 auto dims = obj["shape"].GetArray();
133 std::vector<int> shapes;
134 for (auto iter = dims.begin(); iter != dims.end(); iter++) {
135 shapes.emplace_back(iter->GetInt());
136 }
137 inputShape.insert(std::make_pair(name, shapes));
138 }
139 }
140 }
141 if (document.HasMember("outputs")) {
142 auto array = document["outputs"].GetArray();
143 for (auto iter = array.begin(); iter !=array.end(); iter++) {
144 std::string name = iter->GetString();
145 MNN_PRINT("output: %s\n", name.c_str());
146 outputNames.emplace_back(name);
147 }
148 }
149 }
150 #define LOAD_DATA(TYPE)\
151 if (inputInfo.find(inputName) != inputInfo.end()) {\
152 auto value = inputInfo[inputName];\
153 for (int i=0; i<info->size; ++i) {\
154 ptr[i] = value;\
155 }\
156 } else {\
157 std::ostringstream fileNameOs;\
158 fileNameOs << directName << "/" << inputName << ".txt";\
159 auto fileName = fileNameOs.str();\
160 std::ifstream inputOs(fileName.c_str());\
161 if (inputOs.fail()) {\
162 MNN_ERROR("TESTERROR Can't open %s\n", fileName.c_str());\
163 continue;\
164 }\
165 for (int i=0; i<info->size; ++i) {\
166 inputOs >> ptr[i];\
167 }\
168 }
169 // Expr Branch
170 auto varMap = Variable::loadMap(defaultCacheFile.c_str());
171 for (auto inputName : inputNames) {
172 if (varMap.find(inputName) == varMap.end()) {
173 MNN_ERROR("TESTERROR Can't find var: %s\n", inputName.c_str());
174 continue;
175 }
176 // Resize
177 auto shapeIter = inputShape.find(inputName);
178 if (shapeIter != inputShape.end()) {
179 auto s = shapeIter->second;
180 varMap[inputName]->resize(s);
181 }
182 varMap[inputName] = _ChangeInputFormat(varMap[inputName], dataFormat);
183 auto info = varMap[inputName]->getInfo();
184 if (info->type == halide_type_of<float>()){
185 auto ptr = varMap[inputName]->writeMap<float>();
186 LOAD_DATA(float)
187 } else {
188 auto floatVar = _Input(info->dim, info->order, halide_type_of<float>());
189 auto ptr = floatVar->writeMap<float>();
190 LOAD_DATA(float)
191 auto temp = _Cast(floatVar, info->type);
192 varMap[inputName]->input(temp);
193 }
194 }
195 #undef LOAD_DATA
196 bool modelError = false;
197 // Module Branch
198 if (useControlFlow) {
199 std::shared_ptr<Module> net(Module::load(inputNames, outputNames, defaultCacheFile.c_str()));
200 if (net == nullptr) {
201 MNN_PRINT("Error: can't load module\n");
202 return 0;
203 }
204 std::vector<VARP> inputs;
205 for (auto inputName : inputNames) {
206 inputs.emplace_back(varMap[inputName]);
207 }
208 varMap.clear();
209 auto outputs = net->onForward(inputs);
210 for (int i=0; i<outputNames.size(); ++i) {
211 auto output = outputs[i];
212 bool success = compareOutput(output, directName, outputNames[i], dataFormat, i);
213 if (!success) {
214 modelError = true;
215 break;
216 }
217 }
218 if (!modelError) {
219 MNN_PRINT("TEST_SUCCESS\n");
220 }
221 return 0;
222 }
223 // Expr Branch
224 for (int i=0; i<outputNames.size(); ++i) {
225 auto name = outputNames[i];
226 if (varMap.find(name) == varMap.end()) {
227 MNN_ERROR("TESTERROR, Can't find var: %s\n", name.c_str());
228 return 0;
229 }
230 auto output = varMap[name];
231 bool success = compareOutput(output, directName, name, dataFormat, i);
232 if (!success) {
233 modelError = true;
234 break;
235 }
236 }
237 if (modelError) {
238 std::vector<VARP> outputs;
239 MNN_ERROR("Save mnn result to .error director\n");
240 for (int i=0; i<outputNames.size(); ++i) {
241 auto name = outputNames[i];
242 auto v = varMap[name];
243 auto info = v->getInfo();
244 if (nullptr == info) {
245 continue;
246 }
247 if (info->order == NC4HW4 && info->dim.size() > 1) {
248 v = _Convert(v, dataFormat);
249 }
250 if (info->type.code != halide_type_float) {
251 v = _Cast<float>(v);
252 info = v->getInfo();
253 }
254 v.fix(VARP::CONSTANT);
255 info = v->getInfo();
256 std::ofstream _output((".error/" + name + ".txt").c_str());
257 auto ptr = v->readMap<float>();
258 for (int v=0; v<info->size; ++v) {
259 _output << ptr[v] << "\n";
260 }
261 v->setName(name);
262 outputs.emplace_back(v);
263 }
264 Variable::save(outputs, ".Error.mnn");
265 return 0;
266 }
267 MNN_PRINT("TEST_SUCCESS\n");
268 return 0;
269 }
270
271