1 // 2 // Created by alibaba on 2019/9/11. 3 // 4 5 #include "TRTSoftmax.hpp" 6 #include "TRTBackend.hpp" 7 #include <core/TensorUtils.hpp> 8 9 using namespace std; 10 11 namespace MNN { 12 TRTSoftmax(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)13TRTSoftmax::TRTSoftmax(Backend *b, const Op *op, const std::vector<Tensor *> &inputs, 14 const std::vector<Tensor *> &outputs) 15 : MNN::TRTCommonExecution(b,op) { 16 int axis = mOp->main_as_Axis()->axis(); 17 mAxis = axis < 0 ? axis + outputs[0]->dimensions(): axis; 18 } 19 onEncode(const std::vector<ITensor * > & xOp)20std::vector<ITensor *> TRTSoftmax::onEncode(const std::vector<ITensor *> &xOp) { 21 22 auto softmax_layer = mTrtBackend->getNetwork()->addSoftMax(*(xOp[0])); 23 softmax_layer->setAxes(1U << mAxis); 24 return {softmax_layer->getOutput(0)}; 25 } 26 27 TRTCreatorRegister<TypedCreator<TRTSoftmax>> __softmax_op(OpType_Softmax); 28 29 } // namespace MNN 30