1 #include "ConvDepthWiseExecution.hpp"
2 #include "core/ConvolutionCommon.hpp"
3 namespace MNN {
4 namespace CUDA {
5 struct constBuffer {
6 int pad[2];
7 int kernelSize[2];
8 int stride[2];
9 int dilate[2];
10 int inputSize[2];
11 int outputSize[2];
12 int channel;
13 int subChannel;
14 int total;
15 int activationType;
16 } uConstant;
17
ConvDepthWiseExecution(const Op * op,Backend * bn)18 ConvDepthWiseExecution::ConvDepthWiseExecution(const Op* op, Backend* bn) : Execution(bn) {
19 mOp = op;
20 auto pool = static_cast<CUDABackend*>(bn)->getStaticBufferPool();
21 mConstBuffer = pool->alloc(sizeof(constBuffer));
22
23 auto conv = mOp->main_as_Convolution2D();
24 //weight host->device
25 if(nullptr != conv->weight()) {
26 int weightSize = conv->weight()->size();
27 weightTensor.reset(Tensor::createDevice<float>({weightSize}));
28 backend()->onAcquireBuffer(weightTensor.get(), Backend::STATIC);
29 mFilter = (void *)weightTensor.get()->buffer().device;
30 cuda_check(cudaMemcpy(mFilter, conv->weight()->data(), conv->weight()->size()*sizeof(float), cudaMemcpyHostToDevice));
31
32 mBias = nullptr;
33 if(conv->bias()->size() != 0) {
34 int biasSize = conv->bias()->size();
35 biasTensor.reset(Tensor::createDevice<float>({biasSize}));
36 backend()->onAcquireBuffer(biasTensor.get(), Backend::STATIC);
37 mBias = (void *)biasTensor.get()->buffer().device;
38 cuda_check(cudaMemcpy(mBias, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice));
39 use_bias_ = true;
40 }
41 }
42 }
~ConvDepthWiseExecution()43 ConvDepthWiseExecution::~ ConvDepthWiseExecution() {
44 auto pool = static_cast<CUDABackend*>(backend())->getStaticBufferPool();
45 pool->free(mConstBuffer);
46 if (nullptr != weightTensor) {
47 backend()->onReleaseBuffer(weightTensor.get(), Backend::STATIC);
48 }
49 if(use_bias_ && nullptr != biasTensor) {
50 backend()->onReleaseBuffer(biasTensor.get(), Backend::STATIC);
51 }
52 }
53
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)54 ErrorCode ConvDepthWiseExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
55 auto pad = ConvolutionCommon::convolutionPad(inputs[0], outputs[0], mOp->main_as_Convolution2D()->common());
56 auto conv = mOp->main_as_Convolution2D();
57 auto convCommon = mOp->main_as_Convolution2D()->common();
58 constBuffer parameters;
59 parameters.pad[0] = pad.first;
60 parameters.pad[1] = pad.second;
61 parameters.kernelSize[0] = convCommon->kernelX();
62 parameters.kernelSize[1] = convCommon->kernelY();
63 parameters.stride[0] = convCommon->strideX();
64 parameters.stride[1] = convCommon->strideY();
65 parameters.dilate[0] = convCommon->dilateX();
66 parameters.dilate[1] = convCommon->dilateY();
67 parameters.inputSize[0] = inputs[0]->width();
68 parameters.inputSize[1] = inputs[0]->height();
69 parameters.channel = inputs[0]->batch() * inputs[0]->channel();
70 parameters.outputSize[0] = outputs[0]->width();
71 parameters.outputSize[1] = outputs[0]->height();
72 parameters.total = parameters.channel * parameters.outputSize[1] * parameters.outputSize[0];
73 parameters.subChannel = inputs[0]->channel();
74 parameters.activationType = convCommon->relu() ? 1 : (convCommon->relu6() ? 2 : 0);
75
76 auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
77 runtime->memcpy((uint8_t*)mConstBuffer.first + mConstBuffer.second, ¶meters, sizeof(constBuffer), MNNMemcpyHostToDevice);
78 mTotalCount = parameters.total;
79
80 return NO_ERROR;
81 }
82
CONV_DW(const float * input,const float * kernel,const float * bias,float * output,const constBuffer * uConstant)83 __global__ void CONV_DW(const float* input, const float* kernel, const float* bias, float *output, const constBuffer* uConstant) {
84 for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < uConstant->total; i += blockDim.x * gridDim.x) {
85 {
86 int iw = uConstant->inputSize[0];
87 int ih = uConstant->inputSize[1];
88 int c = uConstant->channel;
89 int ow = uConstant->outputSize[0];
90 int oh = uConstant->outputSize[1];
91 int kw = uConstant->kernelSize[0];
92 int kh = uConstant->kernelSize[1];
93 int dw = uConstant->dilate[0];
94 int dh = uConstant->dilate[1];
95 int sw = uConstant->stride[0];
96 int sh = uConstant->stride[1];
97 int pw = uConstant->pad[0];
98 int ph = uConstant->pad[1];
99 int acttype = uConstant->activationType;
100
101 int oz = i / (ow * oh);
102 int tmp = i % (ow * oh);
103 int oy = tmp / ow;
104 int ox = tmp % ow;
105 int kz = oz % uConstant->subChannel;
106
107 int ix = ox * sw - pw;
108 int iy = oy * sh - ph;
109 float color = 0.0;
110 if (bias != nullptr) {
111 color = bias[kz];
112 }
113
114 int fx, fy, fz;
115 for (fy=0; fy<kh; ++fy) {
116 int sy = fy*dh + iy;
117 if (sy >= ih || sy < 0) {
118 continue;
119 }
120 for (fx=0; fx<kw; ++fx) {
121 int sx = fx*dw + ix;
122 if (sx >= iw || sx < 0) {
123 continue;
124 }
125 float inputValue = input[0
126 + sx
127 + sy * iw
128 + oz * iw * ih
129 ];
130 float k = kernel[0
131 + fx
132 + fy * kw
133 + kz * kw * kh
134 ];
135 color += k*inputValue;
136 }
137 }
138 color = (acttype==1) ? max(0.0, color) : (acttype==2 ? (min(max(0.0, color), 6.0)) : color);
139 output[0
140 + ox
141 + oy * ow
142 + oz * ow * oh
143 ] = color;
144 }
145 }
146 return;
147 }
148
149
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)150 ErrorCode ConvDepthWiseExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
151 auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
152 int block_num = runtime->blocks_num(mTotalCount);
153 int threads_num = runtime->threads_num();
154 auto constPtr = (uint8_t*)mConstBuffer.first + mConstBuffer.second;
155 if (inputs.size() == 1) {
156 CONV_DW<<<block_num, threads_num>>>((const float*)inputs[0]->deviceId(), (const float*)mFilter,
157 (const float*)mBias, (float*)outputs[0]->deviceId(), (const constBuffer*)(constPtr));
158 } else if (inputs.size() == 3) {
159 CONV_DW<<<block_num, threads_num>>>((const float*)inputs[0]->deviceId(), (const float*)inputs[1]->deviceId(),
160 (const float*)inputs[2]->deviceId(), (float*)outputs[0]->deviceId(), (const constBuffer*)constPtr);
161 } else {
162 MNN_ASSERT(inputs.size() == 2);
163 CONV_DW<<<block_num, threads_num>>>((const float*)inputs[0]->deviceId(), (const float*)inputs[1]->deviceId(),
164 nullptr, (float*)outputs[0]->deviceId(), (const constBuffer*)constPtr);
165 }
166 return NO_ERROR;
167 }
168
169
170
DECONV_DW(const float * input,const float * kernel,const float * bias,float * output,const constBuffer * uConstant)171 __global__ void DECONV_DW(const float* input, const float* kernel, const float* bias, float *output, const constBuffer* uConstant) {
172 for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < uConstant->total; i += blockDim.x * gridDim.x) {
173 {
174 int iw = uConstant->inputSize[0];
175 int ih = uConstant->inputSize[1];
176 int c = uConstant->channel;
177 int ow = uConstant->outputSize[0];
178 int oh = uConstant->outputSize[1];
179 int kw = uConstant->kernelSize[0];
180 int kh = uConstant->kernelSize[1];
181 int dw = uConstant->dilate[0];
182 int dh = uConstant->dilate[1];
183 int sw = uConstant->stride[0];
184 int sh = uConstant->stride[1];
185 int pw = uConstant->pad[0];
186 int ph = uConstant->pad[1];
187
188 int oz = i / (ow * oh);
189 int tmp = i % (ow * oh);
190 int oy = tmp / ow;
191 int ox = tmp % ow;
192 int kz = oz % uConstant->subChannel;
193
194 int ix = ox + pw;
195 int iy = oy + ph;
196 float color = 0.0;
197 if (bias != nullptr) {
198 color = bias[kz];
199 }
200
201 int fx, fy, fz;
202 for (fy=0; fy<kh; ++fy) {
203 int sy = iy - fy*dh;
204 int y = sy / sh;
205 if (sy % sh == 0 && y >= 0 && y < ih) {
206 for (int fx=0; fx<kw; ++fx) {
207 int sx = ix - fx*dw;
208 int x = sx / sw;
209 if (sx % sw == 0 && x >= 0 && x < iw) {
210 float inputValue = input[0
211 + x
212 + y * iw
213 + oz * iw * ih
214 ];
215 float k = kernel[0
216 + fx
217 + fy * kw
218 + kz * kw * kh
219 ];
220 color += k*inputValue;
221 }
222 }
223 }
224 }
225 output[0
226 + ox
227 + oy * ow
228 + oz * ow * oh
229 ] = color;
230 }
231 }
232 return;
233 }
234
235
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)236 ErrorCode DeconvDepthWiseExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
237 auto convCommon = mOp->main_as_Convolution2D()->common();
238 auto pad = ConvolutionCommon::convolutionTransposePad(inputs[0], outputs[0], convCommon);
239 constBuffer parameters;
240 parameters.pad[0] = pad.first;
241 parameters.pad[1] = pad.second;
242 parameters.kernelSize[0] = convCommon->kernelX();
243 parameters.kernelSize[1] = convCommon->kernelY();
244 parameters.stride[0] = convCommon->strideX();
245 parameters.stride[1] = convCommon->strideY();
246 parameters.dilate[0] = convCommon->dilateX();
247 parameters.dilate[1] = convCommon->dilateY();
248 parameters.inputSize[0] = inputs[0]->width();
249 parameters.inputSize[1] = inputs[0]->height();
250 parameters.channel = inputs[0]->batch() * inputs[0]->channel();
251 parameters.outputSize[0] = outputs[0]->width();
252 parameters.outputSize[1] = outputs[0]->height();
253 parameters.total = parameters.channel * parameters.outputSize[1] * parameters.outputSize[0];
254 parameters.subChannel = inputs[0]->channel();
255 auto constPtr = (uint8_t*)mConstBuffer.first + mConstBuffer.second;
256
257 auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
258 runtime->memcpy(constPtr, ¶meters, sizeof(constBuffer), MNNMemcpyHostToDevice);
259 mTotalCount = parameters.total;
260 return NO_ERROR;
261 }
262
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)263 ErrorCode DeconvDepthWiseExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
264 auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
265 int block_num = runtime->blocks_num(mTotalCount);
266 int threads_num = runtime->threads_num();
267 auto constPtr = (uint8_t*)mConstBuffer.first + mConstBuffer.second;
268 if (inputs.size() > 2) {
269 DECONV_DW<<<block_num, threads_num>>>((const float*)inputs[0]->deviceId(), (const float*)inputs[1]->deviceId(),
270 (const float*)inputs[2]->deviceId(), (float*)outputs[0]->deviceId(), (const constBuffer*)constPtr);
271 } else {
272 DECONV_DW<<<block_num, threads_num>>>((const float*)inputs[0]->deviceId(), (const float*)inputs[1]->deviceId(),
273 nullptr, (float*)outputs[0]->deviceId(), (const constBuffer*)constPtr);
274 }
275 return NO_ERROR;
276 }
277
278
279 class ConvDepthWiseExecutionCreator : public CUDABackend::Creator {
280 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const281 virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
282 const MNN::Op* op, Backend* backend) const override {
283 if (OpType_ConvolutionDepthwise == op->type()) {
284 return new ConvDepthWiseExecution(op, backend);
285 }
286 if (inputs.size() == 1) {
287 MNN_PRINT("deconv depthwise not support 1 input yet\n");
288 return nullptr;
289 }
290 return new DeconvDepthWiseExecution(op, backend);
291 }
292 };
293
294 static CUDACreatorRegister<ConvDepthWiseExecutionCreator> __init(OpType_ConvolutionDepthwise);
295 static CUDACreatorRegister<ConvDepthWiseExecutionCreator> __init2(OpType_DeconvolutionDepthwise);
296 }
297 }