1 //
2 // TRTScale.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/09/11.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "TRTRaster.hpp"
10 #include <core/TensorUtils.hpp>
11 #include "TRTBackend.hpp"
12 #include "schema/current/MNNPlugin_generated.h"
13
14 using namespace std;
15
16 namespace MNN {
17
18
TRTRaster(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)19 TRTRaster::TRTRaster(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
20 const std::vector<Tensor *> &outputs)
21 : MNN::TRTCommonExecution(b, op) {
22 // Do nothing
23 }
24
onEncode(const std::vector<ITensor * > & xOp)25 std::vector<ITensor *> TRTRaster::onEncode(const std::vector<ITensor *> &xOp) {
26 #ifdef TRT_LOG
27 MNN_PRINT("TRTRaster in\n");
28 #endif
29 std::vector<ITensor *> inputTensors;
30 std::map<const Tensor *, int> tensorMap;
31 auto des = TensorUtils::getDescribe(mInputs[0]);
32 for (auto ® : des->regions) {
33 if (tensorMap.find(reg.origin) == tensorMap.end()) {
34 tensorMap.insert(std::make_pair(reg.origin, tensorMap.size()));
35 }
36 }
37 inputTensors.resize(tensorMap.size());
38 for (auto &iter : tensorMap) {
39 inputTensors[iter.second] = mTrtBackend->getTensorOps(iter.first);
40 }
41 auto plu = createPluginWithOutput(mOutputs);
42 plu->main.type = MNNTRTPlugin::Parameter_RasterInfo;
43 plu->main.value = new MNNTRTPlugin::RasterInfoT;
44 auto raster = plu->main.AsRasterInfo();
45 raster->regions.resize(des->regions.size());
46 for (int i = 0; i < des->regions.size(); ++i) {
47 raster->regions[i].reset(new MNNTRTPlugin::RegionT);
48 auto &dst = raster->regions[i];
49 auto &src = des->regions[i];
50 dst->src.reset(new MNNTRTPlugin::ViewT);
51 dst->dst.reset(new MNNTRTPlugin::ViewT);
52 dst->size = {src.size[0], src.size[1], src.size[2]};
53 dst->index = tensorMap[src.origin];
54 dst->src->offset = src.src.offset;
55 dst->src->stride = {src.src.stride[0], src.src.stride[1], src.src.stride[2]};
56 dst->dst->offset = src.dst.offset;
57 dst->dst->stride = {src.dst.stride[0], src.dst.stride[1], src.dst.stride[2]};
58 }
59 raster->extra = MNNTRTPlugin::ExtraType_Normal;
60 if (!TensorUtils::regionIsFull(mInputs[0])) {
61 raster->extra = MNNTRTPlugin::ExtraType_Fill;
62 }
63 auto preluPlugin = (nvinfer1::IPluginExt *)MNNTRTCreatePlugion(mOp, plu.get());
64 nvinfer1::IPluginLayer *plugin = mTrtBackend->getNetwork()->addPluginExt(&inputTensors[0], inputTensors.size(),
65 *((nvinfer1::IPluginExt *)preluPlugin));
66 if (plugin == nullptr) {
67 MNN_PRINT("plugin == nullptr !!!");
68 }
69 // delete preluPlugin;
70 #ifdef TRT_LOG
71 MNN_PRINT("TRTRaster out\n");
72 #endif
73 mTrtBackend->pushReleaseLayer(preluPlugin);
74
75 return {plugin->getOutput(0)};
76 }
77
78 TRTCreatorRegister<TypedCreator<TRTRaster>> __raster_op(OpType_Raster);
79
80 } // namespace MNN
81
82