1 // 2 // TRTCommonExecution.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/02/28. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef TRTCommonExecution_hpp 10 #define TRTCommonExecution_hpp 11 #include "TRTBackend.hpp" 12 #include "core/Execution.hpp" 13 #include "schema/current/MNNPlugin_generated.h" 14 using namespace std; 15 namespace MNN { 16 createPluginWithOutput(const std::vector<Tensor * > & outputs)17inline static std::shared_ptr<MNNTRTPlugin::PluginT> createPluginWithOutput(const std::vector<Tensor *> &outputs) { 18 std::shared_ptr<MNNTRTPlugin::PluginT> plu(new MNNTRTPlugin::PluginT); 19 plu->outputs.resize(outputs.size()); 20 for (int i = 0; i < outputs.size(); ++i) { 21 auto shape = outputs[0]->shape(); 22 plu->outputs[i].reset(new MNNTRTPlugin::ShapeT); 23 plu->outputs[i]->dim = shape; 24 plu->outputs[i]->bytes = outputs[i]->getType().bytes(); 25 plu->outputs[i]->type = outputs[i]->getType().code; 26 } 27 return plu; 28 } 29 30 class TRTCommonExecution : public Execution { 31 public: 32 TRTCommonExecution(Backend *backend, const Op *op); 33 virtual ~TRTCommonExecution() = default; 34 35 virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 36 virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 37 38 protected: 39 TRTBackend *mTrtBackend; 40 const Op *mOp; 41 std::vector<Tensor *> mInputs; 42 std::vector<Tensor *> mOutputs; 43 44 virtual std::vector<ITensor *> onEncode(const std::vector<ITensor *> &inputs) = 0; 45 }; 46 47 } // namespace MNN 48 #endif /* TRTCommonExecution_hpp */ 49