1 #include "SoftmaxExecution.hpp"
2
3 namespace MNN {
4 namespace CUDA {
5
SoftmaxExecution(int axis,Backend * backend)6 SoftmaxExecution::SoftmaxExecution(int axis, Backend *backend) : Execution(backend) {
7 auto runtime = static_cast<CUDABackend*>(backend)->getCUDARuntime();
8 cudnn_handle_ = runtime->cudnn_handle();
9
10 cudnn_check(cudnnCreateTensorDescriptor(&input_desc_));
11 cudnn_check(cudnnCreateTensorDescriptor(&output_desc_));
12
13 cudnn_data_type_ = CUDNN_DATA_FLOAT;
14 mAxis = axis;
15 }
16
~SoftmaxExecution()17 SoftmaxExecution::~SoftmaxExecution() {
18 cudnnDestroyTensorDescriptor(input_desc_);
19 cudnnDestroyTensorDescriptor(output_desc_);
20 }
21
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)22 ErrorCode SoftmaxExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
23 inside = 1;
24 outside = 1;
25 if(mAxis < 0) {
26 mAxis += inputs[0]->dimensions();
27 }
28 axis = inputs[0]->length(mAxis);
29 for (int i=0; i<mAxis; ++i) {
30 outside *= inputs[0]->length(i);
31 }
32 for (int i=mAxis+1; i<inputs[0]->dimensions(); ++i) {
33 inside *= inputs[0]->length(i);
34 }
35
36 std::vector<int> tensor_shape = {outside, axis, inside, 1};
37 cudnn_check(cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, tensor_shape[0],
38 tensor_shape[1], tensor_shape[2], tensor_shape[3]));
39
40 cudnn_check(cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, tensor_shape[0],
41 tensor_shape[1], tensor_shape[2], tensor_shape[3]));
42
43 return NO_ERROR;
44 }
45
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)46 ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
47 auto input = (void*)inputs[0]->deviceId();
48 auto output = (void*)outputs[0]->deviceId();
49
50 const float alpha = 1;
51 const float beta = 0;
52 cudnn_check(cudnnSoftmaxForward(cudnn_handle_, CUDNN_SOFTMAX_ACCURATE,
53 CUDNN_SOFTMAX_MODE_CHANNEL,
54 &alpha,
55 input_desc_, input,
56 &beta,
57 output_desc_, output));
58
59 return NO_ERROR;
60 }
61
62 class SoftmaxCreator : public CUDABackend::Creator {
63 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const64 virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
65 const MNN::Op* op, Backend* backend) const override {
66 auto type = inputs[0]->getType();
67 if (type.code != halide_type_float) {
68 MNN_PRINT("softmax data type:%s not support", type.code);
69 return nullptr;
70 }
71 auto axis = op->main_as_Axis()->axis();
72 return new SoftmaxExecution(axis, backend);
73 }
74 };
75
76 static CUDACreatorRegister<SoftmaxCreator> __init(OpType_Softmax);
77 }
78 }