1 #pragma once
2
3 #include "onnx/defs/data_type_utils.h"
4 #include "onnx/proto_utils.h"
5 #include "onnx/string_utils.h"
6
7 namespace ONNX_NAMESPACE {
8
9 class GraphInferencer {
10 public:
11 // Perform inferencing on the graph contained in GraphInferencer.
12 // Returns the graph output types post-inferencing.
13 virtual std::vector<const TypeProto*> doInferencing(
14 const std::vector<const TypeProto*>& inputTypes,
15 const std::vector<const TensorProto*>& inputData) = 0;
16 virtual ~GraphInferencer() = default;
17 };
18
19 // Exception class used for handling errors in type and shape inference
20
21 class InferenceError final : public std::runtime_error {
22 public:
23 using std::runtime_error::runtime_error;
24
InferenceError(const std::string & message)25 InferenceError(const std::string& message) : std::runtime_error(message) {}
26
what()27 const char* what() const noexcept override {
28 if (!expanded_message_.empty()) {
29 return expanded_message_.c_str();
30 }
31 return std::runtime_error::what();
32 }
33
AppendContext(const std::string & context)34 void AppendContext(const std::string& context) {
35 expanded_message_ = ONNX_NAMESPACE::MakeString(
36 std::runtime_error::what(), "\n\n==> Context: ", context);
37 }
38
39 private:
40 std::string expanded_message_;
41 };
42
43 #define fail_type_inference(...) \
44 throw ONNX_NAMESPACE::InferenceError( \
45 ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__));
46
47 #define fail_shape_inference(...) \
48 throw ONNX_NAMESPACE::InferenceError( \
49 ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__));
50
51 struct InferenceContext {
52 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
53 virtual size_t getNumInputs() const = 0;
54 virtual const TypeProto* getInputType(size_t index) const = 0;
55 virtual const TensorProto* getInputData(size_t index) const = 0;
56 virtual size_t getNumOutputs() const = 0;
57 virtual TypeProto* getOutputType(size_t index) = 0;
58 virtual GraphInferencer* getGraphAttributeInferencer(
59 const std::string& attribute_name) = 0;
~InferenceContextInferenceContext60 virtual ~InferenceContext() {}
61 };
62
63 using InferenceFunction = std::function<void(InferenceContext&)>;
64
65 // This no-op inference function is used for operators without an
66 // inference implementation.
dummyInferenceFunction(InferenceContext &)67 inline void dummyInferenceFunction(InferenceContext&){};
68
69 template <typename T>
getRepeatedAttribute(InferenceContext & ctx,std::string attr_name,std::vector<T> & values)70 inline bool getRepeatedAttribute(
71 InferenceContext& ctx,
72 std::string attr_name,
73 std::vector<T>& values) {
74 const auto* attr = ctx.getAttribute(attr_name);
75 if (attr) {
76 values = RetrieveValues<T>(*attr);
77 return true;
78 } else {
79 return false;
80 }
81 }
82
getAttribute(InferenceContext & ctx,const std::string & attributeName,int64_t defaultValue)83 inline int64_t getAttribute(
84 InferenceContext& ctx,
85 const std::string& attributeName,
86 int64_t defaultValue) {
87 auto attr_proto = ctx.getAttribute(attributeName);
88 if ((nullptr != attr_proto) && attr_proto->has_i())
89 return attr_proto->i();
90 return defaultValue;
91 }
92
getAttribute(InferenceContext & ctx,const std::string & attributeName,const std::string & defaultValue)93 inline std::string getAttribute(
94 InferenceContext& ctx,
95 const std::string& attributeName,
96 const std::string& defaultValue) {
97 auto attr_proto = ctx.getAttribute(attributeName);
98 if ((nullptr != attr_proto) && attr_proto->has_s())
99 return attr_proto->s();
100 return defaultValue;
101 }
102
103 inline TensorShapeProto::Dimension operator*(
104 TensorShapeProto::Dimension dim1,
105 TensorShapeProto::Dimension dim2) {
106 TensorShapeProto::Dimension result;
107 if (dim1.has_dim_value() && dim2.has_dim_value()) {
108 result.set_dim_value(dim1.dim_value() * dim2.dim_value());
109 } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
110 return dim2;
111 } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
112 return dim1;
113 }
114 return result;
115 }
116
117 inline TensorShapeProto::Dimension operator*(
118 TensorShapeProto::Dimension dim1,
119 int64_t dim2) {
120 TensorShapeProto::Dimension result;
121 if (dim1.has_dim_value()) {
122 result.set_dim_value(dim1.dim_value() * dim2);
123 } else if (dim2 == 1) {
124 return dim1;
125 }
126 return result;
127 }
128
129 inline TensorShapeProto::Dimension operator/(
130 TensorShapeProto::Dimension dim1,
131 int64_t dim2) {
132 TensorShapeProto::Dimension result;
133 if (dim1.has_dim_value()) {
134 result.set_dim_value(dim1.dim_value() / dim2);
135 } else if (dim2 == 1) {
136 return dim1;
137 }
138 return result;
139 }
140
141 // if from >= upto_exclusive, return 1.
142 // Caller must make sure upto_exclusive is less than or equal to shape.size()
143 // Caller must make sure from>=0
144 inline TensorShapeProto::Dimension
multiplyDims(const TensorShapeProto & shape,int from,int upto_exclusive)145 multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
146 TensorShapeProto::Dimension dim;
147 dim.set_dim_value(1);
148 for (int i = from; i < upto_exclusive; ++i) {
149 dim = dim * shape.dim(i);
150 }
151 return dim;
152 }
153
154 // propagate the element type from an input type to an output type.
155 // if an existing output element type exists, validate it matches.
propagateElemTypeWithValidation(const TypeProto * input_type,TypeProto * output_type)156 inline void propagateElemTypeWithValidation(
157 const TypeProto* input_type,
158 TypeProto* output_type) {
159 if (nullptr == input_type) {
160 fail_type_inference("Input type was null");
161 }
162
163 if (input_type->value_case() != TypeProto::kTensorType) {
164 fail_type_inference(
165 "Input was expected to have tensor type. Got ",
166 input_type->value_case());
167 }
168
169 if (input_type->tensor_type().elem_type() == TensorProto::UNDEFINED) {
170 fail_type_inference("Element type of input was unknown");
171 }
172
173 if (output_type->value_case() == TypeProto::VALUE_NOT_SET) {
174 output_type->mutable_tensor_type()->set_elem_type(
175 input_type->tensor_type().elem_type());
176 } else if (output_type->value_case() == TypeProto::kTensorType) {
177 if (output_type->tensor_type().has_elem_type()) {
178 if (input_type->tensor_type().elem_type() !=
179 output_type->tensor_type().elem_type()) {
180 fail_type_inference(
181 "Input element type of ",
182 input_type->tensor_type().elem_type(),
183 " does not match existing output type of ",
184 output_type->tensor_type().elem_type());
185 }
186 } else {
187 output_type->mutable_tensor_type()->set_elem_type(
188 input_type->tensor_type().elem_type());
189 }
190 } else {
191 // This is not expected to happen
192 fail_type_inference(
193 "Output was expected to have tensor type. Got ",
194 output_type->value_case());
195 }
196 }
197
198 // Note: for all methods below for propagating type or shape, callers are
199 // responsible to handle optional inputs/outputs and ensure that the specified
200 // index value is less than NumInputs/NumOutputs.
201
propagateElemTypeFromInputToOutput(InferenceContext & ctx,size_t inputIndex,size_t outputIndex)202 inline void propagateElemTypeFromInputToOutput(
203 InferenceContext& ctx,
204 size_t inputIndex,
205 size_t outputIndex) {
206 auto input_type = ctx.getInputType(inputIndex);
207 if (nullptr == input_type ||
208 input_type->value_case() != TypeProto::kTensorType) {
209 fail_type_inference("Input ", inputIndex, " expected to have tensor type");
210 }
211 if (input_type->tensor_type().elem_type() == TensorProto::UNDEFINED) {
212 fail_type_inference("Element type of input ", inputIndex, " unknown");
213 }
214 auto output_type = ctx.getOutputType(outputIndex);
215 if (output_type->value_case() == TypeProto::kTensorType ||
216 output_type->value_case() == TypeProto::VALUE_NOT_SET) {
217 output_type->mutable_tensor_type()->set_elem_type(
218 input_type->tensor_type().elem_type());
219 } else {
220 // This is not expected to happen
221 fail_type_inference(
222 "Output ", outputIndex, " expected to have tensor type");
223 }
224 }
225
propagateElemTypeFromDtypeToOutput(InferenceContext & ctx,const int & data_type,size_t outputIndex)226 inline void propagateElemTypeFromDtypeToOutput(
227 InferenceContext& ctx,
228 const int& data_type,
229 size_t outputIndex) {
230 auto attribute_tensor_datatype = data_type;
231 auto output_type = ctx.getOutputType(outputIndex);
232 if (output_type->value_case() == TypeProto::kTensorType ||
233 output_type->value_case() == TypeProto::VALUE_NOT_SET) {
234 output_type->mutable_tensor_type()->set_elem_type(
235 attribute_tensor_datatype);
236 } else {
237 // This is not expected to happen
238 fail_type_inference(
239 "Output ", outputIndex, " expected to have tensor type");
240 }
241 }
242
propagateElemTypeFromDtypeToOutput(InferenceContext & ctx,const AttributeProto * attr,size_t outputIndex)243 inline void propagateElemTypeFromDtypeToOutput(
244 InferenceContext& ctx,
245 const AttributeProto* attr,
246 size_t outputIndex) {
247 if (attr->type() != AttributeProto::TENSOR) {
248 fail_type_inference("Attribute expected to have tensor type");
249 }
250 if (attr->t().dims().size() != 1) {
251 fail_type_inference("Attribute expected to have a one-dim tensor");
252 }
253 auto attribute_tensor_datatype = attr->t().data_type();
254 propagateElemTypeFromDtypeToOutput(
255 ctx, attribute_tensor_datatype, outputIndex);
256 }
257
hasInputShape(InferenceContext & ctx,size_t n)258 inline bool hasInputShape(InferenceContext& ctx, size_t n) {
259 return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) &&
260 ctx.getInputType(n)->has_tensor_type() &&
261 ctx.getInputType(n)->tensor_type().has_shape();
262 }
263
hasNInputShapes(InferenceContext & ctx,size_t n)264 inline bool hasNInputShapes(InferenceContext& ctx, size_t n) {
265 for (size_t i = 0; i < n; i++) {
266 if (!hasInputShape(ctx, i)) {
267 return false;
268 }
269 }
270 return true;
271 }
272
getInputShape(InferenceContext & ctx,size_t n)273 inline const TensorShapeProto& getInputShape(InferenceContext& ctx, size_t n) {
274 return ctx.getInputType(n)->tensor_type().shape();
275 }
276
277 // Caller must make sure fromDimIndex is strictly less than shape.dim_size()
appendSingleDimCopiedFromInputTypeToOutputType(InferenceContext & ctx,size_t inputIndex,size_t outputIndex,size_t fromDimIndex)278 inline void appendSingleDimCopiedFromInputTypeToOutputType(
279 InferenceContext& ctx,
280 size_t inputIndex,
281 size_t outputIndex,
282 size_t fromDimIndex) {
283 auto output_type = ctx.getOutputType(outputIndex);
284 auto input_type = ctx.getInputType(inputIndex);
285 if (TypeProto::kTensorType != output_type->value_case()) {
286 fail_type_inference(
287 "Output ", outputIndex, " expected to have tensor type");
288 }
289 if (TypeProto::kTensorType != input_type->value_case()) {
290 fail_type_inference("Input ", inputIndex, " expected to have tensor type");
291 }
292 auto* dim = ctx.getOutputType(outputIndex)
293 ->mutable_tensor_type()
294 ->mutable_shape()
295 ->add_dim();
296 *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
297 }
298
propagateShapeFromInputToOutput(InferenceContext & ctx,size_t inputIndex,size_t outputIndex)299 inline void propagateShapeFromInputToOutput(
300 InferenceContext& ctx,
301 size_t inputIndex,
302 size_t outputIndex) {
303 auto output_type = ctx.getOutputType(outputIndex);
304 auto input_type = ctx.getInputType(inputIndex);
305 if (TypeProto::kTensorType != input_type->value_case() ||
306 TypeProto::kTensorType != output_type->value_case()) {
307 throw std::runtime_error(ONNX_NAMESPACE::to_string(
308 ctx.getInputType(inputIndex)->tensor_type().shape().dim_size()));
309 }
310
311 *ctx.getOutputType(outputIndex)->mutable_tensor_type()->mutable_shape() =
312 ctx.getInputType(inputIndex)->tensor_type().shape();
313 }
314
propagateShapeAndTypeFromFirstInput(InferenceContext & ctx)315 inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
316 propagateElemTypeFromInputToOutput(ctx, 0, 0);
317 if (!hasNInputShapes(ctx, 1)) {
318 return;
319 }
320 propagateShapeFromInputToOutput(ctx, 0, 0);
321 }
322
updateOutputElemType(InferenceContext & ctx,size_t outputIndex,int32_t elemType)323 inline void updateOutputElemType(
324 InferenceContext& ctx,
325 size_t outputIndex,
326 int32_t elemType) {
327 auto output_type = ctx.getOutputType(outputIndex);
328 if ((output_type != nullptr) &&
329 (output_type->value_case() == TypeProto::kTensorType ||
330 output_type->value_case() == TypeProto::VALUE_NOT_SET)) {
331 output_type->mutable_tensor_type()->set_elem_type(elemType);
332 } else {
333 // This is not expected to happen
334 fail_type_inference(
335 "Output ", outputIndex, " expected to have tensor type");
336 }
337 }
338
339 // Infer type of an output from the value of a specified attribute, which is
340 // expected to have a valid value representing a TensorProto_DataType.
341 inline void propagateElemTypeFromAttributeToOutput(
342 InferenceContext& ctx,
343 const std::string& attributeName,
344 size_t outputIndex,
345 TensorProto_DataType default_value = TensorProto::UNDEFINED) {
346 auto attr_proto = ctx.getAttribute(attributeName);
347 if (nullptr == attr_proto) { // attribute not present
348 if (default_value != TensorProto::UNDEFINED) {
349 updateOutputElemType(ctx, outputIndex, default_value);
350 return;
351 } else
352 fail_type_inference(
353 "Value of attribute ", attributeName, " not specified");
354 }
355 if (!attr_proto->has_i()) {
356 fail_type_inference(
357 "Attribute ",
358 attributeName,
359 " should be of integer type and specify a type.");
360 }
361 auto attr_value = attr_proto->i();
362 auto elem_type = static_cast<TensorProto_DataType>(attr_value);
363 if (!TensorProto_DataType_IsValid(elem_type)) {
364 fail_type_inference(
365 "Attribute ", attributeName, " does not specify a valid type.");
366 }
367 updateOutputElemType(ctx, outputIndex, elem_type);
368 }
369
getOutputShape(InferenceContext & ctx,size_t n)370 inline TensorShapeProto* getOutputShape(InferenceContext& ctx, size_t n) {
371 auto output_type = ctx.getOutputType(n);
372 if ((output_type != nullptr) &&
373 (output_type->value_case() == TypeProto::kTensorType ||
374 output_type->value_case() == TypeProto::VALUE_NOT_SET)) {
375 return output_type->mutable_tensor_type()->mutable_shape();
376 } else
377 fail_type_inference("Output ", n, " expected to have tensor type");
378 }
379
updateOutputShape(InferenceContext & ctx,size_t outputIndex,const TensorShapeProto & shape)380 inline void updateOutputShape(
381 InferenceContext& ctx,
382 size_t outputIndex,
383 const TensorShapeProto& shape) {
384 auto* output_shape = getOutputShape(ctx, outputIndex);
385 *output_shape = shape;
386 }
387
updateOutputShape(InferenceContext & ctx,size_t outputIndex,const TensorProto & tensorProto)388 inline void updateOutputShape(
389 InferenceContext& ctx,
390 size_t outputIndex,
391 const TensorProto& tensorProto) {
392 auto* output_shape = getOutputShape(ctx, outputIndex);
393 for (auto d : tensorProto.dims()) {
394 auto* dim = output_shape->add_dim();
395 dim->set_dim_value(d);
396 }
397 }
398
updateOutputShape(InferenceContext & ctx,size_t outputIndex,std::initializer_list<TensorShapeProto::Dimension> dims)399 inline void updateOutputShape(
400 InferenceContext& ctx,
401 size_t outputIndex,
402 std::initializer_list<TensorShapeProto::Dimension> dims) {
403 auto* output_shape = getOutputShape(ctx, outputIndex);
404 for (auto& d : dims) {
405 auto* dim = output_shape->add_dim();
406 *dim = d;
407 }
408 }
409
410 // Infer shape of an output from the value of a specified attribute, which is
411 // expected to be a list of integers specifying a valid shape.
propagateShapeFromAttributeToOutput(InferenceContext & ctx,const std::string & attributeName,size_t outputIndex)412 inline void propagateShapeFromAttributeToOutput(
413 InferenceContext& ctx,
414 const std::string& attributeName,
415 size_t outputIndex) {
416 auto attr_proto = ctx.getAttribute(attributeName);
417 if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
418 (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
419 fail_shape_inference(
420 "Attribute ", attributeName, " should specify a shape");
421 }
422 auto& int_list = attr_proto->ints();
423 TensorShapeProto shape;
424 for (auto dim_size : int_list) {
425 if (dim_size < 0) {
426 fail_shape_inference(
427 "Negative values are not allowed in a shape specification");
428 }
429 shape.add_dim()->set_dim_value(dim_size);
430 }
431
432 updateOutputShape(ctx, outputIndex, shape);
433 }
434
multidirectionalBroadcastShapeInference(const std::vector<const TensorShapeProto * > & shapes,TensorShapeProto & resultShape)435 inline void multidirectionalBroadcastShapeInference(
436 const std::vector<const TensorShapeProto*>& shapes,
437 TensorShapeProto& resultShape) {
438 int result_shape_size = 0;
439 // Get the result shape size.
440 for (size_t i = 0; i < shapes.size(); ++i) {
441 if (shapes[i]->dim_size() > result_shape_size) {
442 result_shape_size = shapes[i]->dim_size();
443 }
444 }
445
446 for (int i = 0; i < result_shape_size; ++i) {
447 int64_t dim_value = 1;
448 TensorShapeProto_Dimension symbolic_dim;
449 int num_symbolic_dims = 0;
450 for (size_t j = 0; j < shapes.size(); ++j) {
451 if (i < result_shape_size - shapes[j]->dim_size()) {
452 // Shape j will be filled with 1 at dimension i;
453 continue;
454 }
455
456 auto dim_i_j =
457 shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
458 if (dim_i_j.has_dim_value()) {
459 if (dim_i_j.dim_value() != 1) {
460 if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
461 fail_shape_inference("Incompatible dimensions");
462 } else {
463 dim_value = dim_i_j.dim_value();
464 }
465 }
466 } else {
467 if (num_symbolic_dims == 0) {
468 symbolic_dim = dim_i_j;
469 ++num_symbolic_dims;
470 } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
471 ++num_symbolic_dims;
472 }
473 }
474 }
475
476 if (dim_value != 1 || num_symbolic_dims == 0) {
477 resultShape.add_dim()->set_dim_value(dim_value);
478 } else if (num_symbolic_dims == 1) {
479 *resultShape.add_dim() = symbolic_dim;
480 } else {
481 resultShape.add_dim();
482 }
483 }
484 }
485
bidirectionalBroadcastShapeInference(const TensorShapeProto & shapeL,const TensorShapeProto & shapeR,TensorShapeProto & resultShape)486 inline void bidirectionalBroadcastShapeInference(
487 const TensorShapeProto& shapeL,
488 const TensorShapeProto& shapeR,
489 TensorShapeProto& resultShape) {
490 std::vector<const TensorShapeProto*> shapes;
491 shapes.push_back(&shapeL);
492 shapes.push_back(&shapeR);
493 multidirectionalBroadcastShapeInference(shapes, resultShape);
494 }
495
496 /*
497 Merge the dimension information from two TensorShapeProto_Dimension instances.
498 Values are merged into target from source.
499 If target has no dimension information, copy from source.
500 If source has no dimension information, ignore source.
501 If both have dimension information:
502 - Prefer values over params. If both have values, values must match.
503 - Prefer target param over source param if mismatched.
504 Fail if there are mismatches in number of dimensions or dimension values.
505 */
mergeInDimensionInfo(const TensorShapeProto_Dimension & source_dim,TensorShapeProto_Dimension & target_dim,int dim_index)506 inline void mergeInDimensionInfo(
507 const TensorShapeProto_Dimension& source_dim,
508 TensorShapeProto_Dimension& target_dim,
509 int dim_index) {
510 // if source has value, merge into target
511 // else if target has value, preserve it
512 // else merge params
513 if (source_dim.has_dim_value()) {
514 auto source_value = source_dim.dim_value();
515 if (target_dim.has_dim_value()) {
516 auto target_value = target_dim.dim_value();
517 if (target_value != source_value) {
518 fail_shape_inference(
519 "Can't merge shape info. "
520 "Both source and target dimension have values but they differ. Source=",
521 source_value,
522 " Target=",
523 target_value,
524 " Dimension=",
525 dim_index);
526 }
527 } else {
528 target_dim.set_dim_value(source_value);
529 }
530 } else if (target_dim.has_dim_value()) {
531 // if target has a value we preserve it so do nothing
532 } else if (target_dim.has_dim_param()) {
533 // prefer target param over source
534 } else if (source_dim.has_dim_param()) {
535 target_dim.set_dim_param(source_dim.dim_param());
536 }
537 }
538
539 /*
540 Merge shape information from a source shape into a target shape.
541 * merges each TensorShapeProto_Dimension separately.
542 * prefer values over params.
543 * If both have values, values must match.
544 * prefer target param over source param if mismatched.
545 * Fail if there are mismatches in number of dimensions or dimension values.
546 */
mergeInShapeInfo(const TensorShapeProto & source,TensorShapeProto & target)547 inline void mergeInShapeInfo(
548 const TensorShapeProto& source,
549 TensorShapeProto& target) {
550 auto num_source_dims = source.dim_size();
551 auto num_target_dims = target.dim_size();
552 if (num_source_dims != num_target_dims) {
553 fail_shape_inference(
554 "Mismatch between number of source and target dimensions. Source=",
555 num_source_dims,
556 " Target=",
557 num_target_dims);
558 }
559
560 auto& source_dims = source.dim();
561 auto* target_dims = target.mutable_dim();
562
563 for (int i = 0, end = source_dims.size(); i < end; ++i) {
564 auto& source_dim = source_dims.Get(i);
565 auto& target_dim = *target_dims->Mutable(i);
566 mergeInDimensionInfo(source_dim, target_dim, i);
567 }
568 }
569
mergeInShapeInfo(const TensorShapeProto & source_shape,TypeProto_Tensor & target_type)570 inline void mergeInShapeInfo(
571 const TensorShapeProto& source_shape,
572 TypeProto_Tensor& target_type) {
573 if (target_type.has_shape()) {
574 // merge with existing info.
575 mergeInShapeInfo(source_shape, *target_type.mutable_shape());
576 } else {
577 // copy to target
578 (*target_type.mutable_shape()) = source_shape;
579 }
580 }
581
582 /*
583 Merge the shape information from two TypeProto_Tensor instances.
584 Values are merged into target from source.
585 If target has no shape information, copy from source.
586 If source has no shape information, ignore source.
587 If both have shape information:
588 - merge each TensorShapeProto_Dimension separately.
589 - Prefer values over params. If both have values, values must match.
590 - Prefer target param over source param if mismatched.
591 Fail if there are mismatches in number of dimensions or dimension values.
592 */
mergeInShapeInfo(const TypeProto_Tensor & source,TypeProto_Tensor & target)593 inline void mergeInShapeInfo(
594 const TypeProto_Tensor& source,
595 TypeProto_Tensor& target) {
596 if (source.has_shape())
597 mergeInShapeInfo(source.shape(), target);
598 }
599
600 // Return a copy of a type, with a specified dimension removed from its shape.
RemoveIthDimensionFromShape(const TypeProto & proto,int removed_dim)601 inline TypeProto RemoveIthDimensionFromShape(
602 const TypeProto& proto,
603 int removed_dim) {
604 TypeProto t(proto);
605 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
606 mutable_shape->clear_dim();
607
608 const auto& dims = proto.tensor_type().shape().dim();
609
610 for (int j = 0, end = dims.size(); j < end; ++j) {
611 if (j != removed_dim)
612 (*mutable_shape->add_dim()) = dims.Get(j);
613 }
614
615 return t;
616 }
617
618 // Return a copy of a type, with specified number of dimensions removed from the
619 // beginning.
RemoveDimensionsFromShape(const TypeProto & proto,int num_dimensions)620 inline TypeProto RemoveDimensionsFromShape(
621 const TypeProto& proto,
622 int num_dimensions) {
623 TypeProto t(proto);
624 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
625 mutable_shape->clear_dim();
626
627 const auto& dims = proto.tensor_type().shape().dim();
628
629 // skip first num_dimensions
630 for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
631 (*mutable_shape->add_dim()) = dims.Get(j);
632 }
633
634 return t;
635 }
636
637 // copied from GSL:
638 // https://github.com/Microsoft/GSL/blob/master/include/gsl/gsl_util
639 template <class T, class U>
narrow_cast(U && u)640 static constexpr T narrow_cast(U&& u) noexcept {
641 return static_cast<T>(std::forward<U>(u));
642 }
643
644 } // namespace ONNX_NAMESPACE
645