1 // Copyright (C) 2020 by Yuri Victorovich. All rights reserved.
2
3 #pragma once
4
5 //
6 // PluginInterface is the interface of all plugins, common for all NN model types
7 //
8
9 #include <string>
10 #include <vector>
11 #include <ostream>
12 #include <functional>
13
14 #include "nn-types.h"
15 #include "tensor.h"
16
17 class PluginInterface {
18
19 public:
~PluginInterface()20 virtual ~PluginInterface() { }
21
22 // capability flags
23 enum Capabilities {Capability_CanWrite=0x00000001};
24
25 // types related to the plugin interface
26 typedef unsigned TensorId;
27 typedef unsigned OperatorId;
28 enum OperatorKind { // all distinct operator kinds should be listed here
29 // XXX each value here has to be mirrored in plugin-interface.cpp (CASE)
30 KindConv2D,
31 KindDepthwiseConv2D,
32 KindPad,
33 KindMirrorPad,
34 KindFullyConnected,
35 KindLocalResponseNormalization,
36 KindMaxPool,
37 KindAveragePool,
38 // activation functions
39 KindRelu,
40 KindRelu6,
41 KindLeakyRelu,
42 KindTanh,
43 KindLogistic,
44 KindHardSwish,
45 // misc math functions
46 KindRSqrt,
47 //
48 KindAdd,
49 KindSub,
50 KindMul,
51 KindDiv,
52 KindMaximum,
53 KindMinimum,
54 KindTranspose,
55 KindReshape,
56 KindSoftmax,
57 KindConcatenation,
58 KindSplit,
59 KindStridedSlice,
60 KindMean,
61 // Data manipulations
62 KindDequantize, // convert any type of qint8, quint8, qint32, qint16, quint16 float
63 // Misc
64 KindArgMax,
65 KindArgMin,
66 KindSquaredDifference,
67 // Resizes
68 KindResizeBilinear,
69 KindResizeNearestNeighbor,
70 //
71 KindUnknown
72 };
73
74 enum DataType {
75 DataType_Float16,
76 DataType_Float32,
77 DataType_Float64,
78 DataType_Int8,
79 DataType_UInt8,
80 DataType_Int16,
81 DataType_Int32,
82 DataType_Int64
83 // TODO? BOOL, STRING, COMPLEX64 are also supported in TfLite specification
84 };
85
86 enum PaddingType {
87 PaddingType_SAME, // pad with zeros where data isn't available, result has the same shape
88 PaddingType_VALID // no padding, iterate only when all data is available for the extent of the kernel, result has a smaller shape
89 };
90
91 enum ActivationFunction {
92 ActivationFunction_NONE,
93 ActivationFunction_RELU,
94 ActivationFunction_RELU_N1_TO_1,
95 ActivationFunction_RELU6,
96 ActivationFunction_TANH,
97 ActivationFunction_SIGN_BIT
98 };
99
100 // OperatorOptionName represents a unique option value with a specific meaning each available only for some operators
101 enum OperatorOptionName {
102 OperatorOption_UNKNOWN,
103 // XXX for now options are a list of all options that occur as fiels for operator options in schema.fbs
104 // corresponds to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs 58719c1 Dec 17, 2019
105 // Regenerate from schema.fbs: 'flatc --jsonschema schema.fb', then prepend with 'var fbsSchema = ' and append the following:
106 #if 0
107 if (false) {
108 Object.keys(fbsSchema.definitions).forEach(function(name) {
109 if (name.length>7 && name.substring(name.length-7)=="Options" && name!="tflite_BuiltinOptions") {
110 print("struct: "+name);
111 Object.keys(fbsSchema.definitions[name].properties).forEach(function(p) {
112 print("... opt: "+p);
113 });
114 }
115 });
116 }
117 #endif
118 OperatorOption_ALIGN_CORNERS,
119 OperatorOption_ALPHA,
120 OperatorOption_AXIS,
121 OperatorOption_BATCH_DIM,
122 OperatorOption_BEGIN_MASK,
123 OperatorOption_BETA,
124 OperatorOption_BIAS,
125 OperatorOption_BLOCK_SIZE,
126 OperatorOption_BODY_SUBGRAPH_INDEX,
127 OperatorOption_CELL_CLIP,
128 OperatorOption_COMBINER,
129 OperatorOption_COND_SUBGRAPH_INDEX,
130 OperatorOption_DEPTH_MULTIPLIER,
131 OperatorOption_DILATION_H_FACTOR,
132 OperatorOption_DILATION_W_FACTOR,
133 OperatorOption_ELLIPSIS_MASK,
134 OperatorOption_ELSE_SUBGRAPH_INDEX,
135 OperatorOption_EMBEDDING_DIM_PER_CHANNEL,
136 OperatorOption_END_MASK,
137 OperatorOption_FILTER_HEIGHT,
138 OperatorOption_FILTER_WIDTH,
139 OperatorOption_FUSED_ACTIVATION_FUNCTION,
140 OperatorOption_IDX_OUT_TYPE,
141 OperatorOption_IN_DATA_TYPE,
142 OperatorOption_INCLUDE_ALL_NGRAMS,
143 OperatorOption_KEEP_DIMS,
144 OperatorOption_KEEP_NUM_DIMS,
145 OperatorOption_KERNEL_TYPE,
146 OperatorOption_MAX,
147 OperatorOption_MAX_SKIP_SIZE,
148 OperatorOption_MERGE_OUTPUTS,
149 OperatorOption_MIN,
150 OperatorOption_MODE,
151 OperatorOption_NARROW_RANGE,
152 OperatorOption_NEW_AXIS_MASK,
153 OperatorOption_NEW_HEIGHT,
154 OperatorOption_NEW_SHAPE,
155 OperatorOption_NEW_WIDTH,
156 OperatorOption_NGRAM_SIZE,
157 OperatorOption_NUM,
158 OperatorOption_NUM_BITS,
159 OperatorOption_NUM_CHANNELS,
160 OperatorOption_NUM_COLUMNS_PER_CHANNEL,
161 OperatorOption_NUM_SPLITS,
162 OperatorOption_OUT_DATA_TYPE,
163 OperatorOption_OUT_TYPE,
164 OperatorOption_OUTPUT_TYPE,
165 OperatorOption_PADDING,
166 OperatorOption_PROJ_CLIP,
167 OperatorOption_RADIUS,
168 OperatorOption_RANK,
169 OperatorOption_SEQ_DIM,
170 OperatorOption_SHRINK_AXIS_MASK,
171 OperatorOption_SQUEEZE_DIMS,
172 OperatorOption_STRIDE_H,
173 OperatorOption_STRIDE_W,
174 OperatorOption_SUBGRAPH,
175 OperatorOption_THEN_SUBGRAPH_INDEX,
176 OperatorOption_TIME_MAJOR,
177 OperatorOption_TYPE,
178 OperatorOption_VALIDATE_INDICES,
179 OperatorOption_VALUES_COUNT,
180 OperatorOption_WEIGHTS_FORMAT
181 };
182
183 enum OperatorOptionType {
184 OperatorOption_TypeBool,
185 OperatorOption_TypeFloat,
186 OperatorOption_TypeInt,
187 OperatorOption_TypeUInt,
188 OperatorOption_TypeIntArray,
189 OperatorOption_TypePaddingType,
190 OperatorOption_TypeActivationFunction
191 };
192 struct OperatorOptionValue {
193 OperatorOptionType type;
194 bool b;
195 float f;
196 int32_t i;
197 uint32_t u;
198 PaddingType paddingType;
199 ActivationFunction activationFunction;
200 std::vector<int32_t> ii;
OperatorOptionValueOperatorOptionValue201 OperatorOptionValue(bool b_) : type(OperatorOption_TypeBool), b(b_) { }
OperatorOptionValueOperatorOptionValue202 OperatorOptionValue(float f_) : type(OperatorOption_TypeFloat), f(f_) { }
OperatorOptionValueOperatorOptionValue203 OperatorOptionValue(int32_t i_) : type(OperatorOption_TypeInt), i(i_) { }
OperatorOptionValueOperatorOptionValue204 OperatorOptionValue(uint32_t u_) : type(OperatorOption_TypeUInt), u(u_) { }
OperatorOptionValueOperatorOptionValue205 OperatorOptionValue(const std::vector<int32_t> &ii_) : type(OperatorOption_TypeIntArray), ii(ii_) { }
OperatorOptionValueOperatorOptionValue206 OperatorOptionValue(PaddingType paddingType_) : type(OperatorOption_TypePaddingType), paddingType(paddingType_) { }
OperatorOptionValueOperatorOptionValue207 OperatorOptionValue(ActivationFunction activationFunction_) : type(OperatorOption_TypeActivationFunction), activationFunction(activationFunction_) { }
208
209 // templetized getter
210 template<typename T> T as() const; // not implemented by default
211 };
212
213 // OperatorOption is what is returned by the plugin for individual operators
214 struct OperatorOption { // represents a "variable": type name = value;
215 OperatorOptionName name; // like a variable name, name is assigned a fixed meaning across operators
216 OperatorOptionValue value; // like a variable type and value
217 };
218
219 typedef std::vector<OperatorOption> OperatorOptionsList;
220
221 friend std::ostream& operator<<(std::ostream &os, OperatorKind okind);
222 friend std::ostream& operator<<(std::ostream &os, DataType dataType);
223 friend std::ostream& operator<<(std::ostream &os, PaddingType paddingType);
224 friend std::ostream& operator<<(std::ostream &os, ActivationFunction afunc);
225 friend std::ostream& operator<<(std::ostream &os, OperatorOptionName optName);
226 friend std::ostream& operator<<(std::ostream &os, OperatorOptionType optType);
227 friend std::ostream& operator<<(std::ostream &os, const OperatorOptionValue &optValue);
228
229 // inner-classes
230 class Model { // Model represents one of potentially many models contained in the file
231 public:
~Model()232 virtual ~Model() { } // has to be inlined for plugins to contain it too // not to be called by users, hence 'protected'
233 public: // interface
234 virtual unsigned numInputs() const = 0; // how many inputs does this model have
235 virtual std::vector<TensorId> getInputs() const = 0; // input indexes
236 virtual unsigned numOutputs() const = 0; // how many outputs does this model have
237 virtual std::vector<TensorId> getOutputs() const = 0; // output indexes
238 virtual unsigned numOperators() const = 0; // how many operators does this model have
239 virtual void getOperatorIo(unsigned operatorIdx, std::vector<TensorId> &inputs, std::vector<TensorId> &outputs) const = 0;
240 virtual OperatorKind getOperatorKind(unsigned operatorIdx) const = 0;
241 virtual OperatorOptionsList* getOperatorOptions(unsigned operatorIdx) const = 0;
242 virtual unsigned numTensors() const = 0; // number of tensors in this model
243 virtual TensorShape getTensorShape(TensorId tensorId) const = 0;
244 virtual DataType getTensorType(TensorId tensorId) const = 0;
245 virtual std::string getTensorName(TensorId tensorId) const = 0;
246 virtual bool getTensorHasData(TensorId tensorId) const = 0; // tensors that are fixed have buffers
247 virtual const void* getTensorData(TensorId tensorId) const = 0; // can only be called when getTensorHasData()=true
248 virtual const float* getTensorDataF32(TensorId tensorId) const = 0; // can only be called when getTensorHasData()=true
249 virtual bool getTensorIsVariableFlag(TensorId tensorId) const = 0; // some tensors are variables that can be altered
250
251 public: // convenience functions
252 bool isTensorComputed(TensorId tensorId) const;
253 };
254
255 // plugin interface
256 virtual uint32_t capabilities() const = 0;
257 virtual std::string filePath() const = 0; // returns back the file name that it was opened from
258 virtual std::string modelDescription() const = 0; // description of the model from the NN file
259 virtual bool open(const std::string &filePath_) = 0; // open the model (can only be done once per object, close is implicit on destruction for simplicity)
260 virtual std::string errorMessage() const = 0; // returns the error of the last operation if it has failed
261 virtual size_t numModels() const = 0; // how many models does this file contain
262 virtual const Model* getModel(unsigned index) const = 0; // access to one model, the Model object is owned by the plugin
263 virtual void write(const Model *model, const std::string &fileName) const = 0; // write the model to disk
264 };
265
266 // gcc-9 needs explicit template specializations to be outside of class scope
as()267 template<> inline bool PluginInterface::OperatorOptionValue::as() const {return b;}
as()268 template<> inline float PluginInterface::OperatorOptionValue::as() const {return f;}
as()269 template<> inline int32_t PluginInterface::OperatorOptionValue::as() const {return i;}
as()270 template<> inline PluginInterface::PaddingType PluginInterface::OperatorOptionValue::as() const {return paddingType;}
as()271 template<> inline PluginInterface::ActivationFunction PluginInterface::OperatorOptionValue::as() const {return activationFunction;}
272
273