1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5 #include "shape_inference.h"
6
7 namespace ONNX_NAMESPACE {
8
9 /// <summary>
10 /// Utility function for UnionShapeInfoForTensor.
11 /// Both shapes must be of the same rank
12 /// </summary>
13 /// <param name="source_shape"></param>
14 /// <param name="target_shape">destination shape</param>
UnionShapeInfo(const TensorShapeProto & source_shape,TensorShapeProto & target_shape)15 void UnionShapeInfo(const TensorShapeProto& source_shape, TensorShapeProto& target_shape) {
16 auto source_rank = source_shape.dim_size();
17 for (int i = 0; i < source_rank; ++i) {
18 const auto source_dim = source_shape.dim(i);
19 const auto target_dim = target_shape.dim(i);
20 bool is_dims_conflict = [&]() {
21 if (source_dim.has_dim_value()) {
22 if (target_dim.has_dim_value() && target_dim.dim_value() == source_dim.dim_value()) {
23 return false;
24 }
25 return true;
26 }
27
28 if (source_dim.has_dim_param()) {
29 if (target_dim.has_dim_param() && target_dim.dim_param() == source_dim.dim_param()) {
30 return false;
31 }
32 return true;
33 }
34
35 return (target_dim.has_dim_value() || target_dim.has_dim_param());
36 }();
37 if (is_dims_conflict && (target_dim.has_dim_value() || target_dim.has_dim_param())) {
38 auto dim = target_shape.mutable_dim(i);
39 dim->clear_dim_value();
40 dim->clear_dim_param();
41 }
42 }
43 }
44
45 template<typename TENSOR_TYPE>
UnionShapeInfoForTensor(const TensorShapeProto & source_shape,TENSOR_TYPE & target_type)46 void UnionShapeInfoForTensor(const TensorShapeProto& source_shape, TENSOR_TYPE& target_type) {
47 if (target_type.has_shape()) {
48 TensorShapeProto* target_shape = target_type.mutable_shape();
49
50 auto source_rank = source_shape.dim_size();
51 auto target_rank = target_shape->dim_size();
52 if (source_rank != target_rank) {
53 target_type.clear_shape();
54 return;
55 }
56
57 UnionShapeInfo(source_shape, *target_shape);
58 }
59 }
60
UnionShapeInfo(const TensorShapeProto & source_shape,TypeProto_Tensor & target_type)61 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
62 UnionShapeInfoForTensor(source_shape, target_type);
63 }
64
UnionShapeInfo(const TensorShapeProto & source_shape,TypeProto_SparseTensor & target_type)65 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
66 UnionShapeInfoForTensor(source_shape, target_type);
67 }
68
69
UnionTypeInfo(const TypeProto & source_type,TypeProto & target_type)70 void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {
71 if (source_type.value_case() != target_type.value_case()) {
72 fail_type_inference("Mismatched type:", " source=", source_type.value_case(), " target=", target_type.value_case());
73 }
74
75 const auto target_case = target_type.value_case();
76 if (target_case == TypeProto::ValueCase::kTensorType) {
77 auto source_elem_type = source_type.tensor_type().elem_type();
78 auto target_elem_type = target_type.tensor_type().elem_type();
79
80 if (source_elem_type != target_elem_type) {
81 fail_type_inference(
82 "Mismatched tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
83 }
84
85 UnionShapeInfoForTensor(source_type.tensor_type().shape(), *target_type.mutable_tensor_type());
86 } else if (target_case == TypeProto::ValueCase::kSparseTensorType) {
87 auto source_elem_type = source_type.sparse_tensor_type().elem_type();
88 auto target_elem_type = target_type.sparse_tensor_type().elem_type();
89 if (source_elem_type != target_elem_type) {
90 fail_type_inference(
91 "Mismatched sparse tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
92 }
93
94 UnionShapeInfoForTensor(source_type.sparse_tensor_type().shape(), *target_type.mutable_sparse_tensor_type());
95 } else if (target_case == TypeProto::ValueCase::kSequenceType) {
96 if (!source_type.sequence_type().has_elem_type()) {
97 fail_type_inference("source sequence type missing element type.");
98 }
99 if (!target_type.sequence_type().has_elem_type()) {
100 fail_type_inference("target sequence type missing element type.");
101 }
102 UnionTypeInfo(source_type.sequence_type().elem_type(), *target_type.mutable_sequence_type()->mutable_elem_type());
103 }
104 }
105
106
107 } // namespace ONNX_NAMESPACE