1 //
2 // CPUDeconvolution.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/07/20.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "CPUDeconvolution.hpp"
10 #include "core/BufferAllocator.hpp"
11 #include "CPUBackend.hpp"
12 #include "core/Concurrency.h"
13 #include "core/Macro.h"
14 #include "core/AutoStorage.h"
15 #include "math/Matrix.hpp"
16 #include "core/TensorUtils.hpp"
17 #include "core/ConvolutionCommon.hpp"
18 #include "compute/CommonOptFunction.h"
19 #include "compute/ConvOpt.h"
20 #include "compute/DeconvolutionWithStride.hpp"
21 //#define MNN_OPEN_TIME_TRACE
22 #include <MNN/AutoTime.hpp>
23
24 namespace MNN {
CPUDeconvolutionBasic(const Tensor * input,const Op * convOp,Backend * b)25 CPUDeconvolutionBasic::CPUDeconvolutionBasic(const Tensor* input, const Op* convOp, Backend* b)
26 : CPUConvolution(convOp->main_as_Convolution2D()->common(), b) {
27 mSrcCount = input->channel();
28 mPostParameters = getPostParameters();
29 }
30
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)31 ErrorCode CPUDeconvolutionBasic::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
32 auto input = inputs[0];
33 auto output = outputs[0];
34 auto pad = ConvolutionCommon::convolutionTransposePad(input, output, mCommon);
35 mPadY = pad.second;
36 mPadX = pad.first;
37 return NO_ERROR;
38 }
39
CPUDeconvolutionCommon(const Tensor * input,const Op * convOp,Backend * b)40 CPUDeconvolutionCommon::CPUDeconvolutionCommon(const Tensor* input, const Op* convOp, Backend* b)
41 : CPUDeconvolutionBasic(input, convOp, b) {
42 auto conv2D = convOp->main_as_Convolution2D();
43 int outputCount = mCommon->outputCount();
44 auto core = static_cast<CPUBackend*>(b)->functions();
45 mBias.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputCount, core->pack) * core->pack}));
46 bool success = b->onAcquireBuffer(mBias.get(), Backend::STATIC);
47 if (!success) {
48 mValid = false;
49 return;
50 }
51 ::memset(mBias->host<float>(), 0, mBias->length(0) * core->bytes);
52 if (core->bytes == 4) {
53 ::memcpy(mBias->host<float>(), conv2D->bias()->data(), conv2D->bias()->size() * sizeof(float));
54 } else {
55 core->MNNFp32ToLowp(conv2D->bias()->data(), mBias->host<int16_t>(), conv2D->bias()->size());
56 }
57 }
58
~CPUDeconvolutionCommon()59 CPUDeconvolutionCommon::~CPUDeconvolutionCommon() {
60 backend()->onReleaseBuffer(mBias.get(), Backend::STATIC);
61 }
62
_transformWeight(const uint8_t * tempWeight,uint8_t * dest,int outputCount,int srcCount,int fh,int fw,uint8_t * cache,const CoreFunctions * core)63 static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outputCount, int srcCount, int fh, int fw,
64 uint8_t* cache, const CoreFunctions* core) {
65 auto outputC4 = UP_DIV(outputCount, core->pack);
66 // c, n, h, w-> c, n/4 * 4, h, w
67 for (int c=0; c<srcCount; ++c) {
68 auto dst = cache + c * outputC4 * fw * fh * core->pack * core->bytes;
69 auto src = tempWeight + c * outputCount * fw * fh * core->bytes;
70 core->MNNPackCUnit((float*)dst, (const float*)src, fw*fh, outputCount);
71 }
72 //printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw);
73 core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false);
74 }
75
CPUDeconvolution(const Tensor * input,const Op * convOp,Backend * backend)76 CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backend* backend)
77 : MNN::CPUDeconvolutionCommon(input, convOp, backend) {
78 auto layer = convOp->main_as_Convolution2D()->common();
79 auto core = static_cast<CPUBackend*>(backend)->functions();
80
81 const float* tempWeight = nullptr;
82 int tempWeightSize = 0;
83 std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
84 ConvolutionCommon::getConvParameters(&quanCommon, convOp->main_as_Convolution2D(), &tempWeight, &tempWeightSize);
85
86 int fw = layer->kernelX();
87 int fh = layer->kernelY();
88 int srcCount = mSrcCount;
89 int eP, lP, hP;
90 core->MNNGetMatMulPackMode(&eP, &lP, &hP);
91 auto outputAlign = UP_DIV(layer->outputCount(), core->pack) * core->pack * fw * fh;
92 mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
93 std::shared_ptr<Tensor> cache(Tensor::createDevice<float>({outputAlign * srcCount}));
94 bool success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC) &&
95 backend->onAcquireBuffer(cache.get(), Backend::STATIC);
96 if (!success) {
97 mValid = false;
98 return;
99 }
100 auto dest = mWeight->host<uint8_t>();
101 int outputCount = layer->outputCount();
102 AutoStorage<uint8_t> lowpWeight;
103 if (core->bytes < 4) {
104 lowpWeight.reset(outputCount * srcCount * fh * fw * core->bytes);
105 if (lowpWeight.get() == nullptr) {
106 mValid = false;
107 return;
108 }
109 core->MNNFp32ToLowp(tempWeight, (int16_t*)lowpWeight.get(), outputCount * srcCount * fh * fw);
110 tempWeight = (float*)lowpWeight.get();
111 }
112 _transformWeight((uint8_t*)tempWeight, dest, outputCount, srcCount, fh, fw, cache->host<uint8_t>(), core);
113 backend->onReleaseBuffer(cache.get(), Backend::STATIC);
114 mOrigin.reset(new CPUDeconvolutionOrigin(input, convOp, backend));
115 }
116
~CPUDeconvolution()117 CPUDeconvolution::~CPUDeconvolution() {
118 backend()->onReleaseBuffer(mWeight.get(), Backend::STATIC);
119 }
120
121
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)122 ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
123 CPUDeconvolutionBasic::onResize(inputs, outputs);
124 auto core = static_cast<CPUBackend*>(backend())->functions();
125 auto input = inputs[0];
126 auto output = outputs[0];
127 auto oc = output->channel();
128 if (UP_DIV(oc, core->pack) * core->pack != inputs[2]->length(0)) {
129 return INPUT_DATA_ERROR;
130 }
131
132 auto ocC4 = UP_DIV(output->channel(), core->pack);
133 auto icC4 = UP_DIV(input->channel(), core->pack);
134 auto kw = mCommon->kernelX();
135 auto kh = mCommon->kernelY();
136 auto dilateX = mCommon->dilateX();
137 auto dilateY = mCommon->dilateY();
138 auto strideX = mCommon->strideX();
139 auto strideY = mCommon->strideY();
140 auto padX = mPadX;
141 auto padY = mPadY;
142 auto width = input->width();
143 auto height = input->height();
144 auto src_height = output->height();
145 auto src_width = output->width();
146
147 auto kernelCount = ocC4 * mCommon->kernelX() * mCommon->kernelY();
148 mPreFunctions.clear();
149 mPostFunctions.clear();
150 auto plane = width * height;
151 const int maxDepth = 5;
152 AutoRelease<Tensor> tempColTotalBuffer(Tensor::createDevice<float>({kernelCount, plane, core->pack}));
153 auto res = backend()->onAcquireBuffer(tempColTotalBuffer.get(), Backend::DYNAMIC);
154 if (!res) {
155 return OUT_OF_MEMORY;
156 }
157 auto colBufferPtr = tempColTotalBuffer->host<float>();
158 auto biasPtr = inputs[2]->host<float>();
159 auto inputPtr = input->host<float>();
160 AutoRelease<Tensor> tempInputBuffer(
161 Tensor::create<float>({icC4, plane, core->pack}, inputPtr));
162 AutoRelease<Tensor> tempInput(Tensor::createDevice<float>({icC4, plane, core->pack}));
163 auto threadNumber = ((CPUBackend*)backend())->threadNumber();
164 if (input->batch() != 1) {
165 res = backend()->onAcquireBuffer(tempInput.get(), Backend::DYNAMIC);
166 if (!res) {
167 return OUT_OF_MEMORY;
168 }
169 auto newInputPtr = tempInput->host<uint8_t>();
170 // Copy Batch
171 mPreFunctions.emplace_back(std::make_pair([newInputPtr, icC4, plane, threadNumber, core](const float* srcBatch, int tId) {
172 for (int c = tId; c<icC4; c+=threadNumber) {
173 auto srcDepth = ((uint8_t*)srcBatch) + c * plane * core->pack * core->bytes;
174 auto dstDepth = newInputPtr + c * plane * core->pack * core->bytes;
175 ::memcpy(dstDepth, srcDepth, plane * core->pack * core->bytes);
176 }
177 }, threadNumber));
178 } else {
179 tempInput->buffer().host = (uint8_t*)inputPtr;
180 }
181 mMatMul.reset(new StrassenMatrixComputor(backend(), true, maxDepth));
182 mMatMul->onEncode({tempInput.get(), inputs[1]}, {tempColTotalBuffer.get()});
183 mPostFunctions.emplace_back(std::make_pair([colBufferPtr, ocC4, width, height, kh, kw, padY, padX, dilateY, dilateX, strideY,
184 strideX, threadNumber, src_width, src_height, plane, biasPtr, this, core](float* outputPtr, int tId) {
185 auto unitBytes = core->pack * core->bytes;
186 for (int z = (tId); z < ocC4; z += threadNumber) {
187 auto dstZ = (uint8_t*)outputPtr + z * src_height * src_width * unitBytes;
188 auto srcZ = (uint8_t*)colBufferPtr + kw * kh * plane * z * unitBytes;
189 auto dstB = dstZ;
190 ::memset(dstB, 0, src_width * src_height * unitBytes);
191 auto srcB = srcZ;
192 for (int oy = 0; oy < height; ++oy) {
193 for (int ox = 0; ox < width; ++ox) {
194 int srcStartX = ox * strideX - padX;
195 int srcStartY = oy * strideY - padY;
196
197 int sfy = ALIMAX(0, (UP_DIV(-srcStartY, dilateY)));
198 int efy = ALIMIN(kh, UP_DIV(src_height - srcStartY, dilateY));
199
200 int sfx = ALIMAX(0, (UP_DIV(-srcStartX, dilateX)));
201 int efx = ALIMIN(kw, UP_DIV(src_width - srcStartX, dilateX));
202
203 auto dstStart = dstB + srcStartX * unitBytes + srcStartY * src_width * unitBytes;
204 auto srcStart = srcB + unitBytes * (ox + oy * width);
205 if (sfy >= efy || sfx >= efx) {
206 continue;
207 }
208
209 for (int fy = sfy; fy < efy; ++fy) {
210 auto dstY = dstStart + fy * unitBytes * dilateY * src_width;
211 auto srcY = srcStart + fy * kw * plane * unitBytes;
212 core->MNNAddC4WithStride((const float*)(srcY + sfx * plane * unitBytes), (float*)(dstY + sfx * dilateX * unitBytes), plane * core->pack, dilateX * core->pack, efx - sfx);
213 }
214 }
215 }
216 core->MNNAxByClampBroadcastUnit((float*)dstZ, (float*)dstZ, (const float*)((uint8_t*)biasPtr + unitBytes * z), src_height * src_width, 0, 0, 1, mPostParameters.data());
217 }
218 }, threadNumber));
219 if (tempInput->host<float>() != inputPtr) {
220 backend()->onReleaseBuffer(tempInput.get(), Backend::DYNAMIC);
221 }
222 backend()->onReleaseBuffer(tempColTotalBuffer.get(), Backend::DYNAMIC);
223 return NO_ERROR;
224 }
225
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)226 ErrorCode CPUDeconvolutionOrigin::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
227 auto batch = inputs[0]->batch();
228 auto core = static_cast<CPUBackend*>(backend())->functions();
229 auto input = inputs[0];
230 auto output = outputs[0];
231 auto oc = output->channel();
232 auto ocC4 = UP_DIV(output->channel(), core->pack);
233 auto icC4 = UP_DIV(input->channel(), core->pack);
234 auto width = output->width();
235 auto height = output->height();
236 auto src_height = input->height();
237 auto src_width = input->width();
238 for (int i=0; i<batch; ++i) {
239 auto inputPtr = inputs[0]->host<uint8_t>() + i * src_width * src_height * icC4 * core->pack * core->bytes;
240 auto outputPtr = outputs[0]->host<uint8_t>() + i * width * height * ocC4 * core->pack * core->bytes;
241 for (auto& unit : mPreFunctions) {
242 MNN_CONCURRENCY_BEGIN(tId, unit.second) {
243 unit.first((float*)inputPtr, (int)tId);
244 }
245 MNN_CONCURRENCY_END();
246 }
247 mMatMul->onExecute();
248 for (auto& unit : mPostFunctions) {
249 MNN_CONCURRENCY_BEGIN(tId, unit.second) {
250 unit.first((float*)outputPtr, (int)tId);
251 }
252 MNN_CONCURRENCY_END();
253 }
254 }
255 return NO_ERROR;
256 }
257 class CPUDeconvolutionCreator : public CPUBackend::Creator {
258 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const259 virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
260 const MNN::Op* op, Backend* backend) const {
261 auto convOp = op->main_as_Convolution2D();
262 auto common = convOp->common();
263 if (backend->type() == MNN_FORWARD_CPU) {
264 if (common->strideY() > 1 || common->strideX() > 1) {
265 if (common->dilateX() == 1 && common->dilateY() == 1) {
266 if (common->kernelX() / common->strideX() > 2 || common->kernelY() / common->strideY() > 2) {
267 return new DeconvolutionWithStride(inputs[0], op, backend);
268 }
269 }
270 }
271 }
272 return new CPUDeconvolution(inputs[0], op, backend);
273 }
274 };
275
276 REGISTER_CPU_OP_CREATOR(CPUDeconvolutionCreator, OpType_Deconvolution);
277 } // namespace MNN
278