1 //
2 //  TRTScale.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/09/11.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "TRTRaster.hpp"
10 #include <core/TensorUtils.hpp>
11 #include "TRTBackend.hpp"
12 #include "schema/current/MNNPlugin_generated.h"
13 
14 using namespace std;
15 
16 namespace MNN {
17 
18 
TRTRaster(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)19 TRTRaster::TRTRaster(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
20                      const std::vector<Tensor *> &outputs)
21     : MNN::TRTCommonExecution(b, op) {
22     // Do nothing
23 }
24 
onEncode(const std::vector<ITensor * > & xOp)25 std::vector<ITensor *> TRTRaster::onEncode(const std::vector<ITensor *> &xOp) {
26 #ifdef TRT_LOG
27     MNN_PRINT("TRTRaster in\n");
28 #endif
29     std::vector<ITensor *> inputTensors;
30     std::map<const Tensor *, int> tensorMap;
31     auto des = TensorUtils::getDescribe(mInputs[0]);
32     for (auto &reg : des->regions) {
33         if (tensorMap.find(reg.origin) == tensorMap.end()) {
34             tensorMap.insert(std::make_pair(reg.origin, tensorMap.size()));
35         }
36     }
37     inputTensors.resize(tensorMap.size());
38     for (auto &iter : tensorMap) {
39         inputTensors[iter.second] = mTrtBackend->getTensorOps(iter.first);
40     }
41     auto plu        = createPluginWithOutput(mOutputs);
42     plu->main.type  = MNNTRTPlugin::Parameter_RasterInfo;
43     plu->main.value = new MNNTRTPlugin::RasterInfoT;
44     auto raster     = plu->main.AsRasterInfo();
45     raster->regions.resize(des->regions.size());
46     for (int i = 0; i < des->regions.size(); ++i) {
47         raster->regions[i].reset(new MNNTRTPlugin::RegionT);
48         auto &dst = raster->regions[i];
49         auto &src = des->regions[i];
50         dst->src.reset(new MNNTRTPlugin::ViewT);
51         dst->dst.reset(new MNNTRTPlugin::ViewT);
52         dst->size        = {src.size[0], src.size[1], src.size[2]};
53         dst->index       = tensorMap[src.origin];
54         dst->src->offset = src.src.offset;
55         dst->src->stride = {src.src.stride[0], src.src.stride[1], src.src.stride[2]};
56         dst->dst->offset = src.dst.offset;
57         dst->dst->stride = {src.dst.stride[0], src.dst.stride[1], src.dst.stride[2]};
58     }
59     raster->extra = MNNTRTPlugin::ExtraType_Normal;
60     if (!TensorUtils::regionIsFull(mInputs[0])) {
61         raster->extra = MNNTRTPlugin::ExtraType_Fill;
62     }
63     auto preluPlugin               = (nvinfer1::IPluginExt *)MNNTRTCreatePlugion(mOp, plu.get());
64     nvinfer1::IPluginLayer *plugin = mTrtBackend->getNetwork()->addPluginExt(&inputTensors[0], inputTensors.size(),
65                                                                              *((nvinfer1::IPluginExt *)preluPlugin));
66     if (plugin == nullptr) {
67         MNN_PRINT("plugin == nullptr !!!");
68     }
69     // delete preluPlugin;
70 #ifdef TRT_LOG
71     MNN_PRINT("TRTRaster out\n");
72 #endif
73     mTrtBackend->pushReleaseLayer(preluPlugin);
74 
75     return {plugin->getOutput(0)};
76 }
77 
78 TRTCreatorRegister<TypedCreator<TRTRaster>> __raster_op(OpType_Raster);
79 
80 } // namespace MNN
81 
82