1 //
2 // WrapExecution.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/09/03.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "core/WrapExecution.hpp"
10 #include "core/TensorUtils.hpp"
11 #include "backend/cpu/CPUBackend.hpp"
12 #include "backend/cpu/compute/CommonOptFunction.h"
13 namespace MNN {
needWrap(const Tensor * input,Backend * curBackend)14 bool WrapExecution::needWrap(const Tensor* input, Backend* curBackend) {
15 if (curBackend->type() == MNN_FORWARD_NN) {
16 return false;
17 }
18 auto des = TensorUtils::getDescribe(input);
19 auto bn = des->backend;
20 MNNForwardType type = MNN_FORWARD_CPU;
21 int pack = 4;
22 int bytes = 4;
23 if (nullptr != bn) {
24 type = bn->type();
25 if (type == MNN_FORWARD_CPU_EXTENSION) {
26 auto core = static_cast<CPUBackend*>(bn)->functions();
27 pack = core->pack;
28 bytes = core->bytes;
29 }
30 }
31 if (type == curBackend->type()) {
32 return false;;
33 }
34 bool srcCpu = (type == MNN_FORWARD_CPU_EXTENSION || type == MNN_FORWARD_CPU);
35 bool dstCpu = ((curBackend->type() == MNN_FORWARD_CPU_EXTENSION) || (curBackend->type() == MNN_FORWARD_CPU));
36 if (srcCpu && dstCpu) {
37 auto dstCore = static_cast<CPUBackend*>(curBackend)->functions();
38 if (dstCore->bytes == bytes) {
39 if (dstCore->pack == pack || des->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) {
40 return false;
41 }
42 }
43 }
44 return true;
45 }
46
WrapExecution(Backend * CPUBackend,std::shared_ptr<Execution> execution,bool isStatic)47 WrapExecution::WrapExecution(Backend* CPUBackend, std::shared_ptr<Execution> execution, bool isStatic)
48 : Execution(execution->backend()), mCPUBackend(CPUBackend), mExecution(execution) {
49 mValid = execution->valid();
50 mStatic = isStatic;
51 }
52
_getCopyTensor(Tensor * inputTensor)53 Tensor* WrapExecution::_getCopyTensor(Tensor* inputTensor) {
54 auto dstBackend = mExecution->backend();
55 auto inputDes = TensorUtils::getDescribe(inputTensor);
56 auto srcBackend = inputDes->backend;
57 if (nullptr == srcBackend) {
58 srcBackend = mCPUBackend;
59 }
60 // CPU -> CPU or XPU -> XPU
61 //if (srcBackend == dstBackend) {
62 if (srcBackend->type() == dstBackend->type()) {
63 return inputTensor;
64 }
65 auto iter = mInputMaps.find(inputTensor);
66 if (iter != mInputMaps.end()) {
67 return std::get<2>(iter->second).get();
68 }
69 // CPU -> XPU
70 if (srcBackend->type() == mCPUBackend->type()) {
71 std::shared_ptr<Tensor> wrapTensor(new Tensor);
72 TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
73 TensorUtils::adjustTensorForCompability(wrapTensor.get());
74 wrapTensor->buffer().type = inputTensor->buffer().type;
75 TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
76 mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(dstBackend, dstBackend, wrapTensor)));
77 return wrapTensor.get();
78 }
79 // XPU -> CPU
80 if (dstBackend->type() == mCPUBackend->type()) {
81 std::shared_ptr<Tensor> wrapTensor(new Tensor);
82 TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
83 wrapTensor->buffer().type = inputTensor->buffer().type;
84 TensorUtils::adjustTensorForCompability(wrapTensor.get());
85 TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
86 mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, wrapTensor)));
87 return wrapTensor.get();
88 }
89 // XPU -> CPU -> XPU'
90 std::shared_ptr<Tensor> midTensor(new Tensor);
91 std::shared_ptr<Tensor> wrapTensor(new Tensor);
92 TensorUtils::copyShape(inputTensor, midTensor.get(), true);
93 TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
94 TensorUtils::adjustTensorForCompability(wrapTensor.get());
95 TensorUtils::adjustTensorForCompability(midTensor.get());
96 TensorUtils::getDescribe(midTensor.get())->usage = TensorUtils::getDescribe(inputTensor)->usage;
97 TensorUtils::getDescribe(midTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
98 midTensor->buffer().type = inputTensor->buffer().type;
99 wrapTensor->buffer().type = inputTensor->buffer().type;
100 mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, midTensor)));
101 mInputMaps.insert(std::make_pair(midTensor.get(), std::make_tuple(dstBackend, dstBackend, wrapTensor)));
102 return wrapTensor.get();
103 }
104
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)105 ErrorCode WrapExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
106 mWrapInputTensors.resize(inputs.size());
107 mInputMaps.clear();
108
109 auto dstBackend = mExecution->backend();
110 for (int i = 0; i < inputs.size(); ++i) {
111 auto inputTensor = inputs[i];
112 auto des = TensorUtils::getDescribe(inputTensor);
113 if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
114 MNN_ASSERT(inputs.size() == 1);
115 mWrapForRaster.reset(new Tensor);
116 TensorUtils::copyShape(inputTensor, mWrapForRaster.get(), true);
117 mWrapForRaster->buffer().type = inputTensor->buffer().type;
118 auto wrapDes = TensorUtils::getDescribe(mWrapForRaster.get());
119 wrapDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
120 wrapDes->regions = des->regions;
121 for (auto& r : wrapDes->regions) {
122 r.origin = _getCopyTensor(r.origin);
123 }
124 mWrapInputTensors[i] = mWrapForRaster.get();
125 } else {
126 mWrapInputTensors[i] = _getCopyTensor(inputTensor);
127 }
128 }
129
130 for (int i = 0; i < outputs.size(); ++i) {
131 MNN_ASSERT(TensorUtils::getDescribe(outputs[i])->backend == dstBackend);
132 }
133 bool memoryAllocSuccess = true;
134 // acquire memory, copy const tensors
135 for (auto& iter : mInputMaps) {
136 auto backend = std::get<0>(iter.second);
137 auto converter = std::get<1>(iter.second);
138 auto src = iter.first;
139 auto dst = std::get<2>(iter.second).get();
140
141 if (TensorUtils::getDescribe(src)->usage == TensorUsage::CONSTANT && mStatic) {
142 memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC_SEPERATE);
143 if (memoryAllocSuccess) {
144 converter->onCopyBuffer(src, dst);
145 TensorUtils::getDescribe(dst)->usage = TensorUtils::getDescribe(src)->usage;
146 }
147 } else {
148 memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC);
149 }
150 }
151 if (!memoryAllocSuccess) {
152 return OUT_OF_MEMORY;
153 }
154
155 // do resize
156 auto result = mExecution->onResize(mWrapInputTensors, outputs);
157
158 // release memory
159 for (auto& iter : mInputMaps) {
160 auto backend = std::get<0>(iter.second);
161 auto dst = std::get<2>(iter.second).get();
162
163 if (TensorUtils::getDescribe(dst)->usage == TensorUsage::CONSTANT && mStatic) {
164 backend->onReleaseBuffer(dst, Backend::DYNAMIC_SEPERATE);
165 } else {
166 backend->onReleaseBuffer(dst, Backend::DYNAMIC);
167 }
168 }
169 return result;
170 }
171
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)172 ErrorCode WrapExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
173 MNN_ASSERT(mWrapInputTensors.size() == inputs.size());
174
175 // copy variant tensors
176 for (auto& iter : mInputMaps) {
177 auto converter = std::get<1>(iter.second);
178 auto src = iter.first;
179 auto dst = std::get<2>(iter.second).get();
180 if (TensorUtils::getDescribe(src)->usage != TensorUsage::CONSTANT || (!mStatic)) {
181 converter->onCopyBuffer(src, dst);
182 }
183 }
184 auto code = mExecution->onExecute(mWrapInputTensors, outputs);
185 return code;
186 }
187
188 } // namespace MNN
189