1 //
2 // Session.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/07/30.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "core/Session.hpp"
10 #include <string.h>
11 #include <MNN/AutoTime.hpp>
12 #include <map>
13 #include <set>
14 #include "MNN_generated.h"
15 #include "core/AutoStorage.h"
16 #include "core/RuntimeFactory.hpp"
17 #include "core/TensorUtils.hpp"
18 #include "core/WrapExecution.hpp"
19
20 using namespace std;
21
22 namespace MNN {
Session(Schedule::ScheduleInfo && info,Interpreter::SessionMode callBackMode,Interpreter::SessionMode inputMode,RuntimeInfo && runtime)23 Session::Session(Schedule::ScheduleInfo&& info, Interpreter::SessionMode callBackMode,
24 Interpreter::SessionMode inputMode, RuntimeInfo&& runtime) {
25 mRuntime = std::move(runtime);
26 if (info.pipelineInfo.empty()) {
27 mValid = false;
28 return;
29 }
30 mTensors = std::move(info.allTensors);
31 for (auto& iter : info.pipelineInfo) {
32 auto rt = mRuntime.first.find(iter.first.type)->second.get();
33 auto cpuRuntime = mRuntime.second;
34 std::shared_ptr<Backend> first(rt->onCreate(iter.first.user));
35 std::shared_ptr<Backend> second;
36 if (first->type() == MNN_FORWARD_CPU) {
37 second = first;
38 } else {
39 BackendConfig defaultConfig;
40 defaultConfig.flags = 4;
41 second.reset(cpuRuntime->onCreate(&defaultConfig));
42 }
43 std::shared_ptr<Pipeline> newPipeline(new Pipeline(std::move(iter.second), first, second, inputMode == Interpreter::Session_Input_Inside, rt->onGetCompilerType()));
44 mPipelines.emplace_back(std::move(newPipeline));
45 }
46 mInputs = std::move(info.inputTensors);
47 mOutputs = std::move(info.outputTensor);
48 mCallBackMode = callBackMode;
49 }
50
~Session()51 Session::~Session() {
52 for (auto& t : mTensors) {
53 TensorUtils::clearHandleData(t.second.get());
54 }
55 mPipelines.clear();
56 mRuntime.first.clear();
57 mTensors.clear();
58 mRuntime.second = nullptr;
59 }
60
loadCache(const void * buffer,size_t size)61 bool Session::loadCache(const void* buffer, size_t size) {
62 for (auto iter : mRuntime.first) {
63 auto res = iter.second->onSetCache(buffer, size);
64 if (res) {
65 return true;
66 }
67 }
68 return false;
69 }
70
getCache()71 std::pair<const void*, size_t> Session::getCache() {
72 for (auto iter : mRuntime.first) {
73 auto res = iter.second->onGetCache();
74 if (res.first != nullptr) {
75 return res;
76 }
77 }
78 return std::make_pair(nullptr, 0);
79 }
cloneExecution(const std::map<const Op *,std::shared_ptr<Execution>> & cache,int pipelineIndex)80 void Session::cloneExecution(const std::map<const Op*, std::shared_ptr<Execution>>& cache, int pipelineIndex) {
81 mPipelines[pipelineIndex]->cloneExecution(cache);
82 }
getExecution(int pipelineIndex)83 const std::map<const Op*, std::shared_ptr<Execution>>& Session::getExecution(int pipelineIndex) {
84 return mPipelines[pipelineIndex]->getCache();
85 }
86
run() const87 ErrorCode Session::run() const {
88 if (mNeedResize) {
89 MNN_ERROR("Can't run session because not resized\n");
90 return COMPUTE_SIZE_ERROR;
91 }
92 for (auto& iter : mPipelines) {
93 auto error = iter->execute();
94 if (NO_ERROR != error) {
95 return error;
96 }
97 }
98 return NO_ERROR;
99 }
100
runWithCallBack(const TensorCallBackWithInfo & before,const TensorCallBackWithInfo & end,bool sync) const101 ErrorCode Session::runWithCallBack(const TensorCallBackWithInfo& before, const TensorCallBackWithInfo& end,
102 bool sync) const {
103 if (mNeedResize) {
104 MNN_ERROR("Can't run session because not resized\n");
105 return COMPUTE_SIZE_ERROR;
106 }
107 for (auto& iter : mPipelines) {
108 auto error = iter->executeCallBack(before, end);
109 if (NO_ERROR != error) {
110 return error;
111 }
112 }
113 return NO_ERROR;
114 }
115
_clearCache()116 void Session::_clearCache() {
117 for (auto& t : mTensors) {
118 auto describe = TensorUtils::getDescribe(t.second.get());
119 TensorUtils::clearHandleData(t.second.get());
120 describe->useCount = 0;
121 describe->backend = nullptr;
122 describe->regions.clear();
123 }
124 }
125
resize(bool isStatic)126 ErrorCode Session::resize(bool isStatic) {
127 if (mNeedResize) {
128 if (!isStatic) {
129 _clearCache();
130 }
131 bool debug = mCallBackMode == Interpreter::Session_Debug;
132 for (auto& iter : mPipelines) {
133 auto error = iter->encode(isStatic, debug);
134 if (NO_ERROR != error) {
135 return error;
136 }
137 }
138 mNeedResize = false;
139 mNeedMalloc = true;
140 }
141 if (mNeedMalloc) {
142 // Set needResize = true for easy for judge in runSession when error
143 mNeedResize = true;
144 // Turn Pipeline to Command Buffer and Malloc resource
145 // TODO: Seperate Schedule and Malloc
146 for (auto& iter : mPipelines) {
147 auto error = iter->allocMemory();
148 if (NO_ERROR != error) {
149 return error;
150 }
151 }
152 for (auto& iter : mRuntime.first) {
153 iter.second->onGabageCollect(0);
154 }
155 mNeedMalloc = false;
156 mNeedResize = false;
157 }
158 return NO_ERROR;
159 }
getInfo(Interpreter::SessionInfoCode code,void * ptr) const160 bool Session::getInfo(Interpreter::SessionInfoCode code, void* ptr) const {
161 switch (code) {
162 case Interpreter::MEMORY: {
163 auto dst = (float*)ptr;
164 float summer = mRuntime.second->onGetMemoryInMB();
165 for (auto& r : mRuntime.first) {
166 if (r.second.get() != mRuntime.second.get()) {
167 summer += r.second->onGetMemoryInMB();
168 }
169 }
170 *dst = summer;
171 return true;
172 } break;
173 case Interpreter::BACKENDS: {
174 int pos = 0;
175 auto res = (int32_t*)ptr;
176 for (auto& r : mRuntime.first) {
177 res[pos++] = r.first;
178 }
179 return true;
180 } break;
181 case Interpreter::FLOPS: {
182 float flo = 0.0f;
183 for (auto& iter : mPipelines) {
184 flo += iter->flops();
185 }
186 auto dst = (float*)ptr;
187 *dst = flo;
188 return true;
189 } break;
190 // TODO: Support other debug info
191 default:
192 break;
193 }
194 return false;
195 }
196
getBackEnd(const Tensor * tensor) const197 const Backend* Session::getBackEnd(const Tensor* tensor) const {
198 return TensorUtils::getDescribe(tensor)->backend;
199 }
200
getInput(const char * name) const201 Tensor* Session::getInput(const char* name) const {
202 //MNN_ASSERT(!mInputs.empty());
203 if (nullptr == name) {
204 return mInputs.begin()->second;
205 }
206 auto iter = mInputs.find(name);
207 if (iter == mInputs.end()) {
208 MNN_PRINT("Error: can't find input: %s\n", name);
209 return nullptr;
210 }
211 return iter->second;
212 }
213
getOutput(const char * name) const214 Tensor* Session::getOutput(const char* name) const {
215 MNN_ASSERT(!mOutputs.empty());
216 if (nullptr == name) {
217 return mOutputs.begin()->second;
218 }
219
220 auto iter = mOutputs.find(name);
221 if (iter == mOutputs.end()) {
222 MNN_PRINT("Error: can't find output: %s\n", name);
223 return nullptr;
224 }
225 return iter->second;
226 }
227
getInputAll() const228 const std::map<std::string, Tensor*>& Session::getInputAll() const {
229 return mInputs;
230 }
231
getOutputAll() const232 const std::map<std::string, Tensor*>& Session::getOutputAll() const {
233 return mOutputs;
234 }
235
updateToModel(Net * net) const236 ErrorCode Session::updateToModel(Net* net) const {
237 if (mNeedResize) {
238 return NOT_SUPPORT;
239 }
240 int opSize = net->oplists()->size();
241 for (int i = 0; i < opSize; ++i) {
242 auto op = net->oplists()->GetAs<Op>(i);
243 if ((net->usage() == Usage_INFERENCE || net->usage() == Usage_INFERENCE_STATIC) && op->type() != OpType_Const) {
244 continue;
245 }
246 if (net->usage() == Usage_TRAIN && op->type() != OpType_TrainableParam) {
247 continue;
248 }
249 if (!op->outputIndexes() || op->outputIndexes()->size() != 1) {
250 continue;
251 }
252 auto index = op->outputIndexes()->data()[0];
253 auto blob = op->main_as_Blob();
254 if (blob->dataType() != DataType_DT_FLOAT) {
255 continue;
256 }
257 std::shared_ptr<Tensor> tensor = mTensors[index].second;
258 if (tensor->host<void>() == nullptr && tensor->deviceId() != 0) {
259 tensor.reset(Tensor::createHostTensorFromDevice(tensor.get(), true));
260 if (tensor.get() == nullptr) {
261 MNN_ERROR("failed to copy trained param from device to host\n");
262 return INVALID_VALUE;
263 }
264 }
265 ::memcpy((void*)blob->float32s()->data(), tensor->host<float>(), tensor->size());
266 }
267
268 return NO_ERROR;
269 }
270
271 } // namespace MNN
272