1 #include "SoftmaxExecution.hpp"
2 
3 namespace MNN {
4 namespace CUDA {
5 
SoftmaxExecution(int axis,Backend * backend)6 SoftmaxExecution::SoftmaxExecution(int axis, Backend *backend) : Execution(backend) {
7     auto runtime = static_cast<CUDABackend*>(backend)->getCUDARuntime();
8     cudnn_handle_ = runtime->cudnn_handle();
9 
10     cudnn_check(cudnnCreateTensorDescriptor(&input_desc_));
11     cudnn_check(cudnnCreateTensorDescriptor(&output_desc_));
12 
13     cudnn_data_type_ = CUDNN_DATA_FLOAT;
14     mAxis = axis;
15 }
16 
~SoftmaxExecution()17 SoftmaxExecution::~SoftmaxExecution() {
18     cudnnDestroyTensorDescriptor(input_desc_);
19     cudnnDestroyTensorDescriptor(output_desc_);
20 }
21 
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)22 ErrorCode SoftmaxExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
23     inside = 1;
24     outside = 1;
25     if(mAxis < 0) {
26         mAxis += inputs[0]->dimensions();
27     }
28     axis = inputs[0]->length(mAxis);
29     for (int i=0; i<mAxis; ++i) {
30         outside *= inputs[0]->length(i);
31     }
32     for (int i=mAxis+1; i<inputs[0]->dimensions(); ++i) {
33         inside *= inputs[0]->length(i);
34     }
35 
36     std::vector<int> tensor_shape = {outside, axis, inside, 1};
37     cudnn_check(cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, tensor_shape[0],
38                                 tensor_shape[1], tensor_shape[2], tensor_shape[3]));
39 
40     cudnn_check(cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, tensor_shape[0],
41                                 tensor_shape[1], tensor_shape[2], tensor_shape[3]));
42 
43     return NO_ERROR;
44 }
45 
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)46 ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
47     auto input = (void*)inputs[0]->deviceId();
48     auto output = (void*)outputs[0]->deviceId();
49 
50     const float alpha = 1;
51     const float beta = 0;
52     cudnn_check(cudnnSoftmaxForward(cudnn_handle_, CUDNN_SOFTMAX_ACCURATE,
53                 CUDNN_SOFTMAX_MODE_CHANNEL,
54                 &alpha,
55                 input_desc_, input,
56                 &beta,
57                 output_desc_, output));
58 
59     return NO_ERROR;
60 }
61 
62 class SoftmaxCreator : public CUDABackend::Creator {
63 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const64     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
65                                 const MNN::Op* op, Backend* backend) const override {
66         auto type = inputs[0]->getType();
67         if (type.code != halide_type_float) {
68             MNN_PRINT("softmax data type:%s not support", type.code);
69             return nullptr;
70         }
71         auto axis = op->main_as_Axis()->axis();
72         return new SoftmaxExecution(axis, backend);
73     }
74 };
75 
76 static CUDACreatorRegister<SoftmaxCreator> __init(OpType_Softmax);
77 }
78 }