1 //
2 //  dataTransformer.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/05/05.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <MNN/ImageProcess.hpp>
10 #include <MNN/Interpreter.hpp>
11 #include <fstream>
12 #include <sstream>
13 #include "MNN_generated.h"
14 #include "rapidjson/document.h"
15 using namespace MNN;
16 using namespace MNN::CV;
17 #define STB_IMAGE_IMPLEMENTATION
18 #include "stb_image.h"
19 
main(int argc,const char * argv[])20 int main(int argc, const char* argv[]) {
21     rapidjson::Document document;
22     if (argc < 3) {
23         MNN_ERROR("Usage: ./dataTransformer.out mobilenet.alinn picpath.json storage.bin\n");
24         return 0;
25     }
26     FUNC_PRINT_ALL(argv[1], s);
27     FUNC_PRINT_ALL(argv[2], s);
28     FUNC_PRINT_ALL(argv[3], s);
29     std::unique_ptr<Interpreter> net(Interpreter::createFromFile(argv[1]));
30     ScheduleConfig scheduleConfig;
31     auto session    = net->createSession(scheduleConfig);
32     auto dataTensor = net->getSessionInput(session, nullptr);
33     auto probTensor = net->getSessionOutput(session, nullptr);
34 
35     {
36         std::ifstream fileNames(argv[2]);
37         std::ostringstream output;
38         output << fileNames.rdbuf();
39         auto outputStr = output.str();
40         document.Parse(outputStr.c_str());
41         if (document.HasParseError()) {
42             MNN_ERROR("Invalid json\n");
43             return 0;
44         }
45     }
46     auto picObj = document.GetObject();
47     ImageProcess::Config config;
48     config.destFormat = BGR;
49     {
50         if (picObj.HasMember("format")) {
51             auto format = picObj["format"].GetString();
52             static std::map<std::string, ImageFormat> formatMap{{"BGR", BGR}, {"RGB", RGB}, {"GRAY", GRAY}};
53             if (formatMap.find(format) != formatMap.end()) {
54                 config.destFormat = formatMap.find(format)->second;
55             }
56         }
57     }
58     config.sourceFormat = RGBA;
59     {
60         if (picObj.HasMember("mean")) {
61             auto mean = picObj["mean"].GetArray();
62             int cur   = 0;
63             for (auto iter = mean.begin(); iter != mean.end(); iter++) {
64                 config.mean[cur++] = iter->GetFloat();
65             }
66         }
67         if (picObj.HasMember("normal")) {
68             auto normal = picObj["normal"].GetArray();
69             int cur     = 0;
70             for (auto iter = normal.begin(); iter != normal.end(); iter++) {
71                 config.normal[cur++] = iter->GetFloat();
72             }
73         }
74     }
75     std::shared_ptr<ImageProcess> process(ImageProcess::create(config));
76     std::vector<std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>> result;
77 
78     auto pathArray = picObj["path"].GetArray();
79     for (auto iter = pathArray.begin(); iter != pathArray.end(); iter++) {
80         auto path = iter->GetString();
81         // FUNC_PRINT_ALL(path, s);
82         int width, height, channel;
83         auto inputImage = stbi_load(path, &width, &height, &channel, 4);
84         if (nullptr == inputImage) {
85             MNN_ERROR("Invalid Path: %s\n", path);
86             continue;
87         }
88         Matrix m;
89         m.setScale((float)width / dataTensor->width(), (float)height / dataTensor->height());
90 
91         process->setMatrix(m);
92         process->convert(inputImage, width, height, 0, dataTensor);
93         std::shared_ptr<Tensor> userTensor(new Tensor(dataTensor));
94         dataTensor->copyToHostTensor(userTensor.get());
95 
96         net->runSession(session);
97 
98         std::shared_ptr<Tensor> probUserTensor(new Tensor(probTensor, probTensor->getDimensionType()));
99         probTensor->copyToHostTensor(probUserTensor.get());
100         // FUNC_PRINT(probTensor->elementSize());
101 
102         result.emplace_back(std::make_pair(userTensor, probUserTensor));
103         stbi_image_free(inputImage);
104     }
105     {
106         std::unique_ptr<NetT> data(new NetT);
107         data->tensorName = {net->getSessionInputAll(session).begin()->first,
108                             net->getSessionOutputAll(session).begin()->first + "_Compare"};
109         {
110             std::unique_ptr<OpT> newOp(new OpT);
111             newOp->type          = OpType_Const;
112             newOp->name          = data->tensorName[0];
113             newOp->outputIndexes = {0};
114             newOp->main.type     = OpParameter_Blob;
115             auto blobT           = new BlobT;
116             blobT->dims      = {(int)result.size(), dataTensor->channel(), dataTensor->height(), dataTensor->width()};
117             size_t totalSize = 1;
118             for (int i = 0; i < blobT->dims.size(); ++i) {
119                 totalSize *= blobT->dims[i];
120             }
121             blobT->float32s.resize(totalSize);
122             switch (dataTensor->getDimensionType()) {
123                 case MNN::Tensor::CAFFE:
124                     blobT->dataFormat = MNN_DATA_FORMAT_NCHW;
125                     break;
126                 case MNN::Tensor::TENSORFLOW:
127                     blobT->dataFormat = MNN_DATA_FORMAT_NHWC;
128                     break;
129                 default:
130                     break;
131             }
132             for (int i = 0; i < result.size(); ++i) {
133                 auto tensor = result[i].first.get();
134                 auto dst    = blobT->float32s.data() + i * tensor->elementSize();
135                 auto src    = tensor->host<float>();
136                 ::memcpy(dst, src, tensor->size());
137             }
138             newOp->main.value = blobT;
139             data->oplists.emplace_back(std::move(newOp));
140         }
141         {
142             std::unique_ptr<OpT> newOp(new OpT);
143             newOp->type          = OpType_Const;
144             newOp->name          = data->tensorName[1];
145             newOp->outputIndexes = {1};
146             newOp->main.type     = OpParameter_Blob;
147             auto blobT           = new BlobT;
148             for (int i = 0; i < probTensor->dimensions(); ++i) {
149                 blobT->dims.emplace_back(probTensor->length(i));
150             }
151             blobT->dims[0]   = result.size();
152             size_t totalSize = 1;
153             for (int i = 0; i < blobT->dims.size(); ++i) {
154                 totalSize *= blobT->dims[i];
155             }
156             switch (probTensor->getDimensionType()) {
157                 case MNN::Tensor::CAFFE:
158                     blobT->dataFormat = MNN_DATA_FORMAT_NCHW;
159                     break;
160                 case MNN::Tensor::TENSORFLOW:
161                     blobT->dataFormat = MNN_DATA_FORMAT_NHWC;
162                     break;
163                 default:
164                     break;
165             }
166             blobT->float32s.resize(totalSize);
167             for (int i = 0; i < result.size(); ++i) {
168                 auto tensor = result[i].second.get();
169                 auto dst    = blobT->float32s.data() + i * tensor->elementSize();
170                 auto src    = tensor->host<float>();
171                 ::memcpy(dst, src, tensor->size());
172             }
173             newOp->main.value = blobT;
174             data->oplists.emplace_back(std::move(newOp));
175         }
176         flatbuffers::FlatBufferBuilder builder(1024);
177         auto offset = Net::Pack(builder, data.get());
178         builder.Finish(offset);
179         std::ofstream os(argv[3]);
180         os.write((const char*)builder.GetBufferPointer(), builder.GetSize());
181     }
182 
183     return 0;
184 }
185