1 //
2 //  TRTCommonExecution.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/02/28.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef TRTCommonExecution_hpp
10 #define TRTCommonExecution_hpp
11 #include "TRTBackend.hpp"
12 #include "core/Execution.hpp"
13 #include "schema/current/MNNPlugin_generated.h"
14 using namespace std;
15 namespace MNN {
16 
createPluginWithOutput(const std::vector<Tensor * > & outputs)17 inline static std::shared_ptr<MNNTRTPlugin::PluginT> createPluginWithOutput(const std::vector<Tensor *> &outputs) {
18     std::shared_ptr<MNNTRTPlugin::PluginT> plu(new MNNTRTPlugin::PluginT);
19     plu->outputs.resize(outputs.size());
20     for (int i = 0; i < outputs.size(); ++i) {
21         auto shape = outputs[0]->shape();
22         plu->outputs[i].reset(new MNNTRTPlugin::ShapeT);
23         plu->outputs[i]->dim   = shape;
24         plu->outputs[i]->bytes = outputs[i]->getType().bytes();
25         plu->outputs[i]->type  = outputs[i]->getType().code;
26     }
27     return plu;
28 }
29 
30 class TRTCommonExecution : public Execution {
31 public:
32     TRTCommonExecution(Backend *backend, const Op *op);
33     virtual ~TRTCommonExecution() = default;
34 
35     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
36     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
37 
38 protected:
39     TRTBackend *mTrtBackend;
40     const Op *mOp;
41     std::vector<Tensor *> mInputs;
42     std::vector<Tensor *> mOutputs;
43 
44     virtual std::vector<ITensor *> onEncode(const std::vector<ITensor *> &inputs) = 0;
45 };
46 
47 } // namespace MNN
48 #endif /* TRTCommonExecution_hpp */
49