1 //
2 // dataTransformer.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/05/05.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include <MNN/ImageProcess.hpp>
10 #include <MNN/Interpreter.hpp>
11 #include <fstream>
12 #include <sstream>
13 #include "MNN_generated.h"
14 #include "rapidjson/document.h"
15 using namespace MNN;
16 using namespace MNN::CV;
17 #define STB_IMAGE_IMPLEMENTATION
18 #include "stb_image.h"
19
main(int argc,const char * argv[])20 int main(int argc, const char* argv[]) {
21 rapidjson::Document document;
22 if (argc < 3) {
23 MNN_ERROR("Usage: ./dataTransformer.out mobilenet.alinn picpath.json storage.bin\n");
24 return 0;
25 }
26 FUNC_PRINT_ALL(argv[1], s);
27 FUNC_PRINT_ALL(argv[2], s);
28 FUNC_PRINT_ALL(argv[3], s);
29 std::unique_ptr<Interpreter> net(Interpreter::createFromFile(argv[1]));
30 ScheduleConfig scheduleConfig;
31 auto session = net->createSession(scheduleConfig);
32 auto dataTensor = net->getSessionInput(session, nullptr);
33 auto probTensor = net->getSessionOutput(session, nullptr);
34
35 {
36 std::ifstream fileNames(argv[2]);
37 std::ostringstream output;
38 output << fileNames.rdbuf();
39 auto outputStr = output.str();
40 document.Parse(outputStr.c_str());
41 if (document.HasParseError()) {
42 MNN_ERROR("Invalid json\n");
43 return 0;
44 }
45 }
46 auto picObj = document.GetObject();
47 ImageProcess::Config config;
48 config.destFormat = BGR;
49 {
50 if (picObj.HasMember("format")) {
51 auto format = picObj["format"].GetString();
52 static std::map<std::string, ImageFormat> formatMap{{"BGR", BGR}, {"RGB", RGB}, {"GRAY", GRAY}};
53 if (formatMap.find(format) != formatMap.end()) {
54 config.destFormat = formatMap.find(format)->second;
55 }
56 }
57 }
58 config.sourceFormat = RGBA;
59 {
60 if (picObj.HasMember("mean")) {
61 auto mean = picObj["mean"].GetArray();
62 int cur = 0;
63 for (auto iter = mean.begin(); iter != mean.end(); iter++) {
64 config.mean[cur++] = iter->GetFloat();
65 }
66 }
67 if (picObj.HasMember("normal")) {
68 auto normal = picObj["normal"].GetArray();
69 int cur = 0;
70 for (auto iter = normal.begin(); iter != normal.end(); iter++) {
71 config.normal[cur++] = iter->GetFloat();
72 }
73 }
74 }
75 std::shared_ptr<ImageProcess> process(ImageProcess::create(config));
76 std::vector<std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>> result;
77
78 auto pathArray = picObj["path"].GetArray();
79 for (auto iter = pathArray.begin(); iter != pathArray.end(); iter++) {
80 auto path = iter->GetString();
81 // FUNC_PRINT_ALL(path, s);
82 int width, height, channel;
83 auto inputImage = stbi_load(path, &width, &height, &channel, 4);
84 if (nullptr == inputImage) {
85 MNN_ERROR("Invalid Path: %s\n", path);
86 continue;
87 }
88 Matrix m;
89 m.setScale((float)width / dataTensor->width(), (float)height / dataTensor->height());
90
91 process->setMatrix(m);
92 process->convert(inputImage, width, height, 0, dataTensor);
93 std::shared_ptr<Tensor> userTensor(new Tensor(dataTensor));
94 dataTensor->copyToHostTensor(userTensor.get());
95
96 net->runSession(session);
97
98 std::shared_ptr<Tensor> probUserTensor(new Tensor(probTensor, probTensor->getDimensionType()));
99 probTensor->copyToHostTensor(probUserTensor.get());
100 // FUNC_PRINT(probTensor->elementSize());
101
102 result.emplace_back(std::make_pair(userTensor, probUserTensor));
103 stbi_image_free(inputImage);
104 }
105 {
106 std::unique_ptr<NetT> data(new NetT);
107 data->tensorName = {net->getSessionInputAll(session).begin()->first,
108 net->getSessionOutputAll(session).begin()->first + "_Compare"};
109 {
110 std::unique_ptr<OpT> newOp(new OpT);
111 newOp->type = OpType_Const;
112 newOp->name = data->tensorName[0];
113 newOp->outputIndexes = {0};
114 newOp->main.type = OpParameter_Blob;
115 auto blobT = new BlobT;
116 blobT->dims = {(int)result.size(), dataTensor->channel(), dataTensor->height(), dataTensor->width()};
117 size_t totalSize = 1;
118 for (int i = 0; i < blobT->dims.size(); ++i) {
119 totalSize *= blobT->dims[i];
120 }
121 blobT->float32s.resize(totalSize);
122 switch (dataTensor->getDimensionType()) {
123 case MNN::Tensor::CAFFE:
124 blobT->dataFormat = MNN_DATA_FORMAT_NCHW;
125 break;
126 case MNN::Tensor::TENSORFLOW:
127 blobT->dataFormat = MNN_DATA_FORMAT_NHWC;
128 break;
129 default:
130 break;
131 }
132 for (int i = 0; i < result.size(); ++i) {
133 auto tensor = result[i].first.get();
134 auto dst = blobT->float32s.data() + i * tensor->elementSize();
135 auto src = tensor->host<float>();
136 ::memcpy(dst, src, tensor->size());
137 }
138 newOp->main.value = blobT;
139 data->oplists.emplace_back(std::move(newOp));
140 }
141 {
142 std::unique_ptr<OpT> newOp(new OpT);
143 newOp->type = OpType_Const;
144 newOp->name = data->tensorName[1];
145 newOp->outputIndexes = {1};
146 newOp->main.type = OpParameter_Blob;
147 auto blobT = new BlobT;
148 for (int i = 0; i < probTensor->dimensions(); ++i) {
149 blobT->dims.emplace_back(probTensor->length(i));
150 }
151 blobT->dims[0] = result.size();
152 size_t totalSize = 1;
153 for (int i = 0; i < blobT->dims.size(); ++i) {
154 totalSize *= blobT->dims[i];
155 }
156 switch (probTensor->getDimensionType()) {
157 case MNN::Tensor::CAFFE:
158 blobT->dataFormat = MNN_DATA_FORMAT_NCHW;
159 break;
160 case MNN::Tensor::TENSORFLOW:
161 blobT->dataFormat = MNN_DATA_FORMAT_NHWC;
162 break;
163 default:
164 break;
165 }
166 blobT->float32s.resize(totalSize);
167 for (int i = 0; i < result.size(); ++i) {
168 auto tensor = result[i].second.get();
169 auto dst = blobT->float32s.data() + i * tensor->elementSize();
170 auto src = tensor->host<float>();
171 ::memcpy(dst, src, tensor->size());
172 }
173 newOp->main.value = blobT;
174 data->oplists.emplace_back(std::move(newOp));
175 }
176 flatbuffers::FlatBufferBuilder builder(1024);
177 auto offset = Net::Pack(builder, data.get());
178 builder.Finish(offset);
179 std::ofstream os(argv[3]);
180 os.write((const char*)builder.GetBufferPointer(), builder.GetSize());
181 }
182
183 return 0;
184 }
185