1 //
2 // CPUDeconvolutionDepthwise.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/07/23.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "CPUDeconvolutionDepthwise.hpp"
10 #include <string.h>
11 #include "backend/cpu/CPUBackend.hpp"
12 #include "core/Macro.h"
13 #include "compute/CommonOptFunction.h"
14 #include "core/Concurrency.h"
15
16
17 namespace MNN {
CPUDeconvolutionDepthwise(const Tensor * input,const Op * convOp,Backend * b)18 CPUDeconvolutionDepthwise::CPUDeconvolutionDepthwise(const Tensor* input, const Op* convOp, Backend* b)
19 : MNN::CPUDeconvolutionCommon(input, convOp, b) {
20 auto conv = convOp->main_as_Convolution2D();
21 auto layer = convOp->main_as_Convolution2D()->common();
22 int kw = layer->kernelX();
23 int kh = layer->kernelY();
24 int outputCount = layer->outputCount();
25 auto core = static_cast<CPUBackend*>(backend())->functions();
26 int depthQuad = UP_DIV(outputCount, core->pack);
27 const float* tempWeight = nullptr;
28 int tempWeightSize = 0;
29 std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
30 ConvolutionCommon::getConvParameters(&quanCommon, conv, &tempWeight, &tempWeightSize);
31
32 // Reorder weight from whc -> pwhc4
33 int kernelSize = depthQuad * core->pack * kw * kh;
34 mWeight.reset(Tensor::createDevice<float>(std::vector<int>{kernelSize}));
35 auto sucess = backend()->onAcquireBuffer(mWeight.get(), Backend::STATIC);
36 if (!sucess) {
37 mValid = false;
38 return;
39 }
40 AutoStorage<uint8_t> weightTempStorage;
41 if (core->bytes < 4) {
42 weightTempStorage.reset(kernelSize * core->bytes);
43 if (weightTempStorage.get() == nullptr) {
44 mValid = false;
45 return;
46 }
47 core->MNNFp32ToLowp(tempWeight, (int16_t*)weightTempStorage.get(), kernelSize);
48 tempWeight = (const float*)weightTempStorage.get();
49 }
50 auto weight = mWeight->host<float>();
51 core->MNNPackCUnit(weight, tempWeight, kw * kh, outputCount);
52 mOrigin.reset(new CPUDeconvolutionDepthwiseBasic(input, convOp, b));
53 }
54
~CPUDeconvolutionDepthwise()55 CPUDeconvolutionDepthwise::~CPUDeconvolutionDepthwise() {
56 backend()->onReleaseBuffer(mWeight.get(), Backend::STATIC);
57 }
58
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)59 ErrorCode CPUDeconvolutionDepthwiseMultiInput::onResize(const std::vector<Tensor*>& inputs,
60 const std::vector<Tensor*>& outputs) {
61 auto kw = mCommon->kernelX();
62 auto kh = mCommon->kernelY();
63 auto core = static_cast<CPUBackend*>(backend())->functions();
64 mWeight.reset(Tensor::createDevice<float>({UP_DIV(inputs[0]->channel(), core->pack), kh, kw, core->pack}));
65 mBias.reset(Tensor::createDevice<float>({UP_DIV(inputs[0]->channel(), core->pack), core->pack}));
66 backend()->onAcquireBuffer(mWeight.get(), Backend::DYNAMIC);
67 backend()->onAcquireBuffer(mBias.get(), Backend::DYNAMIC);
68 mInputs = {inputs[0], mWeight.get(), mBias.get()};
69 auto code = CPUDeconvolutionDepthwiseBasic::onResize(mInputs, outputs);
70 backend()->onReleaseBuffer(mWeight.get(), Backend::DYNAMIC);
71 backend()->onReleaseBuffer(mBias.get(), Backend::DYNAMIC);
72 return code;
73 }
74
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)75 ErrorCode CPUDeconvolutionDepthwiseMultiInput::onExecute(const std::vector<Tensor*>& inputs,
76 const std::vector<Tensor*>& outputs) {
77 auto core = static_cast<CPUBackend*>(backend())->functions();
78 ::memset(mBias->host<float>(), 0, mBias->elementSize() * core->bytes);
79 if (inputs.size() > 2) {
80 ::memcpy(mBias->host<float>(), inputs[2]->host<float>(), inputs[2]->elementSize() * core->bytes);
81 }
82 ::memset(mWeight->host<float>(), 0, mWeight->elementSize() * core->bytes);
83 auto weight = mWeight->host<float>();
84 auto outputCount = inputs[0]->channel();
85 auto kh = mWeight->length(1);
86 auto kw = mWeight->length(2);
87 auto tempWeight = inputs[1]->host<float>();
88 core->MNNPackCUnit(weight, tempWeight, kw * kh, outputCount);
89 return CPUDeconvolutionDepthwiseBasic::onExecute(mInputs, outputs);
90 }
91
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)92 ErrorCode CPUDeconvolutionDepthwiseBasic::onResize(const std::vector<Tensor*>& inputs,
93 const std::vector<Tensor*>& outputs) {
94 CPUDeconvolutionBasic::onResize(inputs, outputs);
95 auto core = static_cast<CPUBackend*>(backend())->functions();
96 auto layer = mCommon;
97 auto inputTensor = outputs[0];
98 auto outputTensor = inputs[0];
99 int src_width = inputTensor->width();
100 int src_height = inputTensor->height();
101 int dst_width = outputTensor->width();
102 int dst_height = outputTensor->height();
103 int dst_depth_quad = UP_DIV(layer->outputCount(), core->pack);
104 int dst_z_step = dst_width * dst_height * core->pack;
105 int src_z_step = src_width * src_height * core->pack;
106 int dst_y_step = dst_width * core->pack;
107 int src_y_step = src_width * core->pack;
108 int strideY = layer->strideY();
109 int strideX = layer->strideX();
110 int dilateX = layer->dilateX();
111 int dilateY = layer->dilateY();
112 int dilateY_step = dilateY * src_width * core->pack;
113 int dilateX_step = dilateX * core->pack;
114 int kernel_height = layer->kernelY();
115 int kernel_width = layer->kernelX();
116 int padX = mPadX;
117 int padY = mPadY;
118 int weight_z_step = kernel_height * kernel_width * core->pack;
119 // Compute Mid Rect
120 int l = 0, t = 0, r = dst_width, b = dst_height;
121 for (; l * strideX - padX < 0 && l < dst_width; l++) {
122 // do nothing
123 }
124 for (; t * strideY - padY < 0 && t < dst_height; t++) {
125 // do nothing
126 }
127 for (; (r - 1) * strideX - padX + (kernel_width - 1) * dilateX >= src_width && r > l; r--) {
128 // do nothing
129 }
130 for (; (b - 1) * strideY - padY + (kernel_height - 1) * dilateY >= src_height && b > t; b--) {
131 // do nothing
132 }
133
134 #define RUN_BASIC(L, T, R, B) \
135 for (int dy = T; dy < B; ++dy) { \
136 auto dst_y = dst_z + dy * dst_y_step * core->bytes; \
137 int srcStartY = dy * strideY - padY; \
138 auto src_dy = src_z + srcStartY * src_y_step * core->bytes; \
139 int sfy = ALIMAX(0, (UP_DIV(-srcStartY, dilateY))); \
140 int efy = ALIMIN(kernel_height, UP_DIV(src_height - srcStartY, dilateY)); \
141 for (int dx = L; dx < R; ++dx) { \
142 auto dst_x = dst_y + core->pack * core->bytes * dx; \
143 int srcStartX = dx * strideX - padX; \
144 auto src_dx = src_dy + srcStartX * core->pack * core->bytes; \
145 int sfx = ALIMAX(0, (UP_DIV(-srcStartX, dilateX))); \
146 int efx = ALIMIN(kernel_width, UP_DIV(src_width - srcStartX, dilateX)); \
147 core->MNNDeconvRunForUnitDepthWise((const float*)dst_x, (float*)(src_dx + (sfx * dilateX + sfy * dilateY * src_width) * core->bytes * core->pack), \
148 (const float*)(weight_dz + core->pack * core->bytes * (kernel_width * sfy + sfx)), efx - sfx, efy - sfy, \
149 core->pack * kernel_width, dilateX_step, dilateY_step); \
150 } \
151 }
152 auto weight = inputs[1];
153 auto bias = inputs[2];
154 int batch = inputs[0]->batch();
155 int totalSize = batch * dst_depth_quad;
156 int numberThread = ((CPUBackend*)backend())->threadNumber();
157
158 mFunction = [=](const uint8_t* dstOrigin, uint8_t* srcOrigin, int tId) {
159 for (int dz = tId; dz < totalSize; dz+=numberThread) {
160 auto zPos = dz % dst_depth_quad;
161 auto dst_z = dstOrigin + dst_z_step * dz * core->bytes;
162 auto src_z = srcOrigin + src_z_step * dz * core->bytes;
163 auto weight_dz = weight->host<uint8_t>() + zPos * weight_z_step * core->bytes;
164 ::memset(src_z, 0, src_width * src_height * core->bytes * core->pack);
165
166 RUN_BASIC(0, 0, dst_width, t);
167 RUN_BASIC(0, b, dst_width, dst_height);
168
169 RUN_BASIC(0, t, l, b);
170 RUN_BASIC(r, t, dst_width, b);
171
172 if (r > l) {
173 for (int dy = t; dy < b; ++dy) {
174 auto dst_y = dst_z + dy * dst_y_step * core->bytes;
175 int srcStartY = dy * strideY - padY;
176 auto src_dy = src_z + srcStartY * src_y_step * core->bytes;
177 core->MNNDeconvRunForLineDepthwise((const float*)(dst_y + l * core->pack * core->bytes), (float*)(src_dy + (l * strideX - padX) * core->bytes * core->pack), (const float*)weight_dz, r - l,
178 strideX * core->pack, kernel_width, kernel_height, dilateX_step, dilateY_step);
179 }
180 }
181 core->MNNAxByClampBroadcastUnit((float*)src_z, (float*)src_z, (const float*)(bias->host<uint8_t>() + zPos * core->pack * core->bytes), src_width * src_height, 0, 0, 1, mPostParameters.data());
182 }
183 };
184 #undef RUN_BASIC
185
186 return NO_ERROR;
187 }
188
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)189 ErrorCode CPUDeconvolutionDepthwiseBasic::onExecute(const std::vector<Tensor*>& inputs,
190 const std::vector<Tensor*>& outputs) {
191 // Revert input and output, do deconvolution
192 auto inputTensor = outputs[0];
193 auto outputTensor = inputs[0];
194 int numberThread = ((CPUBackend*)backend())->threadNumber();
195 auto srcOrigin = inputTensor->host<uint8_t>();
196 auto dstOrigin = outputTensor->host<uint8_t>();
197 MNN_CONCURRENCY_BEGIN(tId, numberThread) {
198 mFunction(dstOrigin, srcOrigin, tId);
199 };
200 MNN_CONCURRENCY_END();
201 return NO_ERROR;
202 }
203
204 class CPUDeconvolutionDepthwiseCreator : public CPUBackend::Creator {
205 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const206 virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
207 const MNN::Op* op, Backend* backend) const {
208 if (1 < inputs.size()) {
209 return new CPUDeconvolutionDepthwiseMultiInput(inputs[0], op, backend);
210 }
211 return new CPUDeconvolutionDepthwise(inputs[0], op, backend);
212 }
213 };
214
215 REGISTER_CPU_OP_CREATOR(CPUDeconvolutionDepthwiseCreator, OpType_DeconvolutionDepthwise);
216
217 } // namespace MNN
218