1 //
2 //  Session.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/07/30.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "core/Session.hpp"
10 #include <string.h>
11 #include <MNN/AutoTime.hpp>
12 #include <map>
13 #include <set>
14 #include "MNN_generated.h"
15 #include "core/AutoStorage.h"
16 #include "core/RuntimeFactory.hpp"
17 #include "core/TensorUtils.hpp"
18 #include "core/WrapExecution.hpp"
19 
20 using namespace std;
21 
22 namespace MNN {
Session(Schedule::ScheduleInfo && info,Interpreter::SessionMode callBackMode,Interpreter::SessionMode inputMode,RuntimeInfo && runtime)23 Session::Session(Schedule::ScheduleInfo&& info, Interpreter::SessionMode callBackMode,
24                  Interpreter::SessionMode inputMode, RuntimeInfo&& runtime) {
25     mRuntime = std::move(runtime);
26     if (info.pipelineInfo.empty()) {
27         mValid = false;
28         return;
29     }
30     mTensors              = std::move(info.allTensors);
31     for (auto& iter : info.pipelineInfo) {
32         auto rt    = mRuntime.first.find(iter.first.type)->second.get();
33         auto cpuRuntime = mRuntime.second;
34         std::shared_ptr<Backend> first(rt->onCreate(iter.first.user));
35         std::shared_ptr<Backend> second;
36         if (first->type() == MNN_FORWARD_CPU) {
37             second = first;
38         } else {
39             BackendConfig defaultConfig;
40             defaultConfig.flags = 4;
41             second.reset(cpuRuntime->onCreate(&defaultConfig));
42         }
43         std::shared_ptr<Pipeline> newPipeline(new Pipeline(std::move(iter.second), first, second, inputMode == Interpreter::Session_Input_Inside, rt->onGetCompilerType()));
44         mPipelines.emplace_back(std::move(newPipeline));
45     }
46     mInputs       = std::move(info.inputTensors);
47     mOutputs      = std::move(info.outputTensor);
48     mCallBackMode = callBackMode;
49 }
50 
~Session()51 Session::~Session() {
52     for (auto& t : mTensors) {
53         TensorUtils::clearHandleData(t.second.get());
54     }
55     mPipelines.clear();
56     mRuntime.first.clear();
57     mTensors.clear();
58     mRuntime.second = nullptr;
59 }
60 
loadCache(const void * buffer,size_t size)61 bool Session::loadCache(const void* buffer, size_t size) {
62     for (auto iter : mRuntime.first) {
63         auto res = iter.second->onSetCache(buffer, size);
64         if (res) {
65             return true;
66         }
67     }
68     return false;
69 }
70 
getCache()71 std::pair<const void*, size_t> Session::getCache() {
72     for (auto iter : mRuntime.first) {
73         auto res = iter.second->onGetCache();
74         if (res.first != nullptr) {
75             return res;
76         }
77     }
78     return std::make_pair(nullptr, 0);
79 }
cloneExecution(const std::map<const Op *,std::shared_ptr<Execution>> & cache,int pipelineIndex)80 void Session::cloneExecution(const std::map<const Op*, std::shared_ptr<Execution>>& cache, int pipelineIndex) {
81     mPipelines[pipelineIndex]->cloneExecution(cache);
82 }
getExecution(int pipelineIndex)83 const std::map<const Op*, std::shared_ptr<Execution>>& Session::getExecution(int pipelineIndex) {
84     return mPipelines[pipelineIndex]->getCache();
85 }
86 
run() const87 ErrorCode Session::run() const {
88     if (mNeedResize) {
89         MNN_ERROR("Can't run session because not resized\n");
90         return COMPUTE_SIZE_ERROR;
91     }
92     for (auto& iter : mPipelines) {
93         auto error = iter->execute();
94         if (NO_ERROR != error) {
95             return error;
96         }
97     }
98     return NO_ERROR;
99 }
100 
runWithCallBack(const TensorCallBackWithInfo & before,const TensorCallBackWithInfo & end,bool sync) const101 ErrorCode Session::runWithCallBack(const TensorCallBackWithInfo& before, const TensorCallBackWithInfo& end,
102                                    bool sync) const {
103     if (mNeedResize) {
104         MNN_ERROR("Can't run session because not resized\n");
105         return COMPUTE_SIZE_ERROR;
106     }
107     for (auto& iter : mPipelines) {
108         auto error = iter->executeCallBack(before, end);
109         if (NO_ERROR != error) {
110             return error;
111         }
112     }
113     return NO_ERROR;
114 }
115 
_clearCache()116 void Session::_clearCache() {
117     for (auto& t : mTensors) {
118         auto describe = TensorUtils::getDescribe(t.second.get());
119         TensorUtils::clearHandleData(t.second.get());
120         describe->useCount = 0;
121         describe->backend  = nullptr;
122         describe->regions.clear();
123     }
124 }
125 
resize(bool isStatic)126 ErrorCode Session::resize(bool isStatic) {
127     if (mNeedResize) {
128         if (!isStatic) {
129             _clearCache();
130         }
131         bool debug = mCallBackMode == Interpreter::Session_Debug;
132         for (auto& iter : mPipelines) {
133             auto error = iter->encode(isStatic, debug);
134             if (NO_ERROR != error) {
135                 return error;
136             }
137         }
138         mNeedResize = false;
139         mNeedMalloc = true;
140     }
141     if (mNeedMalloc) {
142         // Set needResize = true for easy for judge in runSession when error
143         mNeedResize = true;
144         // Turn Pipeline to Command Buffer and Malloc resource
145         // TODO: Seperate Schedule and Malloc
146         for (auto& iter : mPipelines) {
147             auto error = iter->allocMemory();
148             if (NO_ERROR != error) {
149                 return error;
150             }
151         }
152         for (auto& iter : mRuntime.first) {
153             iter.second->onGabageCollect(0);
154         }
155         mNeedMalloc = false;
156         mNeedResize = false;
157     }
158     return NO_ERROR;
159 }
getInfo(Interpreter::SessionInfoCode code,void * ptr) const160 bool Session::getInfo(Interpreter::SessionInfoCode code, void* ptr) const {
161     switch (code) {
162         case Interpreter::MEMORY: {
163             auto dst     = (float*)ptr;
164             float summer = mRuntime.second->onGetMemoryInMB();
165             for (auto& r : mRuntime.first) {
166                 if (r.second.get() != mRuntime.second.get()) {
167                     summer += r.second->onGetMemoryInMB();
168                 }
169             }
170             *dst = summer;
171             return true;
172         } break;
173         case Interpreter::BACKENDS: {
174             int pos = 0;
175             auto res = (int32_t*)ptr;
176             for (auto& r : mRuntime.first) {
177                 res[pos++] = r.first;
178             }
179             return true;
180         } break;
181         case Interpreter::FLOPS: {
182             float flo = 0.0f;
183             for (auto& iter : mPipelines) {
184                 flo += iter->flops();
185             }
186             auto dst     = (float*)ptr;
187             *dst = flo;
188             return true;
189         } break;
190         // TODO: Support other debug info
191         default:
192             break;
193     }
194     return false;
195 }
196 
getBackEnd(const Tensor * tensor) const197 const Backend* Session::getBackEnd(const Tensor* tensor) const {
198     return TensorUtils::getDescribe(tensor)->backend;
199 }
200 
getInput(const char * name) const201 Tensor* Session::getInput(const char* name) const {
202     //MNN_ASSERT(!mInputs.empty());
203     if (nullptr == name) {
204         return mInputs.begin()->second;
205     }
206     auto iter = mInputs.find(name);
207     if (iter == mInputs.end()) {
208         MNN_PRINT("Error: can't find input: %s\n", name);
209         return nullptr;
210     }
211     return iter->second;
212 }
213 
getOutput(const char * name) const214 Tensor* Session::getOutput(const char* name) const {
215     MNN_ASSERT(!mOutputs.empty());
216     if (nullptr == name) {
217         return mOutputs.begin()->second;
218     }
219 
220     auto iter = mOutputs.find(name);
221     if (iter == mOutputs.end()) {
222         MNN_PRINT("Error: can't find output: %s\n", name);
223         return nullptr;
224     }
225     return iter->second;
226 }
227 
getInputAll() const228 const std::map<std::string, Tensor*>& Session::getInputAll() const {
229     return mInputs;
230 }
231 
getOutputAll() const232 const std::map<std::string, Tensor*>& Session::getOutputAll() const {
233     return mOutputs;
234 }
235 
updateToModel(Net * net) const236 ErrorCode Session::updateToModel(Net* net) const {
237     if (mNeedResize) {
238         return NOT_SUPPORT;
239     }
240     int opSize = net->oplists()->size();
241     for (int i = 0; i < opSize; ++i) {
242         auto op = net->oplists()->GetAs<Op>(i);
243         if ((net->usage() == Usage_INFERENCE || net->usage() == Usage_INFERENCE_STATIC) && op->type() != OpType_Const) {
244             continue;
245         }
246         if (net->usage() == Usage_TRAIN && op->type() != OpType_TrainableParam) {
247             continue;
248         }
249         if (!op->outputIndexes() || op->outputIndexes()->size() != 1) {
250             continue;
251         }
252         auto index = op->outputIndexes()->data()[0];
253         auto blob  = op->main_as_Blob();
254         if (blob->dataType() != DataType_DT_FLOAT) {
255             continue;
256         }
257         std::shared_ptr<Tensor> tensor = mTensors[index].second;
258         if (tensor->host<void>() == nullptr && tensor->deviceId() != 0) {
259             tensor.reset(Tensor::createHostTensorFromDevice(tensor.get(), true));
260             if (tensor.get() == nullptr) {
261                 MNN_ERROR("failed to copy trained param from device to host\n");
262                 return INVALID_VALUE;
263             }
264         }
265         ::memcpy((void*)blob->float32s()->data(), tensor->host<float>(), tensor->size());
266     }
267 
268     return NO_ERROR;
269 }
270 
271 } // namespace MNN
272