1 //
2 // TRTOneHot.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/09/11.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "TRTOneHot.hpp"
10 #include <core/TensorUtils.hpp>
11 #include "TRTBackend.hpp"
12 #include "schema/current/MNNPlugin_generated.h"
13
14 using namespace std;
15
16 namespace MNN {
17
TRTOneHot(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)18 TRTOneHot::TRTOneHot(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
19 const std::vector<Tensor *> &outputs)
20 : MNN::TRTCommonExecution(b, op) {
21 mAxis = op->main_as_OneHotParam()->axis();
22 }
23
onEncode(const std::vector<ITensor * > & xOp)24 std::vector<ITensor *> TRTOneHot::onEncode(const std::vector<ITensor *> &xOp) {
25 #ifdef TRT_LOG
26 printf("TRTOneHot in\n");
27 #endif
28
29 auto plu = createPluginWithOutput(mOutputs);
30
31 auto indices = mInputs[0];
32 auto onValueTensor = mInputs[2];
33 auto offValueTensor = mInputs[3];
34
35 int axis = mAxis;
36 if (axis < 0) {
37 axis += mOutputs[0]->dimensions();
38 }
39 int outerSize = 1;
40 for (int i = 0; i < axis; ++i) {
41 outerSize *= indices->length(i);
42 }
43
44 const int innerSize = indices->elementSize() / outerSize;
45
46 auto dataType = onValueTensor->getType();
47
48 MNN_ASSERT(offValueTensor->getType() == dataType);
49 MNN_ASSERT(offValueTensor->getType() != halide_type_of<int>());
50
51 plu->main.type = MNNTRTPlugin::Parameter_OneHotInfo;
52 plu->main.value = new MNNTRTPlugin::OneHotInfoT;
53 auto onehotp = plu->main.AsOneHotInfo();
54
55 onehotp->outerSize = outerSize;
56 onehotp->innerSize = innerSize;
57
58 auto interpPlugin = (nvinfer1::IPluginExt *)MNNTRTCreatePlugion(mOp, plu.get());
59 nvinfer1::IPluginLayer *plugin = mTrtBackend->getNetwork()->addPluginExt(&xOp[0], mInputs.size(), *((nvinfer1::IPluginExt *)interpPlugin));
60 if (plugin == nullptr) {
61 printf("Interp plugin == nullptr !!!\n");
62 }
63 mTrtBackend->pushReleaseLayer(interpPlugin);
64 return {plugin->getOutput(0)};
65
66 }
67
68 TRTCreatorRegister<TypedCreator<TRTOneHot>> __ont_hot_op(OpType_OneHot);
69
70 } // namespace MNN
71