1 //
2 //  DeconvSingleInputExecution.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/08/22.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "DeconvSingleInputExecution.hpp"
10 
11 namespace MNN {
12 namespace CUDA {
13 
14 template <typename T>
cutPad(const size_t size,const T * input,const int old_height,const int old_width,const int height,const int width,const int pad_top,const int pad_left,T * output)15 __global__ void cutPad(const size_t size, const T* input, const int old_height,
16                     const int old_width, const int height, const int width, const int pad_top,
17                     const int pad_left, T* output) {
18     for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
19         int block_num = pos / (width*height);
20         int left = pos % (width*height);
21         const int out_w = left % width;
22         const int out_h = left / width % height;
23 
24         output[pos] = input[(block_num * old_height + out_h + pad_top) * old_width + out_w + pad_left];
25     }
26     return;
27 }
28 
DeconvSingleInputExecution(Backend * backend,const MNN::Op * op)29 DeconvSingleInputExecution::DeconvSingleInputExecution(Backend* backend, const MNN::Op* op) : Execution(backend), mOp(op) {
30     //MNN_PRINT("cuda DeconvSingleInput onInit in\n");
31     auto conv       = op->main_as_Convolution2D();
32     auto common     = conv->common();
33 
34     mKernelInfo.groups         = common->group();
35     mKernelInfo.kernelX        = common->kernelX();
36     mKernelInfo.kernelY        = common->kernelY();
37     mKernelInfo.padMode        = common->padMode();
38     mKernelInfo.padX           = common->padX();
39     mKernelInfo.padY           = common->padY();
40 
41     if (nullptr != common->pads()) {
42         mKernelInfo.padX = common->pads()->data()[1];
43         mKernelInfo.padY = common->pads()->data()[0];
44     }
45     pad_left_  = mKernelInfo.padX;
46     pad_right_ = mKernelInfo.padX;
47     pad_top_ = mKernelInfo.padY;
48     pad_bottom_ = mKernelInfo.padY;
49 
50     mKernelInfo.strideX        = common->strideX();
51     mKernelInfo.strideY        = common->strideY();
52     mKernelInfo.dilateX        = common->dilateX();
53     mKernelInfo.dilateY        = common->dilateY();
54     mKernelInfo.activationType = common->relu() ? 1 : (common->relu6() ? 2 : 0);
55 
56     use_relu_ = (mKernelInfo.activationType == 1);
57     use_relu6_ = (mKernelInfo.activationType == 2);
58 
59     cudnn_handle_ = nullptr;
60     input_desc_ = nullptr;
61     output_desc_ = nullptr;
62     filter_desc_ = nullptr;
63     conv_desc_ = nullptr;
64     padded_desc_ = nullptr;
65     cudnn_data_type_ = CUDNN_DATA_FLOAT;
66     cudnn_data_type_len_ = 0;
67 
68     auto runtime = static_cast<CUDABackend*>(backend)->getCUDARuntime();
69     cudnn_handle_ = runtime->cudnn_handle();
70     cudnn_check(cudnnCreateTensorDescriptor(&input_desc_));
71     cudnn_check(cudnnCreateTensorDescriptor(&output_desc_));
72     cudnn_check(cudnnCreateTensorDescriptor(&padded_desc_));
73     cudnn_check(cudnnCreateTensorDescriptor(&bias_desc_));
74     cudnn_check(cudnnCreateFilterDescriptor(&filter_desc_));
75     cudnn_check(cudnnCreateConvolutionDescriptor(&conv_desc_));
76     cudnn_check(cudnnCreateActivationDescriptor(&act_desc_));
77 
78 
79     //weight host->device
80     const float* filterDataPtr = nullptr;
81     int weightSize = 0;
82     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
83     ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
84     weightTensor.reset(Tensor::createDevice<float>({weightSize}));
85     backend->onAcquireBuffer(weightTensor.get(), Backend::STATIC);
86     mFilter = (void *)weightTensor.get()->buffer().device;
87     cuda_check(cudaMemcpy(mFilter, filterDataPtr, weightSize*sizeof(float), cudaMemcpyHostToDevice));
88 
89 
90     if(conv->bias()->size() != 0) {
91         int biasSize = conv->bias()->size();
92         biasTensor.reset(Tensor::createDevice<float>({biasSize}));
93         backend->onAcquireBuffer(biasTensor.get(), Backend::STATIC);
94         mBias = (void *)biasTensor.get()->buffer().device;
95 
96         cuda_check(cudaMemcpy(mBias, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice));
97 
98         int bias_size = conv->bias()->size();
99         int dim_bias[] = {1, bias_size, 1, 1};
100         int stride_bias[] = {bias_size, 1, 1, 1};
101         if(cudnn_data_type_ == CUDNN_DATA_FLOAT) {
102             cudnn_check(cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 4, dim_bias, stride_bias));
103         }
104         else if(cudnn_data_type_ == CUDNN_DATA_HALF) {
105             cudnn_check(cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_HALF, 4, dim_bias, stride_bias));
106         } else {
107             MNN_PRINT("only supports fp32/fp16 data type!!!\n");
108         }
109         use_bias_ = true;
110     }
111 }
112 
~DeconvSingleInputExecution()113 DeconvSingleInputExecution::~DeconvSingleInputExecution() {
114     cudnn_check(cudnnDestroyConvolutionDescriptor(conv_desc_));
115     cudnn_check(cudnnDestroyFilterDescriptor(filter_desc_));
116     cudnn_check(cudnnDestroyTensorDescriptor(padded_desc_));
117     cudnn_check(cudnnDestroyTensorDescriptor(output_desc_));
118     cudnn_check(cudnnDestroyTensorDescriptor(input_desc_));
119     cudnn_check(cudnnDestroyTensorDescriptor(bias_desc_));
120     cudnn_check(cudnnDestroyActivationDescriptor(act_desc_));
121 
122     if (nullptr != weightTensor) {
123         backend()->onReleaseBuffer(weightTensor.get(), Backend::STATIC);
124     }
125     if(use_bias_ && nullptr != biasTensor) {
126         backend()->onReleaseBuffer(biasTensor.get(), Backend::STATIC);
127     }
128     if(workspace_size_!=0 && nullptr != workspaceTensor) {
129         backend()->onReleaseBuffer(workspaceTensor.get(), Backend::DYNAMIC_SEPERATE);
130     }
131 }
132 
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)133 ErrorCode DeconvSingleInputExecution::onResize(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
134     // prepare
135     //MNN_PRINT("cuda DeconvSingleInput onResize in, pad:%d\n", mKernelInfo.padX);
136     auto input = inputs[0], output = outputs[0];
137 
138     mIOInfo.iw = input->width();
139     mIOInfo.ih = input->height();
140     mIOInfo.ic = input->channel();
141     mIOInfo.ib = input->batch();
142 
143     mIOInfo.ow = output->width();
144     mIOInfo.oh = output->height();
145     mIOInfo.oc = output->channel();
146     mIOInfo.ob = output->batch();
147 
148     mKernelInfo.kernelN = output->channel();
149     mKernelInfo.kernelC = input->channel() / mKernelInfo.groups;
150 
151     std::vector<int> in_shape = {mIOInfo.ib, mIOInfo.ic, mIOInfo.ih, mIOInfo.iw};
152     std::vector<int> output_shape = {mIOInfo.ob, mIOInfo.oc, mIOInfo.oh, mIOInfo.ow};
153     std::vector<int> filter_shape = {mKernelInfo.kernelC, mKernelInfo.kernelN, mKernelInfo.kernelY, mKernelInfo.kernelX};//deconv (ic oc kh kw)
154 
155     // printf("filter:%d %d %d %d\n", filter_shape[0], filter_shape[1], filter_shape[2], filter_shape[3]);
156     // printf("input:%d %d %d %d\n", in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
157     // printf("output:%d %d %d %d\n", output_shape[0], output_shape[1], output_shape[2], output_shape[3]);
158     cudnn_check(cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, in_shape[0],
159                                 in_shape[1], in_shape[2], in_shape[3]));
160 
161     cudnn_check(cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, filter_shape[0],
162                                 filter_shape[1], filter_shape[2], filter_shape[3]));
163     cudnn_check(cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, output_shape[0],
164                                 output_shape[1], output_shape[2], output_shape[3]));
165 
166 
167 
168     cudnnTensorDescriptor_t input_descriptor_real = nullptr;
169 
170     if (mKernelInfo.padMode == PadMode_SAME) {
171         int kernelWidthSize = (mKernelInfo.kernelX - 1) * mKernelInfo.dilateX + 1;
172         int kernelHeightSize = (mKernelInfo.kernelY - 1) * mKernelInfo.dilateY + 1;
173         int pw = (mIOInfo.iw - 1) * mKernelInfo.strideX + kernelWidthSize - mIOInfo.ow;
174         int ph = (mIOInfo.ih - 1) * mKernelInfo.strideY + kernelHeightSize - mIOInfo.oh;
175         pad_left_  = pw/2;
176         pad_right_ = pw - pad_left_;
177         pad_top_ = ph/2;
178         pad_bottom_ = ph - pad_top_;
179     }
180 
181     use_pad_ = (pad_left_!=0 || pad_right_!=0 || pad_top_!=0 || pad_bottom_!=0 ) ? true : false;
182 
183     if(use_pad_) {
184         int totalSize = output_shape[0]*output_shape[1]*(output_shape[2]+pad_top_+pad_bottom_)*(output_shape[3]+pad_left_+pad_right_);
185         padTensor.reset(Tensor::createDevice<float>({totalSize}));
186         backend()->onAcquireBuffer(padTensor.get(), Backend::DYNAMIC);
187         mPadPtr = (void *)padTensor.get()->buffer().device;
188 
189         //dynamic memory release
190         backend()->onReleaseBuffer(padTensor.get(), Backend::DYNAMIC);
191 
192         cudnn_check(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, output_shape[0], output_shape[1],
193             output_shape[2] + +pad_top_+pad_bottom_, output_shape[3] + pad_left_+pad_right_));
194     }
195     input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_;
196 
197     cudnn_check(cudnnSetConvolution2dDescriptor(conv_desc_, 0, 0, mKernelInfo.strideY, mKernelInfo.strideX,
198                                 mKernelInfo.dilateY, mKernelInfo.dilateX, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT));
199     if (cudnn_data_type_ == CUDNN_DATA_HALF) {
200         cudnn_check(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
201     }
202     //set group num
203     cudnn_check(cudnnSetConvolutionGroupCount(conv_desc_, mKernelInfo.groups));
204 
205     // algorithm
206     constexpr int requested_algo_count = 1;
207     int returned_algo_count;
208     cudnnConvolutionBwdDataAlgoPerf_t perf_results;
209     cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, filter_desc_, input_descriptor_real, conv_desc_,
210         output_desc_,  requested_algo_count, &returned_algo_count, &perf_results));
211     conv_bwd_algo_ = perf_results.algo;
212 
213     // workspace
214     cudnn_check(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, filter_desc_, input_descriptor_real, conv_desc_, output_desc_,
215         conv_bwd_algo_, &workspace_size_));
216 
217     if (workspace_size_ != 0) {
218         int workspaceSize = workspace_size_;
219         workspaceTensor.reset(Tensor::createDevice<float>({workspaceSize}));
220         //cudnn not support workspace memory reuse
221         backend()->onAcquireBuffer(workspaceTensor.get(), Backend::DYNAMIC_SEPERATE);
222         mWorkSpace = (void *)workspaceTensor.get()->buffer().device;
223     }
224 
225     if(use_relu_) {
226         cudnn_check(cudnnSetActivationDescriptor(act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
227     } else if(use_relu6_) {
228         cudnn_check(cudnnSetActivationDescriptor(act_desc_, CUDNN_ACTIVATION_CLIPPED_RELU, CUDNN_NOT_PROPAGATE_NAN, 6.0));
229     } else {
230         //do nothing
231     }
232     //MNN_PRINT("cuda DeconvSingleInput onResize out\n");
233     return NO_ERROR;
234 }
235 
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)236 ErrorCode DeconvSingleInputExecution::onExecute(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
237     //MNN_PRINT("cuda DeconvSingleInput onExecute in, inputsize:%d %d\n", (int)inputs.size(), workspace_size_);
238 
239     MNN_ASSERT(inputs.size() == 1);
240     MNN_ASSERT(outputs.size() == 1);
241 
242     auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
243     const void *input_addr = (const void*)inputs[0]->deviceId();
244     const void *filter_addr = mFilter;
245     const void *bias_addr = mBias;
246 
247     void *output_addr = (void*)outputs[0]->deviceId();
248     void *workspace_addr = nullptr;
249     if (workspace_size_ != 0) {
250         workspace_addr = mWorkSpace;
251     }
252 
253     const float alpha = 1;
254     const float beta = 0;
255 
256 
257     if(use_pad_) {
258         cudnn_check(cudnnConvolutionBackwardData(cudnn_handle_, &alpha, filter_desc_, filter_addr, input_desc_, input_addr, conv_desc_,
259             conv_bwd_algo_, workspace_addr, workspace_size_, &beta, padded_desc_, mPadPtr));
260 
261         std::vector<int> out_shape = {mIOInfo.ob, mIOInfo.oc, mIOInfo.oh, mIOInfo.ow};
262 
263         int size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3];
264         int block_num = runtime->blocks_num(size);
265         int threads_num = runtime->threads_num();
266 
267         cutPad<<<block_num, threads_num>>>(size, (float*)mPadPtr, out_shape[2]+pad_top_+pad_bottom_, out_shape[3]+pad_left_+pad_right_,
268             out_shape[2], out_shape[3], pad_top_, pad_left_, (float*)output_addr);
269     }
270     else {
271         cudnn_check(cudnnConvolutionBackwardData(cudnn_handle_, &alpha, filter_desc_, filter_addr, input_desc_, input_addr, conv_desc_,
272             conv_bwd_algo_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr));
273     }
274 
275     if(use_bias_) {
276         cudnn_check(cudnnAddTensor(cudnn_handle_, &alpha, bias_desc_, bias_addr, &alpha, output_desc_, output_addr));
277     }
278     if(use_relu_ || use_relu6_) {
279         cudnn_check(cudnnActivationForward(cudnn_handle_, act_desc_, &alpha, output_desc_, output_addr, &beta, output_desc_, output_addr));
280     }
281     return NO_ERROR;
282 }
283 
284 class CUDADeconvolutionCreator : public CUDABackend::Creator {
285 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const286     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
287             const MNN::Op* op, Backend* backend) const override {
288         if (nullptr != op->main_as_Convolution2D()->quanParameter()) {
289             auto quan = op->main_as_Convolution2D()->quanParameter();
290             if (1 == quan->type() || 2 == quan->type()) {
291                 MNN_PRINT("cuda Deconv quant type 1 or 2 not support\n");
292                 return nullptr;
293             }
294         }
295 
296         if(inputs.size() == 3) {
297             MNN_PRINT("Deconv inputs size:3 not support\n");
298             return nullptr;
299         } else if(inputs.size() == 1) {
300             return new DeconvSingleInputExecution(backend, op);
301         } else {
302             MNN_PRINT("Deconv inputs size:%d not support", (int)inputs.size());
303             return nullptr;
304         }
305     }
306 };
307 
308 CUDACreatorRegister<CUDADeconvolutionCreator> __DeConvExecution(OpType_Deconvolution);
309 
310 }// namespace CUDA
311 }// namespace MNN