1 //
2 //  TRTOneHot.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/09/11.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "TRTOneHot.hpp"
10 #include <core/TensorUtils.hpp>
11 #include "TRTBackend.hpp"
12 #include "schema/current/MNNPlugin_generated.h"
13 
14 using namespace std;
15 
16 namespace MNN {
17 
TRTOneHot(Backend * b,const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)18 TRTOneHot::TRTOneHot(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
19                        const std::vector<Tensor *> &outputs)
20     : MNN::TRTCommonExecution(b, op) {
21     mAxis = op->main_as_OneHotParam()->axis();
22 }
23 
onEncode(const std::vector<ITensor * > & xOp)24 std::vector<ITensor *> TRTOneHot::onEncode(const std::vector<ITensor *> &xOp) {
25 #ifdef TRT_LOG
26     printf("TRTOneHot in\n");
27 #endif
28 
29     auto plu = createPluginWithOutput(mOutputs);
30 
31     auto indices        = mInputs[0];
32     auto onValueTensor  = mInputs[2];
33     auto offValueTensor = mInputs[3];
34 
35     int axis = mAxis;
36     if (axis < 0) {
37         axis += mOutputs[0]->dimensions();
38     }
39     int outerSize = 1;
40     for (int i = 0; i < axis; ++i) {
41         outerSize *= indices->length(i);
42     }
43 
44     const int innerSize   = indices->elementSize() / outerSize;
45 
46     auto dataType    = onValueTensor->getType();
47 
48     MNN_ASSERT(offValueTensor->getType() == dataType);
49     MNN_ASSERT(offValueTensor->getType() != halide_type_of<int>());
50 
51     plu->main.type  = MNNTRTPlugin::Parameter_OneHotInfo;
52     plu->main.value = new MNNTRTPlugin::OneHotInfoT;
53     auto onehotp     = plu->main.AsOneHotInfo();
54 
55     onehotp->outerSize   = outerSize;
56     onehotp->innerSize   = innerSize;
57 
58     auto interpPlugin = (nvinfer1::IPluginExt *)MNNTRTCreatePlugion(mOp, plu.get());
59     nvinfer1::IPluginLayer *plugin = mTrtBackend->getNetwork()->addPluginExt(&xOp[0], mInputs.size(), *((nvinfer1::IPluginExt *)interpPlugin));
60     if (plugin == nullptr) {
61         printf("Interp plugin == nullptr !!!\n");
62     }
63     mTrtBackend->pushReleaseLayer(interpPlugin);
64     return {plugin->getOutput(0)};
65 
66 }
67 
68 TRTCreatorRegister<TypedCreator<TRTOneHot>> __ont_hot_op(OpType_OneHot);
69 
70 } // namespace MNN
71