1 //
2 // Interpreter.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/07/30.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include <math.h>
10 #include <stdio.h>
11 #include <MNN/Interpreter.hpp>
12 #include <algorithm>
13 #include <mutex>
14 #include <vector>
15 #include "MNN_generated.h"
16 #include "core/AutoStorage.h"
17 #include "core/FileLoader.hpp"
18 #include "core/Pipeline.hpp"
19 #include "core/RuntimeFactory.hpp"
20 #include "core/Session.hpp"
21
22 namespace MNN {
23
24 struct Content {
25 AutoStorage<uint8_t> buffer;
26 const Net* net = nullptr;
27 std::vector<std::unique_ptr<Session>> sessions;
28 std::map<const Tensor*, const Session*> tensorMap;
29 Interpreter::SessionMode callBackMode = Interpreter::Session_Debug;
30 Interpreter::SessionMode inputMode = Interpreter::Session_Input_Inside;
31 AutoStorage<uint8_t> cacheBuffer;
32 size_t cacheOffset = 0;
33 std::string cacheFile;
34 std::mutex lock;
35 };
36
createFromFile(const char * file)37 Interpreter* Interpreter::createFromFile(const char* file) {
38 if (nullptr == file) {
39 MNN_PRINT("NULL file for create interpreter\n");
40 return nullptr;
41 }
42 std::unique_ptr<FileLoader> loader(new FileLoader(file));
43 if (!loader->valid()) {
44 MNN_PRINT("Create interpreter failed, open %s error\n", file);
45 return nullptr;
46 }
47 bool result = loader->read();
48 if (!result) {
49 MNN_PRINT("Read file error\n");
50 return nullptr;
51 }
52 if (loader->size() == 0) {
53 MNN_PRINT("Create interpreter failed, %s is empty\n", file);
54 return nullptr;
55 }
56 auto net = new Content;
57 bool success = loader->merge(net->buffer);
58 if (!success) {
59 return nullptr;
60 }
61 loader.reset();
62 return createFromBufferInternal(net);
63 }
createFromBuffer(const void * buffer,size_t size)64 Interpreter* Interpreter::createFromBuffer(const void* buffer, size_t size) {
65 if (nullptr == buffer || 0 == size) {
66 MNN_PRINT("Buffer is null for create interpreter\n");
67 return nullptr;
68 }
69 auto net = new Content;
70 net->buffer.reset((int)size);
71 if (nullptr == net->buffer.get()) {
72 MNN_ERROR("Memory not enought!\n");
73 return nullptr;
74 }
75 ::memcpy(net->buffer.get(), buffer, size);
76
77 return createFromBufferInternal(net);
78 }
79
createFromBufferInternal(Content * net)80 Interpreter* Interpreter::createFromBufferInternal(Content* net) {
81 if (nullptr == net) {
82 MNN_PRINT("Buffer is null for create interpreter\n");
83 return nullptr;
84 }
85 #ifndef MNN_BUILD_MINI
86 flatbuffers::Verifier verify((const uint8_t*)(net->buffer.get()), net->buffer.size());
87 if (false == VerifyNetBuffer(verify)) {
88 MNN_PRINT("Invalidate buffer to create interpreter\n");
89 delete net;
90 return nullptr;
91 }
92 #endif
93 net->net = GetNet(net->buffer.get());
94 if (nullptr == net->net->oplists()) {
95 MNN_ERROR("Model has no oplist\n");
96 delete net;
97 return nullptr;
98 }
99 int opSize = net->net->oplists()->size();
100 for (int i = 0; i < opSize; ++i) {
101 auto op = net->net->oplists()->GetAs<Op>(i);
102 if (nullptr == op || nullptr == op->outputIndexes()) {
103 MNN_ERROR("Invalid Model, the %d op is empty\n", i);
104 delete net;
105 return nullptr;
106 }
107 }
108 return new Interpreter(net);
109 }
110
setSessionMode(SessionMode mode)111 void Interpreter::setSessionMode(SessionMode mode) {
112 if (mode == Session_Input_Inside || mode == Session_Input_User) {
113 mNet->inputMode = mode;
114 } else {
115 mNet->callBackMode = mode;
116 }
117 }
118
setCacheFile(const char * cacheFile,size_t keySize)119 void Interpreter::setCacheFile(const char* cacheFile, size_t keySize) {
120 if (nullptr == cacheFile || nullptr == mNet->buffer.get()) {
121 MNN_ERROR("Empty cacheFile or the interpreter invalid\n");
122 return;
123 }
124 mNet->cacheFile = std::string(cacheFile);
125 mNet->cacheOffset = mNet->buffer.size() > keySize ? keySize : mNet->buffer.size();
126 std::unique_ptr<FileLoader> loader(new FileLoader(cacheFile));
127 if (!loader->valid()) {
128 MNN_ERROR("Load Cache file error.\n");
129 return;
130 }
131 bool result = loader->read();
132 if (!result) {
133 MNN_ERROR("Load Cache file error.\n");
134 return;
135 }
136 if (loader->size() == 0) {
137 MNN_ERROR("Load Cache file error.\n");
138 return;
139 }
140 bool success = loader->merge(mNet->cacheBuffer);
141 if (!success) {
142 MNN_ERROR("Alloc memory for Cache error.\n");
143 return;
144 }
145 if (0 != ::memcmp(mNet->cacheBuffer.get(), mNet->buffer.get(), mNet->cacheOffset)) {
146 MNN_ERROR("Cache model file key does not match.\n");
147 mNet->cacheBuffer.release();
148 return;
149 }
150 }
151
Interpreter(Content * net)152 Interpreter::Interpreter(Content* net) {
153 MNN_ASSERT(nullptr != net);
154 mNet = net;
155 }
156
~Interpreter()157 Interpreter::~Interpreter() {
158 {
159 // If the session is running, we must not delete session
160 std::unique_lock<std::mutex> _l(mNet->lock);
161 mNet->sessions.clear();
162 mNet->tensorMap.clear();
163 }
164 delete mNet;
165 }
166
createMultiPathSession(const std::vector<ScheduleConfig> & configs)167 Session* Interpreter::createMultiPathSession(const std::vector<ScheduleConfig>& configs) {
168 RuntimeInfo runtime = createRuntime(configs);
169 if (runtime.first.empty()) {
170 MNN_ERROR("Runtime not valid for create session\n");
171 return nullptr;
172 }
173 return createMultiPathSession(configs, std::move(runtime));
174 }
175
createMultiPathSession(const std::vector<ScheduleConfig> & configs,const RuntimeInfo & runtime)176 Session* Interpreter::createMultiPathSession(const std::vector<ScheduleConfig>& configs, const RuntimeInfo& runtime) {
177 if (nullptr == mNet->buffer.get()) {
178 MNN_ERROR("The model buffer has been released. Can't create session\n");
179 return nullptr;
180 }
181 if (runtime.first.empty()) {
182 MNN_ERROR("Runtime not valid for create session\n");
183 return nullptr;
184 }
185 std::unique_lock<std::mutex> _l(mNet->lock);
186 auto info = Schedule::schedule(mNet->net, configs);
187 auto validForResize = info.validForResize;
188 RuntimeInfo rt = runtime;
189 auto newSession =
190 std::unique_ptr<Session>(new Session(std::move(info), mNet->callBackMode, mNet->inputMode, std::move(rt)));
191 if (!newSession->valid()) {
192 MNN_PRINT("Invalide Session!!\n");
193 return nullptr;
194 }
195 auto result = newSession.get();
196 bool valid = false;
197 if (mNet->cacheBuffer.get() != nullptr) {
198 valid = result->loadCache(mNet->cacheBuffer.get() + mNet->cacheOffset,
199 mNet->cacheBuffer.size() - mNet->cacheOffset);
200 if(!valid) {
201 // Reset cache
202 result->loadCache(nullptr, 0);
203 MNN_PRINT("Cache invalid, will be reset\n");
204 }
205 }
206 if (validForResize && mNet->inputMode == Session_Input_Inside) {
207 result->resize(mNet->net->usage() == Usage_INFERENCE_STATIC);
208 }
209 if ((!mNet->cacheFile.empty()) && (!valid)) {
210 // Try to save extra cache
211 auto res = result->getCache();
212 if (res.first != nullptr && res.second > 0) {
213 do {
214 MNN_PRINT("Write cache to %s, size = %zu\n", mNet->cacheFile.c_str(), res.second);
215 FILE* f = fopen(mNet->cacheFile.c_str(), "wb");
216 if (nullptr == f) {
217 MNN_ERROR("Open %s error\n", mNet->cacheFile.c_str());
218 break;
219 }
220 // Write key
221 auto tsize = fwrite((const char*)mNet->buffer.get(), 1, mNet->cacheOffset, f);
222 if (tsize != mNet->cacheOffset) {
223 MNN_ERROR("Write %s error\n", mNet->cacheFile.c_str());
224 break;
225 }
226 // Write Cache
227 static const size_t block = 4096;
228 size_t totalSize = res.second;
229 size_t blockSize = UP_DIV(totalSize, block);
230 for (size_t i = 0; i < blockSize; ++i) {
231 size_t sta = block * i;
232 size_t fin = std::min(sta + block, totalSize);
233 if (fin > sta) {
234 auto realSize = fwrite((const char*)(res.first) + sta, 1, fin - sta, f);
235 if (realSize != fin - sta) {
236 MNN_ERROR("Write %s error\n", mNet->cacheFile.c_str());
237 break;
238 }
239 }
240 }
241 fclose(f);
242 } while (false);
243 }
244 }
245 // Reset cache
246 result->loadCache(nullptr, 0);
247
248 mNet->sessions.emplace_back(std::move(newSession));
249 return result;
250 }
251
createSession(const ScheduleConfig & config)252 Session* Interpreter::createSession(const ScheduleConfig& config) {
253 return createMultiPathSession({config});
254 }
255
createSession(const ScheduleConfig & config,const RuntimeInfo & runtime)256 Session* Interpreter::createSession(const ScheduleConfig& config, const RuntimeInfo& runtime) {
257 return createMultiPathSession({config}, runtime);
258 }
259
releaseSession(Session * session)260 bool Interpreter::releaseSession(Session* session) {
261 std::unique_lock<std::mutex> _l(mNet->lock);
262 for (auto iter = mNet->sessions.begin(); iter != mNet->sessions.end(); iter++) {
263 // TODO Delete tensormap
264 for (auto tIter = mNet->tensorMap.begin(); tIter != mNet->tensorMap.end();) {
265 if (tIter->second == session) {
266 tIter = mNet->tensorMap.erase(tIter);
267 continue;
268 }
269 tIter++;
270 }
271
272 if ((*iter).get() == session) {
273 mNet->sessions.erase(iter);
274 return true;
275 }
276 }
277 return false;
278 }
279
runSession(Session * session) const280 ErrorCode Interpreter::runSession(Session* session) const {
281 return session->run();
282 }
283
getSessionInput(const Session * session,const char * name)284 Tensor* Interpreter::getSessionInput(const Session* session, const char* name) {
285 if (session == nullptr) {
286 return nullptr;
287 }
288 std::unique_lock<std::mutex> _l(mNet->lock);
289 auto tensor = session->getInput(name);
290 mNet->tensorMap.insert(std::make_pair(tensor, session));
291 return tensor;
292 }
293
getSessionOutput(const Session * session,const char * name)294 Tensor* Interpreter::getSessionOutput(const Session* session, const char* name) {
295 if (session == nullptr) {
296 return nullptr;
297 }
298 std::unique_lock<std::mutex> _l(mNet->lock);
299 auto tensor = session->getOutput(name);
300 mNet->tensorMap.insert(std::make_pair(tensor, session));
301 return tensor;
302 }
303
getSessionInputAll(const Session * session) const304 const std::map<std::string, Tensor*>& Interpreter::getSessionInputAll(const Session* session) const {
305 std::unique_lock<std::mutex> _l(mNet->lock);
306 auto& tensors = session->getInputAll();
307 for (auto& iter : tensors) {
308 mNet->tensorMap.insert(std::make_pair(iter.second, session));
309 }
310 return tensors;
311 }
312
getSessionOutputAll(const Session * session) const313 const std::map<std::string, Tensor*>& Interpreter::getSessionOutputAll(const Session* session) const {
314 std::unique_lock<std::mutex> _l(mNet->lock);
315 auto& tensors = session->getOutputAll();
316 for (auto& iter : tensors) {
317 mNet->tensorMap.insert(std::make_pair(iter.second, session));
318 }
319 return tensors;
320 }
321
resizeSession(Session * session)322 void Interpreter::resizeSession(Session* session) {
323 std::unique_lock<std::mutex> _l(mNet->lock);
324 if (mNet->buffer.get() == nullptr) {
325 MNN_ERROR("The model buffer has been released. Can't resize session\n");
326 return;
327 }
328 session->resize();
329 }
330
runSessionWithCallBack(const Session * session,const TensorCallBack & before,const TensorCallBack & after,bool sync) const331 ErrorCode Interpreter::runSessionWithCallBack(const Session* session, const TensorCallBack& before,
332 const TensorCallBack& after, bool sync) const {
333 auto beforeWrap = [&before](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
334 return before(tensors, info->name());
335 };
336 auto afterWrap = [&after](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
337 return after(tensors, info->name());
338 };
339 return runSessionWithCallBackInfo(session, beforeWrap, afterWrap, sync);
340 }
341
runSessionWithCallBackInfo(const Session * session,const TensorCallBackWithInfo & before,const TensorCallBackWithInfo & callBack,bool sync) const342 ErrorCode Interpreter::runSessionWithCallBackInfo(const Session* session, const TensorCallBackWithInfo& before,
343 const TensorCallBackWithInfo& callBack, bool sync) const {
344 return session->runWithCallBack(before, callBack, sync);
345 }
346
getBackend(const Session * session,const Tensor * tensor) const347 const Backend* Interpreter::getBackend(const Session* session, const Tensor* tensor) const {
348 return session->getBackEnd(tensor);
349 }
350
releaseModel()351 void Interpreter::releaseModel() {
352 std::unique_lock<std::mutex> _l(mNet->lock);
353 if (mNet->buffer.get() != nullptr && mNet->net->usage() != Usage_INFERENCE_STATIC) {
354 mNet->buffer.release();
355 }
356 mNet->cacheBuffer.release();
357 }
358
resizeTensor(Tensor * tensor,int batch,int channel,int height,int width)359 void Interpreter::resizeTensor(Tensor* tensor, int batch, int channel, int height, int width) {
360 if (tensor->getDimensionType() == Tensor::TENSORFLOW) {
361 resizeTensor(tensor, {batch, height, width, channel});
362 } else {
363 resizeTensor(tensor, {batch, channel, height, width});
364 }
365 }
366
resizeTensor(Tensor * tensor,const std::vector<int> & dims)367 void Interpreter::resizeTensor(Tensor* tensor, const std::vector<int>& dims) {
368 std::unique_lock<std::mutex> _l(mNet->lock);
369 MNN_ASSERT(nullptr != tensor);
370 bool dirty = false;
371 if (tensor->buffer().dimensions != dims.size()) {
372 dirty = true;
373 } else {
374 for (int i = 0; i < dims.size(); ++i) {
375 if (tensor->buffer().dim[i].extent != dims[i]) {
376 dirty = true;
377 break;
378 }
379 }
380 }
381
382 if (!dirty) {
383 return;
384 }
385
386 tensor->buffer().dimensions = (int)dims.size();
387 for (int i = 0; i < dims.size(); ++i) {
388 tensor->buffer().dim[i].extent = dims[i];
389 }
390
391 auto relatedSessionIter = mNet->tensorMap.find(tensor);
392 MNN_ASSERT(relatedSessionIter != mNet->tensorMap.end());
393 ((MNN::Session*)relatedSessionIter->second)->setNeedResize();
394 }
395
bizCode() const396 const char* Interpreter::bizCode() const {
397 const flatbuffers::String* code = mNet->net->bizCode();
398 return code->c_str();
399 }
400
getModelBuffer() const401 std::pair<const void*, size_t> Interpreter::getModelBuffer() const {
402 return std::make_pair(mNet->buffer.get(), mNet->buffer.size());
403 }
updateSessionToModel(Session * session)404 ErrorCode Interpreter::updateSessionToModel(Session* session) {
405 std::unique_lock<std::mutex> _l(mNet->lock);
406 if (mNet->buffer.get() == nullptr) {
407 MNN_ERROR("Can't updateSessionToModel because you called releaseModel before\n");
408 return INPUT_DATA_ERROR;
409 }
410 return session->updateToModel((Net*)mNet->net);
411 }
412
getSessionInfo(const Session * session,SessionInfoCode code,void * ptr)413 bool Interpreter::getSessionInfo(const Session* session, SessionInfoCode code, void* ptr) {
414 std::unique_lock<std::mutex> _l(mNet->lock);
415 if (nullptr == session || nullptr == ptr) {
416 return false;
417 }
418 return session->getInfo(code, ptr);
419 }
420
_getDefaultBackend(RuntimeInfo & rt)421 static void _getDefaultBackend(RuntimeInfo& rt) {
422 auto defaultType = MNN_FORWARD_CPU;
423 if (rt.first.find(defaultType) != rt.first.end()) {
424 rt.second = rt.first[defaultType];
425 }
426 if (rt.second == nullptr) {
427 Backend::Info info;
428 info.type = defaultType;
429 info.numThread = 1;
430 rt.second.reset(RuntimeFactory::create(info));
431 }
432 }
createRuntime(const std::vector<ScheduleConfig> & configs)433 RuntimeInfo Interpreter::createRuntime(const std::vector<ScheduleConfig>& configs) {
434 RuntimeInfo res;
435 auto& mRuntimes = res.first;
436 for (auto& config : configs) {
437 Backend::Info compute;
438 compute.type = Schedule::getApprociateType(config);
439 compute.numThread = config.numThread;
440 compute.user = config.backendConfig;
441 if (mRuntimes.find(compute.type) == mRuntimes.end()) {
442 auto newBn = RuntimeFactory::create(compute);
443 if (nullptr == newBn) {
444 MNN_ERROR("Can't create Runtime: %s\n", EnumNameForwardType((ForwardType)compute.type));
445 continue;
446 }
447 mRuntimes[compute.type].reset(newBn);
448 }
449 }
450 _getDefaultBackend(res);
451 return res;
452 }
453
454 } // namespace MNN
455