1 //
2 //  TestConvertResult.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2020/01/22.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "MNN_generated.h"
10 #include <MNN/expr/Expr.hpp>
11 #include <MNN/expr/Module.hpp>
12 #include <MNN/expr/ExprCreator.hpp>
13 #include "PostConverter.hpp"
14 #include "rapidjson/document.h"
15 #include <fstream>
16 #include <sstream>
17 #include <cmath>
18 #include "cli.hpp"
19 using namespace MNN::Express;
20 using namespace MNN;
21 
compareOutput(VARP output,const std::string & directName,const std::string & name,Dimensionformat dataFormat,int order)22 static bool compareOutput(VARP output, const std::string& directName, const std::string& name, Dimensionformat dataFormat, int order) {
23     auto info = output->getInfo();
24     auto ptr = output->readMap<float>();
25     if (nullptr == info || nullptr == ptr) {
26         MNN_ERROR("TESTERROR ptr / info nullptr\n");
27         return false;
28     }
29     std::ifstream outputOrigin;
30     // First find key
31     {
32         std::ostringstream outputFileOs;
33         outputFileOs << directName << "/" << name <<".txt";
34         outputOrigin.open(outputFileOs.str().c_str());
35     }
36     // Second find order
37     if (outputOrigin.fail()) {
38         std::ostringstream outputFileOs;
39         outputFileOs << directName << "/" << order <<".txt";
40         outputOrigin.open(outputFileOs.str().c_str());
41     }
42     if (info->order == NC4HW4 && info->dim.size() > 1) {
43         output = _Convert(output, dataFormat);
44         info = output->getInfo();
45     }
46     if (info->type.code != halide_type_float) {
47         output = _Cast<float>(output);
48         info = output->getInfo();
49     }
50     auto targetValue = _Input({info->dim}, info->order, info->type);
51     auto targetPtr = targetValue->writeMap<float>();
52     for (int i=0; i<info->size; ++i) {
53         outputOrigin >> targetPtr[i];
54     }
55     auto absMax = _ReduceMax(_Abs(targetValue), {});
56     auto diff = _Abs(targetValue - output);
57     auto diffAbsMax = _ReduceMax(diff);
58     auto absMaxV = absMax->readMap<float>()[0];
59     auto diffAbsMaxV = diffAbsMax->readMap<float>()[0];
60     if (absMaxV * 0.01f < diffAbsMaxV || std::isnan(absMaxV)) {
61         MNN_ERROR("TESTERROR %s value error : absMaxV:%f - DiffMax %f\n", name.c_str(), absMaxV, diffAbsMaxV);
62         return false;
63     }
64     return true;
65 }
main(int argc,char * argv[])66 int main(int argc, char *argv[]) {
67     if (argc < 3) {
68         MNN_ERROR("Usage: ./TestConvertResult [Onnx, Tf, Tflite, Torch] ${Dir}\n");
69         return 0;
70     }
71     std::string inputType = argv[1];
72     std::string directName = argv[2];
73     auto inputModel = modelConfig::ONNX;
74     auto suffix = ".onnx";
75     auto dataFormat = NCHW;
76     if (inputType == "Tf") {
77         inputModel = modelConfig::TENSORFLOW;
78         suffix = ".pb";
79         dataFormat = NHWC;
80     } else if (inputType == "Tflite") {
81         inputModel = modelConfig::TFLITE;
82         suffix = ".tflite";
83         dataFormat = NHWC;
84     } else if (inputType == "Torch") {
85         inputModel = modelConfig::TORCH;
86         suffix = ".pt";
87     }
88     MNN_PRINT("Test %s\n", directName.c_str());
89     std::string defaultCacheFile = ".___temp.mnn";
90     {
91         modelConfig modelPath;
92         modelPath.model = inputModel;
93         std::ostringstream modelNameOs;
94         modelNameOs << directName << "/test" << suffix;
95         modelPath.modelFile = modelNameOs.str();
96         modelPath.MNNModel = defaultCacheFile;
97         Cli::convertModel(modelPath);
98     }
99     bool useControlFlow = false;
100     rapidjson::Document document;
101     std::map<std::string, float> inputInfo;
102     std::map<std::string, std::vector<int>> inputShape;
103     std::vector<std::string> inputNames;
104     std::vector<std::string> outputNames;
105     {
106         std::ostringstream jsonNameOs;
107         jsonNameOs << directName << "/input.json";
108         std::ifstream fileNames(jsonNameOs.str().c_str());
109         std::ostringstream output;
110         output << fileNames.rdbuf();
111         auto outputStr = output.str();
112         document.Parse(outputStr.c_str());
113         if (document.HasParseError()) {
114             MNN_ERROR("Invalid json\n");
115             return 0;
116         }
117         if (document.HasMember("controlflow")) {
118             useControlFlow = document["controlflow"].GetBool();
119         }
120         if (document.HasMember("inputs")) {
121             auto inputsInfo = document["inputs"].GetArray();
122             for (auto iter = inputsInfo.begin(); iter !=inputsInfo.end(); iter++) {
123                 auto obj = iter->GetObject();
124                 std::string name = obj["name"].GetString();
125                 inputNames.emplace_back(name);
126                 MNN_PRINT("%s\n", name.c_str());
127                 if (obj.HasMember("value")) {
128                     float value = obj["value"].GetFloat();
129                     inputInfo.insert(std::make_pair(name, value));
130                 }
131                 if (obj.HasMember("shape")) {
132                     auto dims = obj["shape"].GetArray();
133                     std::vector<int> shapes;
134                     for (auto iter = dims.begin(); iter != dims.end(); iter++) {
135                         shapes.emplace_back(iter->GetInt());
136                     }
137                     inputShape.insert(std::make_pair(name, shapes));
138                 }
139             }
140         }
141         if (document.HasMember("outputs")) {
142             auto array = document["outputs"].GetArray();
143             for (auto iter = array.begin(); iter !=array.end(); iter++) {
144                 std::string name = iter->GetString();
145                 MNN_PRINT("output: %s\n", name.c_str());
146                 outputNames.emplace_back(name);
147             }
148         }
149     }
150 #define LOAD_DATA(TYPE)\
151     if (inputInfo.find(inputName) != inputInfo.end()) {\
152         auto value = inputInfo[inputName];\
153         for (int i=0; i<info->size; ++i) {\
154             ptr[i] = value;\
155         }\
156     } else {\
157         std::ostringstream fileNameOs;\
158         fileNameOs << directName << "/" << inputName << ".txt";\
159         auto fileName = fileNameOs.str();\
160         std::ifstream inputOs(fileName.c_str());\
161         if (inputOs.fail()) {\
162             MNN_ERROR("TESTERROR Can't open %s\n", fileName.c_str());\
163             continue;\
164         }\
165         for (int i=0; i<info->size; ++i) {\
166             inputOs >> ptr[i];\
167         }\
168     }
169     // Expr Branch
170     auto varMap = Variable::loadMap(defaultCacheFile.c_str());
171     for (auto inputName : inputNames) {
172         if (varMap.find(inputName) == varMap.end()) {
173             MNN_ERROR("TESTERROR Can't find var: %s\n", inputName.c_str());
174             continue;
175         }
176         // Resize
177         auto shapeIter = inputShape.find(inputName);
178         if (shapeIter != inputShape.end()) {
179             auto s = shapeIter->second;
180             varMap[inputName]->resize(s);
181         }
182         varMap[inputName] = _ChangeInputFormat(varMap[inputName], dataFormat);
183         auto info = varMap[inputName]->getInfo();
184         if (info->type == halide_type_of<float>()){
185             auto ptr = varMap[inputName]->writeMap<float>();
186             LOAD_DATA(float)
187         } else {
188             auto floatVar = _Input(info->dim, info->order, halide_type_of<float>());
189             auto ptr = floatVar->writeMap<float>();
190             LOAD_DATA(float)
191             auto temp = _Cast(floatVar, info->type);
192             varMap[inputName]->input(temp);
193         }
194     }
195 #undef LOAD_DATA
196     bool modelError = false;
197     // Module Branch
198     if (useControlFlow) {
199         std::shared_ptr<Module> net(Module::load(inputNames, outputNames, defaultCacheFile.c_str()));
200         if (net == nullptr) {
201             MNN_PRINT("Error: can't load module\n");
202             return 0;
203         }
204         std::vector<VARP> inputs;
205         for (auto inputName : inputNames) {
206             inputs.emplace_back(varMap[inputName]);
207         }
208         varMap.clear();
209         auto outputs = net->onForward(inputs);
210         for (int i=0; i<outputNames.size(); ++i) {
211             auto output = outputs[i];
212             bool success = compareOutput(output, directName, outputNames[i], dataFormat, i);
213             if (!success) {
214                 modelError = true;
215                 break;
216             }
217         }
218         if (!modelError) {
219             MNN_PRINT("TEST_SUCCESS\n");
220         }
221         return 0;
222     }
223     // Expr Branch
224     for (int i=0; i<outputNames.size(); ++i) {
225         auto name = outputNames[i];
226         if (varMap.find(name) == varMap.end()) {
227             MNN_ERROR("TESTERROR, Can't find var: %s\n", name.c_str());
228             return 0;
229         }
230         auto output = varMap[name];
231         bool success = compareOutput(output, directName, name, dataFormat, i);
232         if (!success) {
233             modelError = true;
234             break;
235         }
236     }
237     if (modelError) {
238         std::vector<VARP> outputs;
239         MNN_ERROR("Save mnn result to  .error director\n");
240         for (int i=0; i<outputNames.size(); ++i) {
241             auto name = outputNames[i];
242             auto v = varMap[name];
243             auto info = v->getInfo();
244             if (nullptr == info) {
245                 continue;
246             }
247             if (info->order == NC4HW4 && info->dim.size() > 1) {
248                 v = _Convert(v, dataFormat);
249             }
250             if (info->type.code != halide_type_float) {
251                 v = _Cast<float>(v);
252                 info = v->getInfo();
253             }
254             v.fix(VARP::CONSTANT);
255             info = v->getInfo();
256             std::ofstream _output((".error/" + name + ".txt").c_str());
257             auto ptr = v->readMap<float>();
258             for (int v=0; v<info->size; ++v) {
259                 _output << ptr[v] << "\n";
260             }
261             v->setName(name);
262             outputs.emplace_back(v);
263         }
264         Variable::save(outputs, ".Error.mnn");
265         return 0;
266     }
267     MNN_PRINT("TEST_SUCCESS\n");
268     return 0;
269 }
270 
271