1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 #include "onnx/defs/printer.h"
6 #include "onnx/defs/tensor_proto_util.h"
7 
8 namespace ONNX_NAMESPACE {
9 
10 template <typename Collection>
print(std::ostream & os,const char * open,const char * separator,const char * close,Collection coll)11 inline void print(std::ostream& os, const char* open, const char* separator, const char* close, Collection coll) {
12   const char* sep = "";
13   os << open;
14   for (auto& elt : coll) {
15     os << sep << elt;
16     sep = separator;
17   }
18   os << close;
19 }
20 
operator <<(std::ostream & os,const TensorShapeProto_Dimension & dim)21 std::ostream& operator<<(std::ostream& os, const TensorShapeProto_Dimension& dim) {
22   if (dim.has_dim_value())
23     os << dim.dim_value();
24   else if (dim.has_dim_param())
25     os << dim.dim_param();
26   else
27     os << "?";
28   return os;
29 }
30 
operator <<(std::ostream & os,const TensorShapeProto & shape)31 std::ostream& operator<<(std::ostream& os, const TensorShapeProto& shape) {
32   print(os, "[", ",", "]", shape.dim());
33   return os;
34 }
35 
operator <<(std::ostream & os,const TypeProto_Tensor & tensortype)36 std::ostream& operator<<(std::ostream& os, const TypeProto_Tensor& tensortype) {
37   os << PrimitiveTypeNameMap::ToString(tensortype.elem_type());
38   if (tensortype.has_shape()) {
39     if (tensortype.shape().dim_size() > 0)
40       os << tensortype.shape();
41   } else
42     os << "[...]";
43 
44   return os;
45 }
46 
operator <<(std::ostream & os,const TypeProto & type)47 std::ostream& operator<<(std::ostream& os, const TypeProto& type) {
48   if (type.has_tensor_type())
49     os << type.tensor_type();
50   return os;
51 }
52 
operator <<(std::ostream & os,const TensorProto & tensor)53 std::ostream& operator<<(std::ostream& os, const TensorProto& tensor) {
54   os << PrimitiveTypeNameMap::ToString(tensor.data_type());
55   print(os, "[", ",", "]", tensor.dims());
56 
57   // TODO: does not yet handle raw_data or FLOAT16 or externally stored data.
58   // TODO: does not yet handle name of tensor.
59   switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
60     case TensorProto::DataType::TensorProto_DataType_INT8:
61     case TensorProto::DataType::TensorProto_DataType_INT16:
62     case TensorProto::DataType::TensorProto_DataType_INT32:
63     case TensorProto::DataType::TensorProto_DataType_UINT8:
64     case TensorProto::DataType::TensorProto_DataType_UINT16:
65     case TensorProto::DataType::TensorProto_DataType_BOOL:
66       print(os, " {", ",", "}", tensor.int32_data());
67       break;
68     case TensorProto::DataType::TensorProto_DataType_INT64:
69       print(os, " {", ",", "}", tensor.int64_data());
70       break;
71     case TensorProto::DataType::TensorProto_DataType_UINT32:
72     case TensorProto::DataType::TensorProto_DataType_UINT64:
73       print(os, " {", ",", "}", tensor.uint64_data());
74       break;
75     case TensorProto::DataType::TensorProto_DataType_FLOAT:
76       print(os, " {", ",", "}", tensor.float_data());
77       break;
78     case TensorProto::DataType::TensorProto_DataType_DOUBLE:
79       print(os, " {", ",", "}", tensor.double_data());
80       break;
81     case TensorProto::DataType::TensorProto_DataType_STRING: {
82       const char* sep = "{";
83       for (auto& elt : tensor.string_data()) {
84         os << sep << "\"" << elt << "\"";
85         sep = ", ";
86       }
87       os << "}";
88       break;
89     }
90     default:
91       break;
92   }
93   return os;
94 }
95 
operator <<(std::ostream & os,const ValueInfoProto & value_info)96 std::ostream& operator<<(std::ostream& os, const ValueInfoProto& value_info) {
97   os << value_info.type() << " " << value_info.name();
98   return os;
99 }
100 
operator <<(std::ostream & os,const ValueInfoList & vilist)101 std::ostream& operator<<(std::ostream& os, const ValueInfoList& vilist) {
102   print(os, "(", ", ", ")", vilist);
103   return os;
104 }
105 
operator <<(std::ostream & os,const AttributeProto & attr)106 std::ostream& operator<<(std::ostream& os, const AttributeProto& attr) {
107   os << attr.name() << " = ";
108   switch (attr.type()) {
109     case AttributeProto_AttributeType_INT:
110       os << attr.i();
111       break;
112     case AttributeProto_AttributeType_INTS:
113       print(os, "[", ", ", "]", attr.ints());
114       break;
115     case AttributeProto_AttributeType_FLOAT:
116       os << attr.f();
117       break;
118     case AttributeProto_AttributeType_FLOATS:
119       print(os, "[", ", ", "]", attr.floats());
120       break;
121     case AttributeProto_AttributeType_STRING:
122       os << "\"" << attr.s() << "\"";
123       break;
124     case AttributeProto_AttributeType_STRINGS: {
125       const char* sep = "[";
126       for (auto& elt : attr.strings()) {
127         os << sep << "\"" << elt << "\"";
128         sep = ", ";
129       }
130       os << "]";
131       break;
132     }
133     case AttributeProto_AttributeType_GRAPH:
134       os << attr.g();
135       break;
136     default:
137       break;
138   }
139   return os;
140 }
141 
operator <<(std::ostream & os,const AttrList & attrlist)142 std::ostream& operator<<(std::ostream& os, const AttrList& attrlist) {
143   print(os, "<", ", ", ">", attrlist);
144   return os;
145 }
146 
operator <<(std::ostream & os,const NodeProto & node)147 std::ostream& operator<<(std::ostream& os, const NodeProto& node) {
148   print(os, "", ", ", "", node.output());
149   os << " = " << node.op_type();
150   if (node.attribute_size() > 0)
151     os << node.attribute();
152   print(os, "(", ", ", ")", node.input());
153   return os;
154 }
155 
operator <<(std::ostream & os,const NodeList & nodelist)156 std::ostream& operator<<(std::ostream& os, const NodeList& nodelist) {
157   print(os, "{\n", "\n", "\n}\n", nodelist);
158   return os;
159 }
160 
operator <<(std::ostream & os,const GraphProto & graph)161 std::ostream& operator<<(std::ostream& os, const GraphProto& graph) {
162   os << graph.name() << " " << graph.input() << " => " << graph.output() << " ";
163   os << graph.node();
164   return os;
165 }
166 
167 } // namespace ONNX_NAMESPACE