1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5 #include "onnx/defs/printer.h"
6 #include "onnx/defs/tensor_proto_util.h"
7
8 namespace ONNX_NAMESPACE {
9
10 template <typename Collection>
print(std::ostream & os,const char * open,const char * separator,const char * close,Collection coll)11 inline void print(std::ostream& os, const char* open, const char* separator, const char* close, Collection coll) {
12 const char* sep = "";
13 os << open;
14 for (auto& elt : coll) {
15 os << sep << elt;
16 sep = separator;
17 }
18 os << close;
19 }
20
operator <<(std::ostream & os,const TensorShapeProto_Dimension & dim)21 std::ostream& operator<<(std::ostream& os, const TensorShapeProto_Dimension& dim) {
22 if (dim.has_dim_value())
23 os << dim.dim_value();
24 else if (dim.has_dim_param())
25 os << dim.dim_param();
26 else
27 os << "?";
28 return os;
29 }
30
operator <<(std::ostream & os,const TensorShapeProto & shape)31 std::ostream& operator<<(std::ostream& os, const TensorShapeProto& shape) {
32 print(os, "[", ",", "]", shape.dim());
33 return os;
34 }
35
operator <<(std::ostream & os,const TypeProto_Tensor & tensortype)36 std::ostream& operator<<(std::ostream& os, const TypeProto_Tensor& tensortype) {
37 os << PrimitiveTypeNameMap::ToString(tensortype.elem_type());
38 if (tensortype.has_shape()) {
39 if (tensortype.shape().dim_size() > 0)
40 os << tensortype.shape();
41 } else
42 os << "[...]";
43
44 return os;
45 }
46
operator <<(std::ostream & os,const TypeProto & type)47 std::ostream& operator<<(std::ostream& os, const TypeProto& type) {
48 if (type.has_tensor_type())
49 os << type.tensor_type();
50 return os;
51 }
52
operator <<(std::ostream & os,const TensorProto & tensor)53 std::ostream& operator<<(std::ostream& os, const TensorProto& tensor) {
54 os << PrimitiveTypeNameMap::ToString(tensor.data_type());
55 print(os, "[", ",", "]", tensor.dims());
56
57 // TODO: does not yet handle raw_data or FLOAT16 or externally stored data.
58 // TODO: does not yet handle name of tensor.
59 switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
60 case TensorProto::DataType::TensorProto_DataType_INT8:
61 case TensorProto::DataType::TensorProto_DataType_INT16:
62 case TensorProto::DataType::TensorProto_DataType_INT32:
63 case TensorProto::DataType::TensorProto_DataType_UINT8:
64 case TensorProto::DataType::TensorProto_DataType_UINT16:
65 case TensorProto::DataType::TensorProto_DataType_BOOL:
66 print(os, " {", ",", "}", tensor.int32_data());
67 break;
68 case TensorProto::DataType::TensorProto_DataType_INT64:
69 print(os, " {", ",", "}", tensor.int64_data());
70 break;
71 case TensorProto::DataType::TensorProto_DataType_UINT32:
72 case TensorProto::DataType::TensorProto_DataType_UINT64:
73 print(os, " {", ",", "}", tensor.uint64_data());
74 break;
75 case TensorProto::DataType::TensorProto_DataType_FLOAT:
76 print(os, " {", ",", "}", tensor.float_data());
77 break;
78 case TensorProto::DataType::TensorProto_DataType_DOUBLE:
79 print(os, " {", ",", "}", tensor.double_data());
80 break;
81 case TensorProto::DataType::TensorProto_DataType_STRING: {
82 const char* sep = "{";
83 for (auto& elt : tensor.string_data()) {
84 os << sep << "\"" << elt << "\"";
85 sep = ", ";
86 }
87 os << "}";
88 break;
89 }
90 default:
91 break;
92 }
93 return os;
94 }
95
operator <<(std::ostream & os,const ValueInfoProto & value_info)96 std::ostream& operator<<(std::ostream& os, const ValueInfoProto& value_info) {
97 os << value_info.type() << " " << value_info.name();
98 return os;
99 }
100
operator <<(std::ostream & os,const ValueInfoList & vilist)101 std::ostream& operator<<(std::ostream& os, const ValueInfoList& vilist) {
102 print(os, "(", ", ", ")", vilist);
103 return os;
104 }
105
operator <<(std::ostream & os,const AttributeProto & attr)106 std::ostream& operator<<(std::ostream& os, const AttributeProto& attr) {
107 os << attr.name() << " = ";
108 switch (attr.type()) {
109 case AttributeProto_AttributeType_INT:
110 os << attr.i();
111 break;
112 case AttributeProto_AttributeType_INTS:
113 print(os, "[", ", ", "]", attr.ints());
114 break;
115 case AttributeProto_AttributeType_FLOAT:
116 os << attr.f();
117 break;
118 case AttributeProto_AttributeType_FLOATS:
119 print(os, "[", ", ", "]", attr.floats());
120 break;
121 case AttributeProto_AttributeType_STRING:
122 os << "\"" << attr.s() << "\"";
123 break;
124 case AttributeProto_AttributeType_STRINGS: {
125 const char* sep = "[";
126 for (auto& elt : attr.strings()) {
127 os << sep << "\"" << elt << "\"";
128 sep = ", ";
129 }
130 os << "]";
131 break;
132 }
133 case AttributeProto_AttributeType_GRAPH:
134 os << attr.g();
135 break;
136 default:
137 break;
138 }
139 return os;
140 }
141
operator <<(std::ostream & os,const AttrList & attrlist)142 std::ostream& operator<<(std::ostream& os, const AttrList& attrlist) {
143 print(os, "<", ", ", ">", attrlist);
144 return os;
145 }
146
operator <<(std::ostream & os,const NodeProto & node)147 std::ostream& operator<<(std::ostream& os, const NodeProto& node) {
148 print(os, "", ", ", "", node.output());
149 os << " = " << node.op_type();
150 if (node.attribute_size() > 0)
151 os << node.attribute();
152 print(os, "(", ", ", ")", node.input());
153 return os;
154 }
155
operator <<(std::ostream & os,const NodeList & nodelist)156 std::ostream& operator<<(std::ostream& os, const NodeList& nodelist) {
157 print(os, "{\n", "\n", "\n}\n", nodelist);
158 return os;
159 }
160
operator <<(std::ostream & os,const GraphProto & graph)161 std::ostream& operator<<(std::ostream& os, const GraphProto& graph) {
162 os << graph.name() << " " << graph.input() << " => " << graph.output() << " ";
163 os << graph.node();
164 return os;
165 }
166
167 } // namespace ONNX_NAMESPACE