1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 #include "shape_inference.h"
6 
7 namespace ONNX_NAMESPACE {
8 
9 /// <summary>
10 /// Utility function for UnionShapeInfoForTensor.
11 /// Both shapes must be of the same rank
12 /// </summary>
13 /// <param name="source_shape"></param>
14 /// <param name="target_shape">destination shape</param>
UnionShapeInfo(const TensorShapeProto & source_shape,TensorShapeProto & target_shape)15 void UnionShapeInfo(const TensorShapeProto& source_shape, TensorShapeProto& target_shape) {
16   auto source_rank = source_shape.dim_size();
17   for (int i = 0; i < source_rank; ++i) {
18     const auto source_dim = source_shape.dim(i);
19     const auto target_dim = target_shape.dim(i);
20     bool is_dims_conflict = [&]() {
21       if (source_dim.has_dim_value()) {
22         if (target_dim.has_dim_value() && target_dim.dim_value() == source_dim.dim_value()) {
23           return false;
24         }
25         return true;
26       }
27 
28       if (source_dim.has_dim_param()) {
29         if (target_dim.has_dim_param() && target_dim.dim_param() == source_dim.dim_param()) {
30           return false;
31         }
32         return true;
33       }
34 
35       return (target_dim.has_dim_value() || target_dim.has_dim_param());
36     }();
37     if (is_dims_conflict && (target_dim.has_dim_value() || target_dim.has_dim_param())) {
38       auto dim = target_shape.mutable_dim(i);
39       dim->clear_dim_value();
40       dim->clear_dim_param();
41     }
42   }
43 }
44 
45 template<typename TENSOR_TYPE>
UnionShapeInfoForTensor(const TensorShapeProto & source_shape,TENSOR_TYPE & target_type)46 void UnionShapeInfoForTensor(const TensorShapeProto& source_shape, TENSOR_TYPE& target_type) {
47   if (target_type.has_shape()) {
48     TensorShapeProto* target_shape = target_type.mutable_shape();
49 
50     auto source_rank = source_shape.dim_size();
51     auto target_rank = target_shape->dim_size();
52     if (source_rank != target_rank) {
53       target_type.clear_shape();
54       return;
55     }
56 
57     UnionShapeInfo(source_shape, *target_shape);
58   }
59 }
60 
UnionShapeInfo(const TensorShapeProto & source_shape,TypeProto_Tensor & target_type)61 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
62   UnionShapeInfoForTensor(source_shape, target_type);
63 }
64 
UnionShapeInfo(const TensorShapeProto & source_shape,TypeProto_SparseTensor & target_type)65 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
66   UnionShapeInfoForTensor(source_shape, target_type);
67 }
68 
69 
UnionTypeInfo(const TypeProto & source_type,TypeProto & target_type)70 void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {
71   if (source_type.value_case() != target_type.value_case()) {
72     fail_type_inference("Mismatched type:", " source=", source_type.value_case(), " target=", target_type.value_case());
73   }
74 
75   const auto target_case = target_type.value_case();
76   if (target_case == TypeProto::ValueCase::kTensorType) {
77     auto source_elem_type = source_type.tensor_type().elem_type();
78     auto target_elem_type = target_type.tensor_type().elem_type();
79 
80     if (source_elem_type != target_elem_type) {
81       fail_type_inference(
82           "Mismatched tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
83     }
84 
85     UnionShapeInfoForTensor(source_type.tensor_type().shape(), *target_type.mutable_tensor_type());
86   } else if (target_case == TypeProto::ValueCase::kSparseTensorType) {
87     auto source_elem_type = source_type.sparse_tensor_type().elem_type();
88     auto target_elem_type = target_type.sparse_tensor_type().elem_type();
89     if (source_elem_type != target_elem_type) {
90       fail_type_inference(
91           "Mismatched sparse tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
92     }
93 
94     UnionShapeInfoForTensor(source_type.sparse_tensor_type().shape(), *target_type.mutable_sparse_tensor_type());
95   } else if (target_case == TypeProto::ValueCase::kSequenceType) {
96     if (!source_type.sequence_type().has_elem_type()) {
97       fail_type_inference("source sequence type missing element type.");
98     }
99     if (!target_type.sequence_type().has_elem_type()) {
100       fail_type_inference("target sequence type missing element type.");
101     }
102     UnionTypeInfo(source_type.sequence_type().elem_type(), *target_type.mutable_sequence_type()->mutable_elem_type());
103   }
104 }
105 
106 
107 } // namespace ONNX_NAMESPACE