1 #pragma once
2 
3 #include "onnx/defs/data_type_utils.h"
4 #include "onnx/proto_utils.h"
5 #include "onnx/string_utils.h"
6 
7 namespace ONNX_NAMESPACE {
8 
9 class GraphInferencer {
10  public:
11   // Perform inferencing on the graph contained in GraphInferencer.
12   // Returns the graph output types post-inferencing.
13   virtual std::vector<const TypeProto*> doInferencing(
14       const std::vector<const TypeProto*>& inputTypes,
15       const std::vector<const TensorProto*>& inputData) = 0;
16   virtual ~GraphInferencer() = default;
17 };
18 
19 // Exception class used for handling errors in type and shape inference
20 
21 class InferenceError final : public std::runtime_error {
22  public:
23   using std::runtime_error::runtime_error;
24 
InferenceError(const std::string & message)25   InferenceError(const std::string& message) : std::runtime_error(message) {}
26 
what()27   const char* what() const noexcept override {
28     if (!expanded_message_.empty()) {
29       return expanded_message_.c_str();
30     }
31     return std::runtime_error::what();
32   }
33 
AppendContext(const std::string & context)34   void AppendContext(const std::string& context) {
35     expanded_message_ = ONNX_NAMESPACE::MakeString(
36         std::runtime_error::what(), "\n\n==> Context: ", context);
37   }
38 
39  private:
40   std::string expanded_message_;
41 };
42 
43 #define fail_type_inference(...)        \
44   throw ONNX_NAMESPACE::InferenceError( \
45       ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__));
46 
47 #define fail_shape_inference(...)       \
48   throw ONNX_NAMESPACE::InferenceError( \
49       ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__));
50 
51 struct InferenceContext {
52   virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
53   virtual size_t getNumInputs() const = 0;
54   virtual const TypeProto* getInputType(size_t index) const = 0;
55   virtual const TensorProto* getInputData(size_t index) const = 0;
56   virtual size_t getNumOutputs() const = 0;
57   virtual TypeProto* getOutputType(size_t index) = 0;
58   virtual GraphInferencer* getGraphAttributeInferencer(
59       const std::string& attribute_name) = 0;
~InferenceContextInferenceContext60   virtual ~InferenceContext() {}
61 };
62 
63 using InferenceFunction = std::function<void(InferenceContext&)>;
64 
65 // This no-op inference function is used for operators without an
66 // inference implementation.
dummyInferenceFunction(InferenceContext &)67 inline void dummyInferenceFunction(InferenceContext&){};
68 
69 template <typename T>
getRepeatedAttribute(InferenceContext & ctx,std::string attr_name,std::vector<T> & values)70 inline bool getRepeatedAttribute(
71     InferenceContext& ctx,
72     std::string attr_name,
73     std::vector<T>& values) {
74   const auto* attr = ctx.getAttribute(attr_name);
75   if (attr) {
76     values = RetrieveValues<T>(*attr);
77     return true;
78   } else {
79     return false;
80   }
81 }
82 
getAttribute(InferenceContext & ctx,const std::string & attributeName,int64_t defaultValue)83 inline int64_t getAttribute(
84     InferenceContext& ctx,
85     const std::string& attributeName,
86     int64_t defaultValue) {
87   auto attr_proto = ctx.getAttribute(attributeName);
88   if ((nullptr != attr_proto) && attr_proto->has_i())
89     return attr_proto->i();
90   return defaultValue;
91 }
92 
getAttribute(InferenceContext & ctx,const std::string & attributeName,const std::string & defaultValue)93 inline std::string getAttribute(
94     InferenceContext& ctx,
95     const std::string& attributeName,
96     const std::string& defaultValue) {
97   auto attr_proto = ctx.getAttribute(attributeName);
98   if ((nullptr != attr_proto) && attr_proto->has_s())
99     return attr_proto->s();
100   return defaultValue;
101 }
102 
103 inline TensorShapeProto::Dimension operator*(
104     TensorShapeProto::Dimension dim1,
105     TensorShapeProto::Dimension dim2) {
106   TensorShapeProto::Dimension result;
107   if (dim1.has_dim_value() && dim2.has_dim_value()) {
108     result.set_dim_value(dim1.dim_value() * dim2.dim_value());
109   } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
110     return dim2;
111   } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
112     return dim1;
113   }
114   return result;
115 }
116 
117 inline TensorShapeProto::Dimension operator*(
118     TensorShapeProto::Dimension dim1,
119     int64_t dim2) {
120   TensorShapeProto::Dimension result;
121   if (dim1.has_dim_value()) {
122     result.set_dim_value(dim1.dim_value() * dim2);
123   } else if (dim2 == 1) {
124     return dim1;
125   }
126   return result;
127 }
128 
129 inline TensorShapeProto::Dimension operator/(
130     TensorShapeProto::Dimension dim1,
131     int64_t dim2) {
132   TensorShapeProto::Dimension result;
133   if (dim1.has_dim_value()) {
134     result.set_dim_value(dim1.dim_value() / dim2);
135   } else if (dim2 == 1) {
136     return dim1;
137   }
138   return result;
139 }
140 
141 // if from >= upto_exclusive, return 1.
142 // Caller must make sure upto_exclusive is less than or equal to shape.size()
143 // Caller must make sure from>=0
144 inline TensorShapeProto::Dimension
multiplyDims(const TensorShapeProto & shape,int from,int upto_exclusive)145 multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
146   TensorShapeProto::Dimension dim;
147   dim.set_dim_value(1);
148   for (int i = from; i < upto_exclusive; ++i) {
149     dim = dim * shape.dim(i);
150   }
151   return dim;
152 }
153 
154 // propagate the element type from an input type to an output type.
155 // if an existing output element type exists, validate it matches.
propagateElemTypeWithValidation(const TypeProto * input_type,TypeProto * output_type)156 inline void propagateElemTypeWithValidation(
157     const TypeProto* input_type,
158     TypeProto* output_type) {
159   if (nullptr == input_type) {
160     fail_type_inference("Input type was null");
161   }
162 
163   if (input_type->value_case() != TypeProto::kTensorType) {
164     fail_type_inference(
165         "Input was expected to have tensor type. Got ",
166         input_type->value_case());
167   }
168 
169   if (input_type->tensor_type().elem_type() == TensorProto::UNDEFINED) {
170     fail_type_inference("Element type of input was unknown");
171   }
172 
173   if (output_type->value_case() == TypeProto::VALUE_NOT_SET) {
174     output_type->mutable_tensor_type()->set_elem_type(
175         input_type->tensor_type().elem_type());
176   } else if (output_type->value_case() == TypeProto::kTensorType) {
177     if (output_type->tensor_type().has_elem_type()) {
178       if (input_type->tensor_type().elem_type() !=
179           output_type->tensor_type().elem_type()) {
180         fail_type_inference(
181             "Input element type of ",
182             input_type->tensor_type().elem_type(),
183             " does not match existing output type of ",
184             output_type->tensor_type().elem_type());
185       }
186     } else {
187       output_type->mutable_tensor_type()->set_elem_type(
188           input_type->tensor_type().elem_type());
189     }
190   } else {
191     // This is not expected to happen
192     fail_type_inference(
193         "Output was expected to have tensor type. Got ",
194         output_type->value_case());
195   }
196 }
197 
198 // Note: for all methods below for propagating type or shape, callers are
199 // responsible to handle optional inputs/outputs and ensure that the specified
200 // index value is less than NumInputs/NumOutputs.
201 
propagateElemTypeFromInputToOutput(InferenceContext & ctx,size_t inputIndex,size_t outputIndex)202 inline void propagateElemTypeFromInputToOutput(
203     InferenceContext& ctx,
204     size_t inputIndex,
205     size_t outputIndex) {
206   auto input_type = ctx.getInputType(inputIndex);
207   if (nullptr == input_type ||
208       input_type->value_case() != TypeProto::kTensorType) {
209     fail_type_inference("Input ", inputIndex, " expected to have tensor type");
210   }
211   if (input_type->tensor_type().elem_type() == TensorProto::UNDEFINED) {
212     fail_type_inference("Element type of input ", inputIndex, " unknown");
213   }
214   auto output_type = ctx.getOutputType(outputIndex);
215   if (output_type->value_case() == TypeProto::kTensorType ||
216       output_type->value_case() == TypeProto::VALUE_NOT_SET) {
217     output_type->mutable_tensor_type()->set_elem_type(
218         input_type->tensor_type().elem_type());
219   } else {
220     // This is not expected to happen
221     fail_type_inference(
222         "Output ", outputIndex, " expected to have tensor type");
223   }
224 }
225 
propagateElemTypeFromDtypeToOutput(InferenceContext & ctx,const int & data_type,size_t outputIndex)226 inline void propagateElemTypeFromDtypeToOutput(
227     InferenceContext& ctx,
228     const int& data_type,
229     size_t outputIndex) {
230   auto attribute_tensor_datatype = data_type;
231   auto output_type = ctx.getOutputType(outputIndex);
232   if (output_type->value_case() == TypeProto::kTensorType ||
233       output_type->value_case() == TypeProto::VALUE_NOT_SET) {
234     output_type->mutable_tensor_type()->set_elem_type(
235         attribute_tensor_datatype);
236   } else {
237     // This is not expected to happen
238     fail_type_inference(
239         "Output ", outputIndex, " expected to have tensor type");
240   }
241 }
242 
propagateElemTypeFromDtypeToOutput(InferenceContext & ctx,const AttributeProto * attr,size_t outputIndex)243 inline void propagateElemTypeFromDtypeToOutput(
244     InferenceContext& ctx,
245     const AttributeProto* attr,
246     size_t outputIndex) {
247   if (attr->type() != AttributeProto::TENSOR) {
248     fail_type_inference("Attribute expected to have tensor type");
249   }
250   if (attr->t().dims().size() != 1) {
251     fail_type_inference("Attribute expected to have a one-dim tensor");
252   }
253   auto attribute_tensor_datatype = attr->t().data_type();
254   propagateElemTypeFromDtypeToOutput(
255       ctx, attribute_tensor_datatype, outputIndex);
256 }
257 
hasInputShape(InferenceContext & ctx,size_t n)258 inline bool hasInputShape(InferenceContext& ctx, size_t n) {
259   return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) &&
260       ctx.getInputType(n)->has_tensor_type() &&
261       ctx.getInputType(n)->tensor_type().has_shape();
262 }
263 
hasNInputShapes(InferenceContext & ctx,size_t n)264 inline bool hasNInputShapes(InferenceContext& ctx, size_t n) {
265   for (size_t i = 0; i < n; i++) {
266     if (!hasInputShape(ctx, i)) {
267       return false;
268     }
269   }
270   return true;
271 }
272 
getInputShape(InferenceContext & ctx,size_t n)273 inline const TensorShapeProto& getInputShape(InferenceContext& ctx, size_t n) {
274   return ctx.getInputType(n)->tensor_type().shape();
275 }
276 
277 // Caller must make sure fromDimIndex is strictly less than shape.dim_size()
appendSingleDimCopiedFromInputTypeToOutputType(InferenceContext & ctx,size_t inputIndex,size_t outputIndex,size_t fromDimIndex)278 inline void appendSingleDimCopiedFromInputTypeToOutputType(
279     InferenceContext& ctx,
280     size_t inputIndex,
281     size_t outputIndex,
282     size_t fromDimIndex) {
283   auto output_type = ctx.getOutputType(outputIndex);
284   auto input_type = ctx.getInputType(inputIndex);
285   if (TypeProto::kTensorType != output_type->value_case()) {
286     fail_type_inference(
287         "Output ", outputIndex, " expected to have tensor type");
288   }
289   if (TypeProto::kTensorType != input_type->value_case()) {
290     fail_type_inference("Input ", inputIndex, " expected to have tensor type");
291   }
292   auto* dim = ctx.getOutputType(outputIndex)
293                   ->mutable_tensor_type()
294                   ->mutable_shape()
295                   ->add_dim();
296   *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
297 }
298 
propagateShapeFromInputToOutput(InferenceContext & ctx,size_t inputIndex,size_t outputIndex)299 inline void propagateShapeFromInputToOutput(
300     InferenceContext& ctx,
301     size_t inputIndex,
302     size_t outputIndex) {
303   auto output_type = ctx.getOutputType(outputIndex);
304   auto input_type = ctx.getInputType(inputIndex);
305   if (TypeProto::kTensorType != input_type->value_case() ||
306       TypeProto::kTensorType != output_type->value_case()) {
307     throw std::runtime_error(ONNX_NAMESPACE::to_string(
308         ctx.getInputType(inputIndex)->tensor_type().shape().dim_size()));
309   }
310 
311   *ctx.getOutputType(outputIndex)->mutable_tensor_type()->mutable_shape() =
312       ctx.getInputType(inputIndex)->tensor_type().shape();
313 }
314 
propagateShapeAndTypeFromFirstInput(InferenceContext & ctx)315 inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
316   propagateElemTypeFromInputToOutput(ctx, 0, 0);
317   if (!hasNInputShapes(ctx, 1)) {
318     return;
319   }
320   propagateShapeFromInputToOutput(ctx, 0, 0);
321 }
322 
updateOutputElemType(InferenceContext & ctx,size_t outputIndex,int32_t elemType)323 inline void updateOutputElemType(
324     InferenceContext& ctx,
325     size_t outputIndex,
326     int32_t elemType) {
327   auto output_type = ctx.getOutputType(outputIndex);
328   if ((output_type != nullptr) &&
329       (output_type->value_case() == TypeProto::kTensorType ||
330        output_type->value_case() == TypeProto::VALUE_NOT_SET)) {
331     output_type->mutable_tensor_type()->set_elem_type(elemType);
332   } else {
333     // This is not expected to happen
334     fail_type_inference(
335         "Output ", outputIndex, " expected to have tensor type");
336   }
337 }
338 
339 // Infer type of an output from the value of a specified attribute, which is
340 // expected to have a valid value representing a TensorProto_DataType.
341 inline void propagateElemTypeFromAttributeToOutput(
342     InferenceContext& ctx,
343     const std::string& attributeName,
344     size_t outputIndex,
345     TensorProto_DataType default_value = TensorProto::UNDEFINED) {
346   auto attr_proto = ctx.getAttribute(attributeName);
347   if (nullptr == attr_proto) { // attribute not present
348     if (default_value != TensorProto::UNDEFINED) {
349       updateOutputElemType(ctx, outputIndex, default_value);
350       return;
351     } else
352       fail_type_inference(
353           "Value of attribute ", attributeName, " not specified");
354   }
355   if (!attr_proto->has_i()) {
356     fail_type_inference(
357         "Attribute ",
358         attributeName,
359         " should be of integer type and specify a type.");
360   }
361   auto attr_value = attr_proto->i();
362   auto elem_type = static_cast<TensorProto_DataType>(attr_value);
363   if (!TensorProto_DataType_IsValid(elem_type)) {
364     fail_type_inference(
365         "Attribute ", attributeName, " does not specify a valid type.");
366   }
367   updateOutputElemType(ctx, outputIndex, elem_type);
368 }
369 
getOutputShape(InferenceContext & ctx,size_t n)370 inline TensorShapeProto* getOutputShape(InferenceContext& ctx, size_t n) {
371   auto output_type = ctx.getOutputType(n);
372   if ((output_type != nullptr) &&
373       (output_type->value_case() == TypeProto::kTensorType ||
374        output_type->value_case() == TypeProto::VALUE_NOT_SET)) {
375     return output_type->mutable_tensor_type()->mutable_shape();
376   } else
377     fail_type_inference("Output ", n, " expected to have tensor type");
378 }
379 
updateOutputShape(InferenceContext & ctx,size_t outputIndex,const TensorShapeProto & shape)380 inline void updateOutputShape(
381     InferenceContext& ctx,
382     size_t outputIndex,
383     const TensorShapeProto& shape) {
384   auto* output_shape = getOutputShape(ctx, outputIndex);
385   *output_shape = shape;
386 }
387 
updateOutputShape(InferenceContext & ctx,size_t outputIndex,const TensorProto & tensorProto)388 inline void updateOutputShape(
389     InferenceContext& ctx,
390     size_t outputIndex,
391     const TensorProto& tensorProto) {
392   auto* output_shape = getOutputShape(ctx, outputIndex);
393   for (auto d : tensorProto.dims()) {
394     auto* dim = output_shape->add_dim();
395     dim->set_dim_value(d);
396   }
397 }
398 
updateOutputShape(InferenceContext & ctx,size_t outputIndex,std::initializer_list<TensorShapeProto::Dimension> dims)399 inline void updateOutputShape(
400     InferenceContext& ctx,
401     size_t outputIndex,
402     std::initializer_list<TensorShapeProto::Dimension> dims) {
403   auto* output_shape = getOutputShape(ctx, outputIndex);
404   for (auto& d : dims) {
405     auto* dim = output_shape->add_dim();
406     *dim = d;
407   }
408 }
409 
410 // Infer shape of an output from the value of a specified attribute, which is
411 // expected to be a list of integers specifying a valid shape.
propagateShapeFromAttributeToOutput(InferenceContext & ctx,const std::string & attributeName,size_t outputIndex)412 inline void propagateShapeFromAttributeToOutput(
413     InferenceContext& ctx,
414     const std::string& attributeName,
415     size_t outputIndex) {
416   auto attr_proto = ctx.getAttribute(attributeName);
417   if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
418       (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
419     fail_shape_inference(
420         "Attribute ", attributeName, " should specify a shape");
421   }
422   auto& int_list = attr_proto->ints();
423   TensorShapeProto shape;
424   for (auto dim_size : int_list) {
425     if (dim_size < 0) {
426       fail_shape_inference(
427           "Negative values are not allowed in a shape specification");
428     }
429     shape.add_dim()->set_dim_value(dim_size);
430   }
431 
432   updateOutputShape(ctx, outputIndex, shape);
433 }
434 
multidirectionalBroadcastShapeInference(const std::vector<const TensorShapeProto * > & shapes,TensorShapeProto & resultShape)435 inline void multidirectionalBroadcastShapeInference(
436     const std::vector<const TensorShapeProto*>& shapes,
437     TensorShapeProto& resultShape) {
438   int result_shape_size = 0;
439   // Get the result shape size.
440   for (size_t i = 0; i < shapes.size(); ++i) {
441     if (shapes[i]->dim_size() > result_shape_size) {
442       result_shape_size = shapes[i]->dim_size();
443     }
444   }
445 
446   for (int i = 0; i < result_shape_size; ++i) {
447     int64_t dim_value = 1;
448     TensorShapeProto_Dimension symbolic_dim;
449     int num_symbolic_dims = 0;
450     for (size_t j = 0; j < shapes.size(); ++j) {
451       if (i < result_shape_size - shapes[j]->dim_size()) {
452         // Shape j will be filled with 1 at dimension i;
453         continue;
454       }
455 
456       auto dim_i_j =
457           shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
458       if (dim_i_j.has_dim_value()) {
459         if (dim_i_j.dim_value() != 1) {
460           if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
461             fail_shape_inference("Incompatible dimensions");
462           } else {
463             dim_value = dim_i_j.dim_value();
464           }
465         }
466       } else {
467         if (num_symbolic_dims == 0) {
468           symbolic_dim = dim_i_j;
469           ++num_symbolic_dims;
470         } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
471           ++num_symbolic_dims;
472         }
473       }
474     }
475 
476     if (dim_value != 1 || num_symbolic_dims == 0) {
477       resultShape.add_dim()->set_dim_value(dim_value);
478     } else if (num_symbolic_dims == 1) {
479       *resultShape.add_dim() = symbolic_dim;
480     } else {
481       resultShape.add_dim();
482     }
483   }
484 }
485 
bidirectionalBroadcastShapeInference(const TensorShapeProto & shapeL,const TensorShapeProto & shapeR,TensorShapeProto & resultShape)486 inline void bidirectionalBroadcastShapeInference(
487     const TensorShapeProto& shapeL,
488     const TensorShapeProto& shapeR,
489     TensorShapeProto& resultShape) {
490   std::vector<const TensorShapeProto*> shapes;
491   shapes.push_back(&shapeL);
492   shapes.push_back(&shapeR);
493   multidirectionalBroadcastShapeInference(shapes, resultShape);
494 }
495 
496 /*
497 Merge the dimension information from two TensorShapeProto_Dimension instances.
498 Values are merged into target from source.
499 If target has no dimension information, copy from source.
500 If source has no dimension information, ignore source.
501 If both have dimension information:
502  - Prefer values over params. If both have values, values must match.
503  - Prefer target param over source param if mismatched.
504 Fail if there are mismatches in number of dimensions or dimension values.
505 */
mergeInDimensionInfo(const TensorShapeProto_Dimension & source_dim,TensorShapeProto_Dimension & target_dim,int dim_index)506 inline void mergeInDimensionInfo(
507     const TensorShapeProto_Dimension& source_dim,
508     TensorShapeProto_Dimension& target_dim,
509     int dim_index) {
510   // if source has value, merge into target
511   // else if target has value, preserve it
512   // else merge params
513   if (source_dim.has_dim_value()) {
514     auto source_value = source_dim.dim_value();
515     if (target_dim.has_dim_value()) {
516       auto target_value = target_dim.dim_value();
517       if (target_value != source_value) {
518         fail_shape_inference(
519             "Can't merge shape info. "
520             "Both source and target dimension have values but they differ. Source=",
521             source_value,
522             " Target=",
523             target_value,
524             " Dimension=",
525             dim_index);
526       }
527     } else {
528       target_dim.set_dim_value(source_value);
529     }
530   } else if (target_dim.has_dim_value()) {
531     // if target has a value we preserve it so do nothing
532   } else if (target_dim.has_dim_param()) {
533     // prefer target param over source
534   } else if (source_dim.has_dim_param()) {
535     target_dim.set_dim_param(source_dim.dim_param());
536   }
537 }
538 
539 /*
540 Merge shape information from a source shape into a target shape.
541 * merges each TensorShapeProto_Dimension separately.
542 * prefer values over params.
543 * If both have values, values must match.
544 * prefer target param over source param if mismatched.
545 * Fail if there are mismatches in number of dimensions or dimension values.
546 */
mergeInShapeInfo(const TensorShapeProto & source,TensorShapeProto & target)547 inline void mergeInShapeInfo(
548     const TensorShapeProto& source,
549     TensorShapeProto& target) {
550   auto num_source_dims = source.dim_size();
551   auto num_target_dims = target.dim_size();
552   if (num_source_dims != num_target_dims) {
553     fail_shape_inference(
554         "Mismatch between number of source and target dimensions. Source=",
555         num_source_dims,
556         " Target=",
557         num_target_dims);
558   }
559 
560   auto& source_dims = source.dim();
561   auto* target_dims = target.mutable_dim();
562 
563   for (int i = 0, end = source_dims.size(); i < end; ++i) {
564     auto& source_dim = source_dims.Get(i);
565     auto& target_dim = *target_dims->Mutable(i);
566     mergeInDimensionInfo(source_dim, target_dim, i);
567   }
568 }
569 
mergeInShapeInfo(const TensorShapeProto & source_shape,TypeProto_Tensor & target_type)570 inline void mergeInShapeInfo(
571     const TensorShapeProto& source_shape,
572     TypeProto_Tensor& target_type) {
573   if (target_type.has_shape()) {
574     // merge with existing info.
575     mergeInShapeInfo(source_shape, *target_type.mutable_shape());
576   } else {
577     // copy to target
578     (*target_type.mutable_shape()) = source_shape;
579   }
580 }
581 
582 /*
583 Merge the shape information from two TypeProto_Tensor instances.
584 Values are merged into target from source.
585 If target has no shape information, copy from source.
586 If source has no shape information, ignore source.
587 If both have shape information:
588 - merge each TensorShapeProto_Dimension separately.
589 - Prefer values over params. If both have values, values must match.
590 - Prefer target param over source param if mismatched.
591 Fail if there are mismatches in number of dimensions or dimension values.
592 */
mergeInShapeInfo(const TypeProto_Tensor & source,TypeProto_Tensor & target)593 inline void mergeInShapeInfo(
594     const TypeProto_Tensor& source,
595     TypeProto_Tensor& target) {
596   if (source.has_shape())
597     mergeInShapeInfo(source.shape(), target);
598 }
599 
600 // Return a copy of a type, with a specified dimension removed from its shape.
RemoveIthDimensionFromShape(const TypeProto & proto,int removed_dim)601 inline TypeProto RemoveIthDimensionFromShape(
602     const TypeProto& proto,
603     int removed_dim) {
604   TypeProto t(proto);
605   auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
606   mutable_shape->clear_dim();
607 
608   const auto& dims = proto.tensor_type().shape().dim();
609 
610   for (int j = 0, end = dims.size(); j < end; ++j) {
611     if (j != removed_dim)
612       (*mutable_shape->add_dim()) = dims.Get(j);
613   }
614 
615   return t;
616 }
617 
618 // Return a copy of a type, with specified number of dimensions removed from the
619 // beginning.
RemoveDimensionsFromShape(const TypeProto & proto,int num_dimensions)620 inline TypeProto RemoveDimensionsFromShape(
621     const TypeProto& proto,
622     int num_dimensions) {
623   TypeProto t(proto);
624   auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
625   mutable_shape->clear_dim();
626 
627   const auto& dims = proto.tensor_type().shape().dim();
628 
629   // skip first num_dimensions
630   for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
631     (*mutable_shape->add_dim()) = dims.Get(j);
632   }
633 
634   return t;
635 }
636 
637 // copied from GSL:
638 // https://github.com/Microsoft/GSL/blob/master/include/gsl/gsl_util
639 template <class T, class U>
narrow_cast(U && u)640 static constexpr T narrow_cast(U&& u) noexcept {
641   return static_cast<T>(std::forward<U>(u));
642 }
643 
644 } // namespace ONNX_NAMESPACE
645