1 //
2 //  WrapExecution.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/09/03.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "core/WrapExecution.hpp"
10 #include "core/TensorUtils.hpp"
11 #include "backend/cpu/CPUBackend.hpp"
12 #include "backend/cpu/compute/CommonOptFunction.h"
13 namespace MNN {
needWrap(const Tensor * input,Backend * curBackend)14 bool WrapExecution::needWrap(const Tensor* input, Backend* curBackend) {
15     if (curBackend->type() == MNN_FORWARD_NN) {
16         return false;
17     }
18     auto des = TensorUtils::getDescribe(input);
19     auto bn = des->backend;
20     MNNForwardType type = MNN_FORWARD_CPU;
21     int pack = 4;
22     int bytes = 4;
23     if (nullptr != bn) {
24         type = bn->type();
25         if (type == MNN_FORWARD_CPU_EXTENSION) {
26             auto core = static_cast<CPUBackend*>(bn)->functions();
27             pack = core->pack;
28             bytes = core->bytes;
29         }
30     }
31     if (type == curBackend->type()) {
32         return false;;
33     }
34     bool srcCpu = (type == MNN_FORWARD_CPU_EXTENSION || type == MNN_FORWARD_CPU);
35     bool dstCpu = ((curBackend->type() == MNN_FORWARD_CPU_EXTENSION) || (curBackend->type() == MNN_FORWARD_CPU));
36     if (srcCpu && dstCpu) {
37         auto dstCore = static_cast<CPUBackend*>(curBackend)->functions();
38         if (dstCore->bytes == bytes) {
39             if (dstCore->pack == pack || des->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) {
40                 return false;
41             }
42         }
43     }
44     return true;
45 }
46 
WrapExecution(Backend * CPUBackend,std::shared_ptr<Execution> execution,bool isStatic)47 WrapExecution::WrapExecution(Backend* CPUBackend, std::shared_ptr<Execution> execution, bool isStatic)
48     : Execution(execution->backend()), mCPUBackend(CPUBackend), mExecution(execution) {
49     mValid  = execution->valid();
50     mStatic = isStatic;
51 }
52 
_getCopyTensor(Tensor * inputTensor)53 Tensor* WrapExecution::_getCopyTensor(Tensor* inputTensor) {
54     auto dstBackend = mExecution->backend();
55     auto inputDes   = TensorUtils::getDescribe(inputTensor);
56     auto srcBackend = inputDes->backend;
57     if (nullptr == srcBackend) {
58         srcBackend = mCPUBackend;
59     }
60     // CPU -> CPU or XPU -> XPU
61     //if (srcBackend == dstBackend) {
62     if (srcBackend->type() == dstBackend->type()) {
63         return inputTensor;
64     }
65     auto iter = mInputMaps.find(inputTensor);
66     if (iter != mInputMaps.end()) {
67         return std::get<2>(iter->second).get();
68     }
69     // CPU -> XPU
70     if (srcBackend->type() == mCPUBackend->type()) {
71         std::shared_ptr<Tensor> wrapTensor(new Tensor);
72         TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
73         TensorUtils::adjustTensorForCompability(wrapTensor.get());
74         wrapTensor->buffer().type = inputTensor->buffer().type;
75         TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
76         mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(dstBackend, dstBackend, wrapTensor)));
77         return wrapTensor.get();
78     }
79     // XPU -> CPU
80     if (dstBackend->type() == mCPUBackend->type()) {
81         std::shared_ptr<Tensor> wrapTensor(new Tensor);
82         TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
83         wrapTensor->buffer().type = inputTensor->buffer().type;
84         TensorUtils::adjustTensorForCompability(wrapTensor.get());
85         TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
86         mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, wrapTensor)));
87         return wrapTensor.get();
88     }
89     // XPU -> CPU -> XPU'
90     std::shared_ptr<Tensor> midTensor(new Tensor);
91     std::shared_ptr<Tensor> wrapTensor(new Tensor);
92     TensorUtils::copyShape(inputTensor, midTensor.get(), true);
93     TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
94     TensorUtils::adjustTensorForCompability(wrapTensor.get());
95     TensorUtils::adjustTensorForCompability(midTensor.get());
96     TensorUtils::getDescribe(midTensor.get())->usage = TensorUtils::getDescribe(inputTensor)->usage;
97     TensorUtils::getDescribe(midTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
98     midTensor->buffer().type                         = inputTensor->buffer().type;
99     wrapTensor->buffer().type                        = inputTensor->buffer().type;
100     mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, midTensor)));
101     mInputMaps.insert(std::make_pair(midTensor.get(), std::make_tuple(dstBackend, dstBackend, wrapTensor)));
102     return wrapTensor.get();
103 }
104 
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)105 ErrorCode WrapExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
106     mWrapInputTensors.resize(inputs.size());
107     mInputMaps.clear();
108 
109     auto dstBackend = mExecution->backend();
110     for (int i = 0; i < inputs.size(); ++i) {
111         auto inputTensor = inputs[i];
112         auto des         = TensorUtils::getDescribe(inputTensor);
113         if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
114             MNN_ASSERT(inputs.size() == 1);
115             mWrapForRaster.reset(new Tensor);
116             TensorUtils::copyShape(inputTensor, mWrapForRaster.get(), true);
117             mWrapForRaster->buffer().type = inputTensor->buffer().type;
118             auto wrapDes                  = TensorUtils::getDescribe(mWrapForRaster.get());
119             wrapDes->memoryType           = Tensor::InsideDescribe::MEMORY_VIRTUAL;
120             wrapDes->regions              = des->regions;
121             for (auto& r : wrapDes->regions) {
122                 r.origin = _getCopyTensor(r.origin);
123             }
124             mWrapInputTensors[i] = mWrapForRaster.get();
125         } else {
126             mWrapInputTensors[i] = _getCopyTensor(inputTensor);
127         }
128     }
129 
130     for (int i = 0; i < outputs.size(); ++i) {
131         MNN_ASSERT(TensorUtils::getDescribe(outputs[i])->backend == dstBackend);
132     }
133     bool memoryAllocSuccess = true;
134     // acquire memory, copy const tensors
135     for (auto& iter : mInputMaps) {
136         auto backend   = std::get<0>(iter.second);
137         auto converter = std::get<1>(iter.second);
138         auto src       = iter.first;
139         auto dst       = std::get<2>(iter.second).get();
140 
141         if (TensorUtils::getDescribe(src)->usage == TensorUsage::CONSTANT && mStatic) {
142             memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC_SEPERATE);
143             if (memoryAllocSuccess) {
144                 converter->onCopyBuffer(src, dst);
145                 TensorUtils::getDescribe(dst)->usage = TensorUtils::getDescribe(src)->usage;
146             }
147         } else {
148             memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC);
149         }
150     }
151     if (!memoryAllocSuccess) {
152         return OUT_OF_MEMORY;
153     }
154 
155     // do resize
156     auto result = mExecution->onResize(mWrapInputTensors, outputs);
157 
158     // release memory
159     for (auto& iter : mInputMaps) {
160         auto backend = std::get<0>(iter.second);
161         auto dst     = std::get<2>(iter.second).get();
162 
163         if (TensorUtils::getDescribe(dst)->usage == TensorUsage::CONSTANT && mStatic) {
164             backend->onReleaseBuffer(dst, Backend::DYNAMIC_SEPERATE);
165         } else {
166             backend->onReleaseBuffer(dst, Backend::DYNAMIC);
167         }
168     }
169     return result;
170 }
171 
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)172 ErrorCode WrapExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
173     MNN_ASSERT(mWrapInputTensors.size() == inputs.size());
174 
175     // copy variant tensors
176     for (auto& iter : mInputMaps) {
177         auto converter = std::get<1>(iter.second);
178         auto src       = iter.first;
179         auto dst       = std::get<2>(iter.second).get();
180         if (TensorUtils::getDescribe(src)->usage != TensorUsage::CONSTANT || (!mStatic)) {
181             converter->onCopyBuffer(src, dst);
182         }
183     }
184     auto code = mExecution->onExecute(mWrapInputTensors, outputs);
185     return code;
186 }
187 
188 } // namespace MNN
189