1 //
2 //  DeconvSingleInputExecution.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/08/22.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef DeconvSingleInputExecution_hpp
10 #define DeconvSingleInputExecution_hpp
11 
12 #include "backend/cuda/core/CUDABackend.hpp"
13 #include "core/Execution.hpp"
14 #include "half.hpp"
15 
16 namespace MNN {
17 namespace CUDA {
18 
19 struct IOInfo {
20     int ib, ic, ih, iw;
21     int ob, oc, oh, ow;
22 };
23 struct KernelInfo {
24     int groups         = 0;
25     int kernelN        = 0;
26     int kernelC        = 0;
27     int kernelX        = 0;
28     int kernelY        = 0;
29     PadMode padMode    = PadMode_CAFFE;
30     int padX           = 0;
31     int padY           = 0;
32     int strideX        = 0;
33     int strideY        = 0;
34     int dilateX        = 0;
35     int dilateY        = 0;
36     int activationType = 0;
37 };//
38 
39 extern "C"
40 class DeconvSingleInputExecution : public Execution {
41 public:
42     DeconvSingleInputExecution(Backend* backend, const MNN::Op* op);
43     virtual ~DeconvSingleInputExecution();
44     virtual ErrorCode onResize(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) override;
45     virtual ErrorCode onExecute(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) override;
46 
47 private:
48     cudnnHandle_t cudnn_handle_;
49     cudnnTensorDescriptor_t input_desc_;
50     cudnnTensorDescriptor_t output_desc_;
51     cudnnFilterDescriptor_t filter_desc_;
52     cudnnConvolutionBwdDataAlgo_t conv_bwd_algo_;
53     cudnnConvolutionDescriptor_t conv_desc_;
54     cudnnTensorDescriptor_t bias_desc_;
55     cudnnTensorDescriptor_t padded_desc_;
56     cudnnActivationDescriptor_t act_desc_;
57 
58     cudnnDataType_t cudnn_data_type_;
59     int cudnn_data_type_len_;
60     bool use_pad_ = false;
61     int pad_top_ = 0;
62     int pad_bottom_ = 0;
63     int pad_left_ = 0;
64     int pad_right_ = 0;
65 
66     bool use_bias_ = false;
67     bool use_relu_ = false;
68     bool use_relu6_ = false;
69 
70     void* mPadPtr;
71     void* mFilter;
72     void* mBias;
73     void* mWorkSpace;
74     std::shared_ptr<Tensor> weightTensor;
75     std::shared_ptr<Tensor> biasTensor;
76     std::shared_ptr<Tensor> padTensor;
77     std::shared_ptr<Tensor> workspaceTensor;
78 
79     std::shared_ptr<Tensor> mPad;
80     std::shared_ptr<Tensor> mWorkspaceForward;
81 
82     size_t input_size_;
83     size_t filter_size_;
84     size_t output_size_;
85     size_t padded_size_;
86     size_t workspace_size_;
87 
88     const MNN::Op* mOp;
89     KernelInfo mKernelInfo;
90     IOInfo mIOInfo;
91     std::shared_ptr<Tensor> mTempInput;
92 };
93 
94 } // namespace CUDA
95 } // namespace MNN
96 
97 #endif /* DeconvSingleInputExecution */
98