1 // Copyright (C) 2020 by Yuri Victorovich. All rights reserved.
2 
3 #pragma once
4 
5 //
6 // PluginInterface is the interface of all plugins, common for all NN model types
7 //
8 
9 #include <string>
10 #include <vector>
11 #include <ostream>
12 #include <functional>
13 
14 #include "nn-types.h"
15 #include "tensor.h"
16 
17 class PluginInterface {
18 
19 public:
~PluginInterface()20 	virtual ~PluginInterface() { }
21 
22 	// capability flags
23 	enum Capabilities {Capability_CanWrite=0x00000001};
24 
25 	// types related to the plugin interface
26 	typedef unsigned TensorId;
27 	typedef unsigned OperatorId;
28 	enum OperatorKind { // all distinct operator kinds should be listed here
29 		// XXX each value here has to be mirrored in plugin-interface.cpp (CASE)
30 		KindConv2D,
31 		KindDepthwiseConv2D,
32 		KindPad,
33 		KindMirrorPad,
34 		KindFullyConnected,
35 		KindLocalResponseNormalization,
36 		KindMaxPool,
37 		KindAveragePool,
38 		// activation functions
39 		KindRelu,
40 		KindRelu6,
41 		KindLeakyRelu,
42 		KindTanh,
43 		KindLogistic,
44 		KindHardSwish,
45 		// misc math functions
46 		KindRSqrt,
47 		//
48 		KindAdd,
49 		KindSub,
50 		KindMul,
51 		KindDiv,
52 		KindMaximum,
53 		KindMinimum,
54 		KindTranspose,
55 		KindReshape,
56 		KindSoftmax,
57 		KindConcatenation,
58 		KindSplit,
59 		KindStridedSlice,
60 		KindMean,
61 		// Data manipulations
62 		KindDequantize,  // convert any type of qint8, quint8, qint32, qint16, quint16 float
63 		// Misc
64 		KindArgMax,
65 		KindArgMin,
66 		KindSquaredDifference,
67 		// Resizes
68 		KindResizeBilinear,
69 		KindResizeNearestNeighbor,
70 		//
71 		KindUnknown
72 	};
73 
74 	enum DataType {
75 		DataType_Float16,
76 		DataType_Float32,
77 		DataType_Float64,
78 		DataType_Int8,
79 		DataType_UInt8,
80 		DataType_Int16,
81 		DataType_Int32,
82 		DataType_Int64
83 		// TODO? BOOL, STRING, COMPLEX64 are also supported in TfLite specification
84 	};
85 
86 	enum PaddingType {
87 		PaddingType_SAME,    // pad with zeros where data isn't available, result has the same shape
88 		PaddingType_VALID    // no padding, iterate only when all data is available for the extent of the kernel, result has a smaller shape
89 	};
90 
91 	enum ActivationFunction {
92 		ActivationFunction_NONE,
93 		ActivationFunction_RELU,
94 		ActivationFunction_RELU_N1_TO_1,
95 		ActivationFunction_RELU6,
96 		ActivationFunction_TANH,
97 		ActivationFunction_SIGN_BIT
98 	};
99 
100 	// OperatorOptionName represents a unique option value with a specific meaning each available only for some operators
101 	enum OperatorOptionName {
102 		OperatorOption_UNKNOWN,
103 		// XXX for now options are a list of all options that occur as fiels for operator options in schema.fbs
104 		// corresponds to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs 58719c1 Dec 17, 2019
105 		// Regenerate from schema.fbs: 'flatc --jsonschema schema.fb', then prepend with 'var fbsSchema = ' and append the following:
106 #if 0
107 if (false) {
108         Object.keys(fbsSchema.definitions).forEach(function(name) {
109                 if (name.length>7 && name.substring(name.length-7)=="Options" && name!="tflite_BuiltinOptions") {
110                         print("struct: "+name);
111                         Object.keys(fbsSchema.definitions[name].properties).forEach(function(p) {
112                                 print("... opt: "+p);
113                         });
114                 }
115         });
116 }
117 #endif
118 		OperatorOption_ALIGN_CORNERS,
119 		OperatorOption_ALPHA,
120 		OperatorOption_AXIS,
121 		OperatorOption_BATCH_DIM,
122 		OperatorOption_BEGIN_MASK,
123 		OperatorOption_BETA,
124 		OperatorOption_BIAS,
125 		OperatorOption_BLOCK_SIZE,
126 		OperatorOption_BODY_SUBGRAPH_INDEX,
127 		OperatorOption_CELL_CLIP,
128 		OperatorOption_COMBINER,
129 		OperatorOption_COND_SUBGRAPH_INDEX,
130 		OperatorOption_DEPTH_MULTIPLIER,
131 		OperatorOption_DILATION_H_FACTOR,
132 		OperatorOption_DILATION_W_FACTOR,
133 		OperatorOption_ELLIPSIS_MASK,
134 		OperatorOption_ELSE_SUBGRAPH_INDEX,
135 		OperatorOption_EMBEDDING_DIM_PER_CHANNEL,
136 		OperatorOption_END_MASK,
137 		OperatorOption_FILTER_HEIGHT,
138 		OperatorOption_FILTER_WIDTH,
139 		OperatorOption_FUSED_ACTIVATION_FUNCTION,
140 		OperatorOption_IDX_OUT_TYPE,
141 		OperatorOption_IN_DATA_TYPE,
142 		OperatorOption_INCLUDE_ALL_NGRAMS,
143 		OperatorOption_KEEP_DIMS,
144 		OperatorOption_KEEP_NUM_DIMS,
145 		OperatorOption_KERNEL_TYPE,
146 		OperatorOption_MAX,
147 		OperatorOption_MAX_SKIP_SIZE,
148 		OperatorOption_MERGE_OUTPUTS,
149 		OperatorOption_MIN,
150 		OperatorOption_MODE,
151 		OperatorOption_NARROW_RANGE,
152 		OperatorOption_NEW_AXIS_MASK,
153 		OperatorOption_NEW_HEIGHT,
154 		OperatorOption_NEW_SHAPE,
155 		OperatorOption_NEW_WIDTH,
156 		OperatorOption_NGRAM_SIZE,
157 		OperatorOption_NUM,
158 		OperatorOption_NUM_BITS,
159 		OperatorOption_NUM_CHANNELS,
160 		OperatorOption_NUM_COLUMNS_PER_CHANNEL,
161 		OperatorOption_NUM_SPLITS,
162 		OperatorOption_OUT_DATA_TYPE,
163 		OperatorOption_OUT_TYPE,
164 		OperatorOption_OUTPUT_TYPE,
165 		OperatorOption_PADDING,
166 		OperatorOption_PROJ_CLIP,
167 		OperatorOption_RADIUS,
168 		OperatorOption_RANK,
169 		OperatorOption_SEQ_DIM,
170 		OperatorOption_SHRINK_AXIS_MASK,
171 		OperatorOption_SQUEEZE_DIMS,
172 		OperatorOption_STRIDE_H,
173 		OperatorOption_STRIDE_W,
174 		OperatorOption_SUBGRAPH,
175 		OperatorOption_THEN_SUBGRAPH_INDEX,
176 		OperatorOption_TIME_MAJOR,
177 		OperatorOption_TYPE,
178 		OperatorOption_VALIDATE_INDICES,
179 		OperatorOption_VALUES_COUNT,
180 		OperatorOption_WEIGHTS_FORMAT
181 	};
182 
183 	enum OperatorOptionType {
184 		OperatorOption_TypeBool,
185 		OperatorOption_TypeFloat,
186 		OperatorOption_TypeInt,
187 		OperatorOption_TypeUInt,
188 		OperatorOption_TypeIntArray,
189 		OperatorOption_TypePaddingType,
190 		OperatorOption_TypeActivationFunction
191 	};
192 	struct OperatorOptionValue {
193 		OperatorOptionType type;
194 		bool     b;
195 		float    f;
196 		int32_t  i;
197 		uint32_t u;
198 		PaddingType        paddingType;
199 		ActivationFunction activationFunction;
200 		std::vector<int32_t>  ii;
OperatorOptionValueOperatorOptionValue201 		OperatorOptionValue(bool b_)                                : type(OperatorOption_TypeBool), b(b_) { }
OperatorOptionValueOperatorOptionValue202 		OperatorOptionValue(float f_)                               : type(OperatorOption_TypeFloat), f(f_) { }
OperatorOptionValueOperatorOptionValue203 		OperatorOptionValue(int32_t i_)                             : type(OperatorOption_TypeInt), i(i_) { }
OperatorOptionValueOperatorOptionValue204 		OperatorOptionValue(uint32_t u_)                            : type(OperatorOption_TypeUInt), u(u_) { }
OperatorOptionValueOperatorOptionValue205 		OperatorOptionValue(const std::vector<int32_t> &ii_)        : type(OperatorOption_TypeIntArray), ii(ii_) { }
OperatorOptionValueOperatorOptionValue206 		OperatorOptionValue(PaddingType paddingType_) : type(OperatorOption_TypePaddingType), paddingType(paddingType_) { }
OperatorOptionValueOperatorOptionValue207 		OperatorOptionValue(ActivationFunction activationFunction_) : type(OperatorOption_TypeActivationFunction), activationFunction(activationFunction_) { }
208 
209 		// templetized getter
210 		template<typename T> T as() const; // not implemented by default
211 	};
212 
213 	// OperatorOption is what is returned by the plugin for individual operators
214 	struct OperatorOption { // represents a "variable": type name = value;
215 		OperatorOptionName  name;   // like a variable name, name is assigned a fixed meaning across operators
216 		OperatorOptionValue value;  // like a variable type and value
217 	};
218 
219 	typedef std::vector<OperatorOption> OperatorOptionsList;
220 
221 	friend std::ostream& operator<<(std::ostream &os, OperatorKind okind);
222 	friend std::ostream& operator<<(std::ostream &os, DataType dataType);
223 	friend std::ostream& operator<<(std::ostream &os, PaddingType paddingType);
224 	friend std::ostream& operator<<(std::ostream &os, ActivationFunction afunc);
225 	friend std::ostream& operator<<(std::ostream &os, OperatorOptionName optName);
226 	friend std::ostream& operator<<(std::ostream &os, OperatorOptionType optType);
227 	friend std::ostream& operator<<(std::ostream &os, const OperatorOptionValue &optValue);
228 
229 	// inner-classes
230 	class Model { // Model represents one of potentially many models contained in the file
231 	public:
~Model()232 		virtual ~Model() { } // has to be inlined for plugins to contain it too // not to be called by users, hence 'protected'
233 	public: // interface
234 		virtual unsigned                numInputs() const = 0;                                                          // how many inputs does this model have
235 		virtual std::vector<TensorId>   getInputs() const = 0;                                                          // input indexes
236 		virtual unsigned                numOutputs() const = 0;                                                         // how many outputs does this model have
237 		virtual std::vector<TensorId>   getOutputs() const = 0;                                                         // output indexes
238 		virtual unsigned                numOperators() const = 0;                                                       // how many operators does this model have
239 		virtual void                    getOperatorIo(unsigned operatorIdx, std::vector<TensorId> &inputs, std::vector<TensorId> &outputs) const = 0;
240 		virtual OperatorKind            getOperatorKind(unsigned operatorIdx) const = 0;
241 		virtual OperatorOptionsList*    getOperatorOptions(unsigned operatorIdx) const = 0;
242 		virtual unsigned                numTensors() const = 0;                                                         // number of tensors in this model
243 		virtual TensorShape             getTensorShape(TensorId tensorId) const = 0;
244 		virtual DataType                getTensorType(TensorId tensorId) const = 0;
245 		virtual std::string             getTensorName(TensorId tensorId) const = 0;
246 		virtual bool                    getTensorHasData(TensorId tensorId) const = 0;                                  // tensors that are fixed have buffers
247 		virtual const void*             getTensorData(TensorId tensorId) const = 0;                                     // can only be called when getTensorHasData()=true
248 		virtual const float*            getTensorDataF32(TensorId tensorId) const = 0;                                  // can only be called when getTensorHasData()=true
249 		virtual bool                    getTensorIsVariableFlag(TensorId tensorId) const = 0;                           // some tensors are variables that can be altered
250 
251 	public: // convenience functions
252 		bool isTensorComputed(TensorId tensorId) const;
253 	};
254 
255 	// plugin interface
256 	virtual uint32_t capabilities() const = 0;
257 	virtual std::string filePath() const = 0;                                        // returns back the file name that it was opened from
258 	virtual std::string modelDescription() const = 0;                                // description of the model from the NN file
259 	virtual bool open(const std::string &filePath_) = 0;                             // open the model (can only be done once per object, close is implicit on destruction for simplicity)
260 	virtual std::string errorMessage() const = 0;                                    // returns the error of the last operation if it has failed
261 	virtual size_t numModels() const = 0;                                            // how many models does this file contain
262 	virtual const Model* getModel(unsigned index) const = 0;                         // access to one model, the Model object is owned by the plugin
263 	virtual void write(const Model *model, const std::string &fileName) const = 0;  // write the model to disk
264 };
265 
266 // gcc-9 needs explicit template specializations to be outside of class scope
as()267 template<> inline bool PluginInterface::OperatorOptionValue::as() const {return b;}
as()268 template<> inline float PluginInterface::OperatorOptionValue::as() const {return f;}
as()269 template<> inline int32_t PluginInterface::OperatorOptionValue::as() const {return i;}
as()270 template<> inline PluginInterface::PaddingType PluginInterface::OperatorOptionValue::as() const {return paddingType;}
as()271 template<> inline PluginInterface::ActivationFunction PluginInterface::OperatorOptionValue::as() const {return activationFunction;}
272 
273