1 //
2 // Created by alibaba on 2019/9/11.
3 //
4 
5 #include "TRTSoftmax.hpp"
6 #include "TRTBackend.hpp"
7 #include <core/TensorUtils.hpp>
8 
9 using namespace std;
10 
11 namespace MNN {
12 
TRTSoftmax(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)13 TRTSoftmax::TRTSoftmax(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
14                                const std::vector<Tensor *> &outputs)
15     : MNN::TRTCommonExecution(b,op) {
16         int axis = mOp->main_as_Axis()->axis();
17         mAxis = axis < 0 ? axis + outputs[0]->dimensions(): axis;
18     }
19 
onEncode(const std::vector<ITensor * > & xOp)20 std::vector<ITensor *> TRTSoftmax::onEncode(const std::vector<ITensor *> &xOp) {
21 
22     auto softmax_layer = mTrtBackend->getNetwork()->addSoftMax(*(xOp[0]));
23     softmax_layer->setAxes(1U << mAxis);
24     return {softmax_layer->getOutput(0)};
25 }
26 
27 TRTCreatorRegister<TypedCreator<TRTSoftmax>> __softmax_op(OpType_Softmax);
28 
29 } // namespace MNN
30