1 //
2 //  Interpreter.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/07/30.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <math.h>
10 #include <stdio.h>
11 #include <MNN/Interpreter.hpp>
12 #include <algorithm>
13 #include <mutex>
14 #include <vector>
15 #include "MNN_generated.h"
16 #include "core/AutoStorage.h"
17 #include "core/FileLoader.hpp"
18 #include "core/Pipeline.hpp"
19 #include "core/RuntimeFactory.hpp"
20 #include "core/Session.hpp"
21 
22 namespace MNN {
23 
24 struct Content {
25     AutoStorage<uint8_t> buffer;
26     const Net* net = nullptr;
27     std::vector<std::unique_ptr<Session>> sessions;
28     std::map<const Tensor*, const Session*> tensorMap;
29     Interpreter::SessionMode callBackMode = Interpreter::Session_Debug;
30     Interpreter::SessionMode inputMode    = Interpreter::Session_Input_Inside;
31     AutoStorage<uint8_t> cacheBuffer;
32     size_t cacheOffset = 0;
33     std::string cacheFile;
34     std::mutex lock;
35 };
36 
createFromFile(const char * file)37 Interpreter* Interpreter::createFromFile(const char* file) {
38     if (nullptr == file) {
39         MNN_PRINT("NULL file for create interpreter\n");
40         return nullptr;
41     }
42     std::unique_ptr<FileLoader> loader(new FileLoader(file));
43     if (!loader->valid()) {
44         MNN_PRINT("Create interpreter failed, open %s error\n", file);
45         return nullptr;
46     }
47     bool result = loader->read();
48     if (!result) {
49         MNN_PRINT("Read file error\n");
50         return nullptr;
51     }
52     if (loader->size() == 0) {
53         MNN_PRINT("Create interpreter failed, %s is empty\n", file);
54         return nullptr;
55     }
56     auto net     = new Content;
57     bool success = loader->merge(net->buffer);
58     if (!success) {
59         return nullptr;
60     }
61     loader.reset();
62     return createFromBufferInternal(net);
63 }
createFromBuffer(const void * buffer,size_t size)64 Interpreter* Interpreter::createFromBuffer(const void* buffer, size_t size) {
65     if (nullptr == buffer || 0 == size) {
66         MNN_PRINT("Buffer is null for create interpreter\n");
67         return nullptr;
68     }
69     auto net = new Content;
70     net->buffer.reset((int)size);
71     if (nullptr == net->buffer.get()) {
72         MNN_ERROR("Memory not enought!\n");
73         return nullptr;
74     }
75     ::memcpy(net->buffer.get(), buffer, size);
76 
77     return createFromBufferInternal(net);
78 }
79 
createFromBufferInternal(Content * net)80 Interpreter* Interpreter::createFromBufferInternal(Content* net) {
81     if (nullptr == net) {
82         MNN_PRINT("Buffer is null for create interpreter\n");
83         return nullptr;
84     }
85 #ifndef MNN_BUILD_MINI
86     flatbuffers::Verifier verify((const uint8_t*)(net->buffer.get()), net->buffer.size());
87     if (false == VerifyNetBuffer(verify)) {
88         MNN_PRINT("Invalidate buffer to create interpreter\n");
89         delete net;
90         return nullptr;
91     }
92 #endif
93     net->net = GetNet(net->buffer.get());
94     if (nullptr == net->net->oplists()) {
95         MNN_ERROR("Model has no oplist\n");
96         delete net;
97         return nullptr;
98     }
99     int opSize = net->net->oplists()->size();
100     for (int i = 0; i < opSize; ++i) {
101         auto op = net->net->oplists()->GetAs<Op>(i);
102         if (nullptr == op || nullptr == op->outputIndexes()) {
103             MNN_ERROR("Invalid Model, the %d op is empty\n", i);
104             delete net;
105             return nullptr;
106         }
107     }
108     return new Interpreter(net);
109 }
110 
setSessionMode(SessionMode mode)111 void Interpreter::setSessionMode(SessionMode mode) {
112     if (mode == Session_Input_Inside || mode == Session_Input_User) {
113         mNet->inputMode = mode;
114     } else {
115         mNet->callBackMode = mode;
116     }
117 }
118 
setCacheFile(const char * cacheFile,size_t keySize)119 void Interpreter::setCacheFile(const char* cacheFile, size_t keySize) {
120     if (nullptr == cacheFile || nullptr == mNet->buffer.get()) {
121         MNN_ERROR("Empty cacheFile or the interpreter invalid\n");
122         return;
123     }
124     mNet->cacheFile   = std::string(cacheFile);
125     mNet->cacheOffset = mNet->buffer.size() > keySize ? keySize : mNet->buffer.size();
126     std::unique_ptr<FileLoader> loader(new FileLoader(cacheFile));
127     if (!loader->valid()) {
128         MNN_ERROR("Load Cache file error.\n");
129         return;
130     }
131     bool result = loader->read();
132     if (!result) {
133         MNN_ERROR("Load Cache file error.\n");
134         return;
135     }
136     if (loader->size() == 0) {
137         MNN_ERROR("Load Cache file error.\n");
138         return;
139     }
140     bool success = loader->merge(mNet->cacheBuffer);
141     if (!success) {
142         MNN_ERROR("Alloc memory for Cache error.\n");
143         return;
144     }
145     if (0 != ::memcmp(mNet->cacheBuffer.get(), mNet->buffer.get(), mNet->cacheOffset)) {
146         MNN_ERROR("Cache model file key does not match.\n");
147         mNet->cacheBuffer.release();
148         return;
149     }
150 }
151 
Interpreter(Content * net)152 Interpreter::Interpreter(Content* net) {
153     MNN_ASSERT(nullptr != net);
154     mNet = net;
155 }
156 
~Interpreter()157 Interpreter::~Interpreter() {
158     {
159         // If the session is running, we must not delete session
160         std::unique_lock<std::mutex> _l(mNet->lock);
161         mNet->sessions.clear();
162         mNet->tensorMap.clear();
163     }
164     delete mNet;
165 }
166 
createMultiPathSession(const std::vector<ScheduleConfig> & configs)167 Session* Interpreter::createMultiPathSession(const std::vector<ScheduleConfig>& configs) {
168     RuntimeInfo runtime = createRuntime(configs);
169     if (runtime.first.empty()) {
170         MNN_ERROR("Runtime not valid for create session\n");
171         return nullptr;
172     }
173     return createMultiPathSession(configs, std::move(runtime));
174 }
175 
createMultiPathSession(const std::vector<ScheduleConfig> & configs,const RuntimeInfo & runtime)176 Session* Interpreter::createMultiPathSession(const std::vector<ScheduleConfig>& configs, const RuntimeInfo& runtime) {
177     if (nullptr == mNet->buffer.get()) {
178         MNN_ERROR("The model buffer has been released. Can't create session\n");
179         return nullptr;
180     }
181     if (runtime.first.empty()) {
182         MNN_ERROR("Runtime not valid for create session\n");
183         return nullptr;
184     }
185     std::unique_lock<std::mutex> _l(mNet->lock);
186     auto info           = Schedule::schedule(mNet->net, configs);
187     auto validForResize = info.validForResize;
188     RuntimeInfo rt = runtime;
189     auto newSession =
190         std::unique_ptr<Session>(new Session(std::move(info), mNet->callBackMode, mNet->inputMode, std::move(rt)));
191     if (!newSession->valid()) {
192         MNN_PRINT("Invalide Session!!\n");
193         return nullptr;
194     }
195     auto result = newSession.get();
196     bool valid  = false;
197     if (mNet->cacheBuffer.get() != nullptr) {
198         valid = result->loadCache(mNet->cacheBuffer.get() + mNet->cacheOffset,
199                                   mNet->cacheBuffer.size() - mNet->cacheOffset);
200         if(!valid) {
201             // Reset cache
202             result->loadCache(nullptr, 0);
203             MNN_PRINT("Cache invalid, will be reset\n");
204         }
205     }
206     if (validForResize && mNet->inputMode == Session_Input_Inside) {
207         result->resize(mNet->net->usage() == Usage_INFERENCE_STATIC);
208     }
209     if ((!mNet->cacheFile.empty()) && (!valid)) {
210         // Try to save extra cache
211         auto res = result->getCache();
212         if (res.first != nullptr && res.second > 0) {
213             do {
214                 MNN_PRINT("Write cache to %s, size = %zu\n", mNet->cacheFile.c_str(), res.second);
215                 FILE* f = fopen(mNet->cacheFile.c_str(), "wb");
216                 if (nullptr == f) {
217                     MNN_ERROR("Open %s error\n", mNet->cacheFile.c_str());
218                     break;
219                 }
220                 // Write key
221                 auto tsize = fwrite((const char*)mNet->buffer.get(), 1, mNet->cacheOffset, f);
222                 if (tsize != mNet->cacheOffset) {
223                     MNN_ERROR("Write %s error\n", mNet->cacheFile.c_str());
224                     break;
225                 }
226                 // Write Cache
227                 static const size_t block = 4096;
228                 size_t totalSize          = res.second;
229                 size_t blockSize          = UP_DIV(totalSize, block);
230                 for (size_t i = 0; i < blockSize; ++i) {
231                     size_t sta = block * i;
232                     size_t fin = std::min(sta + block, totalSize);
233                     if (fin > sta) {
234                         auto realSize = fwrite((const char*)(res.first) + sta, 1, fin - sta, f);
235                         if (realSize != fin - sta) {
236                             MNN_ERROR("Write %s error\n", mNet->cacheFile.c_str());
237                             break;
238                         }
239                     }
240                 }
241                 fclose(f);
242             } while (false);
243         }
244     }
245     // Reset cache
246     result->loadCache(nullptr, 0);
247 
248     mNet->sessions.emplace_back(std::move(newSession));
249     return result;
250 }
251 
createSession(const ScheduleConfig & config)252 Session* Interpreter::createSession(const ScheduleConfig& config) {
253     return createMultiPathSession({config});
254 }
255 
createSession(const ScheduleConfig & config,const RuntimeInfo & runtime)256 Session* Interpreter::createSession(const ScheduleConfig& config, const RuntimeInfo& runtime) {
257     return createMultiPathSession({config}, runtime);
258 }
259 
releaseSession(Session * session)260 bool Interpreter::releaseSession(Session* session) {
261     std::unique_lock<std::mutex> _l(mNet->lock);
262     for (auto iter = mNet->sessions.begin(); iter != mNet->sessions.end(); iter++) {
263         // TODO Delete tensormap
264         for (auto tIter = mNet->tensorMap.begin(); tIter != mNet->tensorMap.end();) {
265             if (tIter->second == session) {
266                 tIter = mNet->tensorMap.erase(tIter);
267                 continue;
268             }
269             tIter++;
270         }
271 
272         if ((*iter).get() == session) {
273             mNet->sessions.erase(iter);
274             return true;
275         }
276     }
277     return false;
278 }
279 
runSession(Session * session) const280 ErrorCode Interpreter::runSession(Session* session) const {
281     return session->run();
282 }
283 
getSessionInput(const Session * session,const char * name)284 Tensor* Interpreter::getSessionInput(const Session* session, const char* name) {
285     if (session == nullptr) {
286         return nullptr;
287     }
288     std::unique_lock<std::mutex> _l(mNet->lock);
289     auto tensor = session->getInput(name);
290     mNet->tensorMap.insert(std::make_pair(tensor, session));
291     return tensor;
292 }
293 
getSessionOutput(const Session * session,const char * name)294 Tensor* Interpreter::getSessionOutput(const Session* session, const char* name) {
295     if (session == nullptr) {
296         return nullptr;
297     }
298     std::unique_lock<std::mutex> _l(mNet->lock);
299     auto tensor = session->getOutput(name);
300     mNet->tensorMap.insert(std::make_pair(tensor, session));
301     return tensor;
302 }
303 
getSessionInputAll(const Session * session) const304 const std::map<std::string, Tensor*>& Interpreter::getSessionInputAll(const Session* session) const {
305     std::unique_lock<std::mutex> _l(mNet->lock);
306     auto& tensors = session->getInputAll();
307     for (auto& iter : tensors) {
308         mNet->tensorMap.insert(std::make_pair(iter.second, session));
309     }
310     return tensors;
311 }
312 
getSessionOutputAll(const Session * session) const313 const std::map<std::string, Tensor*>& Interpreter::getSessionOutputAll(const Session* session) const {
314     std::unique_lock<std::mutex> _l(mNet->lock);
315     auto& tensors = session->getOutputAll();
316     for (auto& iter : tensors) {
317         mNet->tensorMap.insert(std::make_pair(iter.second, session));
318     }
319     return tensors;
320 }
321 
resizeSession(Session * session)322 void Interpreter::resizeSession(Session* session) {
323     std::unique_lock<std::mutex> _l(mNet->lock);
324     if (mNet->buffer.get() == nullptr) {
325         MNN_ERROR("The model buffer has been released. Can't resize session\n");
326         return;
327     }
328     session->resize();
329 }
330 
runSessionWithCallBack(const Session * session,const TensorCallBack & before,const TensorCallBack & after,bool sync) const331 ErrorCode Interpreter::runSessionWithCallBack(const Session* session, const TensorCallBack& before,
332                                               const TensorCallBack& after, bool sync) const {
333     auto beforeWrap = [&before](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
334         return before(tensors, info->name());
335     };
336     auto afterWrap = [&after](const std::vector<Tensor*>& tensors, const OperatorInfo* info) {
337         return after(tensors, info->name());
338     };
339     return runSessionWithCallBackInfo(session, beforeWrap, afterWrap, sync);
340 }
341 
runSessionWithCallBackInfo(const Session * session,const TensorCallBackWithInfo & before,const TensorCallBackWithInfo & callBack,bool sync) const342 ErrorCode Interpreter::runSessionWithCallBackInfo(const Session* session, const TensorCallBackWithInfo& before,
343                                                   const TensorCallBackWithInfo& callBack, bool sync) const {
344     return session->runWithCallBack(before, callBack, sync);
345 }
346 
getBackend(const Session * session,const Tensor * tensor) const347 const Backend* Interpreter::getBackend(const Session* session, const Tensor* tensor) const {
348     return session->getBackEnd(tensor);
349 }
350 
releaseModel()351 void Interpreter::releaseModel() {
352     std::unique_lock<std::mutex> _l(mNet->lock);
353     if (mNet->buffer.get() != nullptr && mNet->net->usage() != Usage_INFERENCE_STATIC) {
354         mNet->buffer.release();
355     }
356     mNet->cacheBuffer.release();
357 }
358 
resizeTensor(Tensor * tensor,int batch,int channel,int height,int width)359 void Interpreter::resizeTensor(Tensor* tensor, int batch, int channel, int height, int width) {
360     if (tensor->getDimensionType() == Tensor::TENSORFLOW) {
361         resizeTensor(tensor, {batch, height, width, channel});
362     } else {
363         resizeTensor(tensor, {batch, channel, height, width});
364     }
365 }
366 
resizeTensor(Tensor * tensor,const std::vector<int> & dims)367 void Interpreter::resizeTensor(Tensor* tensor, const std::vector<int>& dims) {
368     std::unique_lock<std::mutex> _l(mNet->lock);
369     MNN_ASSERT(nullptr != tensor);
370     bool dirty = false;
371     if (tensor->buffer().dimensions != dims.size()) {
372         dirty = true;
373     } else {
374         for (int i = 0; i < dims.size(); ++i) {
375             if (tensor->buffer().dim[i].extent != dims[i]) {
376                 dirty = true;
377                 break;
378             }
379         }
380     }
381 
382     if (!dirty) {
383         return;
384     }
385 
386     tensor->buffer().dimensions = (int)dims.size();
387     for (int i = 0; i < dims.size(); ++i) {
388         tensor->buffer().dim[i].extent = dims[i];
389     }
390 
391     auto relatedSessionIter = mNet->tensorMap.find(tensor);
392     MNN_ASSERT(relatedSessionIter != mNet->tensorMap.end());
393     ((MNN::Session*)relatedSessionIter->second)->setNeedResize();
394 }
395 
bizCode() const396 const char* Interpreter::bizCode() const {
397     const flatbuffers::String* code = mNet->net->bizCode();
398     return code->c_str();
399 }
400 
getModelBuffer() const401 std::pair<const void*, size_t> Interpreter::getModelBuffer() const {
402     return std::make_pair(mNet->buffer.get(), mNet->buffer.size());
403 }
updateSessionToModel(Session * session)404 ErrorCode Interpreter::updateSessionToModel(Session* session) {
405     std::unique_lock<std::mutex> _l(mNet->lock);
406     if (mNet->buffer.get() == nullptr) {
407         MNN_ERROR("Can't updateSessionToModel because you called releaseModel before\n");
408         return INPUT_DATA_ERROR;
409     }
410     return session->updateToModel((Net*)mNet->net);
411 }
412 
getSessionInfo(const Session * session,SessionInfoCode code,void * ptr)413 bool Interpreter::getSessionInfo(const Session* session, SessionInfoCode code, void* ptr) {
414     std::unique_lock<std::mutex> _l(mNet->lock);
415     if (nullptr == session || nullptr == ptr) {
416         return false;
417     }
418     return session->getInfo(code, ptr);
419 }
420 
_getDefaultBackend(RuntimeInfo & rt)421 static void _getDefaultBackend(RuntimeInfo& rt) {
422     auto defaultType = MNN_FORWARD_CPU;
423     if (rt.first.find(defaultType) != rt.first.end()) {
424         rt.second = rt.first[defaultType];
425     }
426     if (rt.second == nullptr) {
427         Backend::Info info;
428         info.type      = defaultType;
429         info.numThread = 1;
430         rt.second.reset(RuntimeFactory::create(info));
431     }
432 }
createRuntime(const std::vector<ScheduleConfig> & configs)433 RuntimeInfo Interpreter::createRuntime(const std::vector<ScheduleConfig>& configs) {
434     RuntimeInfo res;
435     auto& mRuntimes = res.first;
436     for (auto& config : configs) {
437         Backend::Info compute;
438         compute.type      = Schedule::getApprociateType(config);
439         compute.numThread = config.numThread;
440         compute.user      = config.backendConfig;
441         if (mRuntimes.find(compute.type) == mRuntimes.end()) {
442             auto newBn = RuntimeFactory::create(compute);
443             if (nullptr == newBn) {
444                 MNN_ERROR("Can't create Runtime: %s\n", EnumNameForwardType((ForwardType)compute.type));
445                 continue;
446             }
447             mRuntimes[compute.type].reset(newBn);
448         }
449     }
450     _getDefaultBackend(res);
451     return res;
452 }
453 
454 } // namespace MNN
455