1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5 #include <cctype>
6 #include <iostream>
7 #include <iterator>
8 #include <sstream>
9
10 #include "data_type_utils.h"
11
12 namespace ONNX_NAMESPACE {
13 namespace Utils {
14
15 // Singleton wrapper around allowed data types.
16 // This implements construct on first use which is needed to ensure
17 // static objects are initialized before use. Ops registration does not work
18 // properly without this.
19 class TypesWrapper final {
20 public:
21 static TypesWrapper& GetTypesWrapper();
22
23 std::unordered_set<std::string>& GetAllowedDataTypes();
24
25 std::unordered_map<std::string, int32_t>& TypeStrToTensorDataType();
26
27 std::unordered_map<int32_t, std::string>& TensorDataTypeToTypeStr();
28
29 ~TypesWrapper() = default;
30 TypesWrapper(const TypesWrapper&) = delete;
31 void operator=(const TypesWrapper&) = delete;
32
33 private:
34 TypesWrapper();
35
36 std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
37 std::unordered_map<int, std::string> tensor_data_type_to_type_str_;
38 std::unordered_set<std::string> allowed_data_types_;
39 };
40
41 // Simple class which contains pointers to external string buffer and a size.
42 // This can be used to track a "valid" range/slice of the string.
43 // Caller should ensure StringRange is not used after external storage has
44 // been freed.
45 class StringRange final {
46 public:
47 StringRange();
48 StringRange(const char* data, size_t size);
49 StringRange(const std::string& str);
50 StringRange(const char* data);
51 const char* Data() const;
52 size_t Size() const;
53 bool Empty() const;
54 char operator[](size_t idx) const;
55 void Reset();
56 void Reset(const char* data, size_t size);
57 void Reset(const std::string& str);
58 bool StartsWith(const StringRange& str) const;
59 bool EndsWith(const StringRange& str) const;
60 bool LStrip();
61 bool LStrip(size_t size);
62 bool LStrip(StringRange str);
63 bool RStrip();
64 bool RStrip(size_t size);
65 bool RStrip(StringRange str);
66 bool LAndRStrip();
67 void ParensWhitespaceStrip();
68 size_t Find(const char ch) const;
69
70 // These methods provide a way to return the range of the string
71 // which was discarded by LStrip(). i.e. We capture the string
72 // range which was discarded.
73 StringRange GetCaptured();
74 void RestartCapture();
75
76 private:
77 // data_ + size tracks the "valid" range of the external string buffer.
78 const char* data_;
79 size_t size_;
80
81 // start_ and end_ track the captured range.
82 // end_ advances when LStrip() is called.
83 const char* start_;
84 const char* end_;
85 };
86
87 std::unordered_map<std::string, TypeProto>&
GetTypeStrToProtoMap()88 DataTypeUtils::GetTypeStrToProtoMap() {
89 static std::unordered_map<std::string, TypeProto> map;
90 return map;
91 }
92
GetTypeStrLock()93 std::mutex& DataTypeUtils::GetTypeStrLock() {
94 static std::mutex lock;
95 return lock;
96 }
97
ToType(const TypeProto & type_proto)98 DataType DataTypeUtils::ToType(const TypeProto& type_proto) {
99 auto typeStr = ToString(type_proto);
100 std::lock_guard<std::mutex> lock(GetTypeStrLock());
101 if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end()) {
102 TypeProto type;
103 FromString(typeStr, type);
104 GetTypeStrToProtoMap()[typeStr] = type;
105 }
106 return &(GetTypeStrToProtoMap().find(typeStr)->first);
107 }
108
ToType(const std::string & type_str)109 DataType DataTypeUtils::ToType(const std::string& type_str) {
110 TypeProto type;
111 FromString(type_str, type);
112 return ToType(type);
113 }
114
ToTypeProto(const DataType & data_type)115 const TypeProto& DataTypeUtils::ToTypeProto(const DataType& data_type) {
116 std::lock_guard<std::mutex> lock(GetTypeStrLock());
117 auto it = GetTypeStrToProtoMap().find(*data_type);
118 if (GetTypeStrToProtoMap().end() == it) {
119 ONNX_THROW_EX(std::invalid_argument("Invalid data type " + *data_type));
120 }
121 return it->second;
122 }
123
ToString(const TypeProto & type_proto,const std::string & left,const std::string & right)124 std::string DataTypeUtils::ToString(
125 const TypeProto& type_proto,
126 const std::string& left,
127 const std::string& right) {
128 switch (type_proto.value_case()) {
129 case TypeProto::ValueCase::kTensorType: {
130 // Note: We do not distinguish tensors with zero rank (a shape consisting
131 // of an empty sequence of dimensions) here.
132 return left + "tensor(" +
133 ToDataTypeString(type_proto.tensor_type().elem_type()) + ")" + right;
134 }
135 case TypeProto::ValueCase::kSequenceType: {
136 return ToString(
137 type_proto.sequence_type().elem_type(), left + "seq(", ")" + right);
138 }
139 case TypeProto::ValueCase::kOptionalType: {
140 return ToString(
141 type_proto.optional_type().elem_type(), left + "optional(", ")" + right);
142 }
143 case TypeProto::ValueCase::kMapType: {
144 std::string map_str =
145 "map(" + ToDataTypeString(type_proto.map_type().key_type()) + ",";
146 return ToString(
147 type_proto.map_type().value_type(), left + map_str, ")" + right);
148 }
149 #ifdef ONNX_ML
150 case TypeProto::ValueCase::kOpaqueType: {
151 static const std::string empty;
152 std::string result;
153 const auto& op_type = type_proto.opaque_type();
154 result.append(left).append("opaque(");
155 if (op_type.has_domain() && !op_type.domain().empty()) {
156 result.append(op_type.domain()).append(",");
157 }
158 if (op_type.has_name() && !op_type.name().empty()) {
159 result.append(op_type.name());
160 }
161 result.append(")").append(right);
162 return result;
163 }
164 #endif
165 case TypeProto::ValueCase::kSparseTensorType: {
166 // Note: We do not distinguish tensors with zero rank (a shape consisting
167 // of an empty sequence of dimensions) here.
168 return left + "sparse_tensor(" +
169 ToDataTypeString(type_proto.sparse_tensor_type().elem_type()) + ")" +
170 right;
171 }
172 default:
173 ONNX_THROW_EX(std::invalid_argument("Unsuported type proto value case."));
174 }
175 }
176
ToDataTypeString(int32_t tensor_data_type)177 std::string DataTypeUtils::ToDataTypeString(int32_t tensor_data_type) {
178 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
179 auto iter = t.TensorDataTypeToTypeStr().find(tensor_data_type);
180 if (t.TensorDataTypeToTypeStr().end() == iter) {
181 ONNX_THROW_EX(std::invalid_argument("Invalid tensor data type "));
182 }
183 return iter->second;
184 }
185
FromString(const std::string & type_str,TypeProto & type_proto)186 void DataTypeUtils::FromString(
187 const std::string& type_str,
188 TypeProto& type_proto) {
189 StringRange s(type_str);
190 type_proto.Clear();
191 if (s.LStrip("seq")) {
192 s.ParensWhitespaceStrip();
193 return FromString(
194 std::string(s.Data(), s.Size()),
195 *type_proto.mutable_sequence_type()->mutable_elem_type());
196 } else if (s.LStrip("optional")) {
197 s.ParensWhitespaceStrip();
198 return FromString(
199 std::string(s.Data(), s.Size()),
200 *type_proto.mutable_optional_type()->mutable_elem_type());
201 } else if (s.LStrip("map")) {
202 s.ParensWhitespaceStrip();
203 size_t key_size = s.Find(',');
204 StringRange k(s.Data(), key_size);
205 std::string key(k.Data(), k.Size());
206 s.LStrip(key_size);
207 s.LStrip(",");
208 StringRange v(s.Data(), s.Size());
209 int32_t key_type;
210 FromDataTypeString(key, key_type);
211 type_proto.mutable_map_type()->set_key_type(key_type);
212 return FromString(
213 std::string(v.Data(), v.Size()),
214 *type_proto.mutable_map_type()->mutable_value_type());
215 } else
216 #ifdef ONNX_ML
217 if (s.LStrip("opaque")) {
218 auto* opaque_type = type_proto.mutable_opaque_type();
219 s.ParensWhitespaceStrip();
220 if (!s.Empty()) {
221 size_t cm = s.Find(',');
222 if (cm != std::string::npos) {
223 if (cm > 0) {
224 opaque_type->mutable_domain()->assign(s.Data(), cm);
225 }
226 s.LStrip(cm + 1); // skip comma
227 }
228 if (!s.Empty()) {
229 opaque_type->mutable_name()->assign(s.Data(), s.Size());
230 }
231 }
232 } else
233 #endif
234 if (s.LStrip("sparse_tensor")) {
235 s.ParensWhitespaceStrip();
236 int32_t e;
237 FromDataTypeString(std::string(s.Data(), s.Size()), e);
238 type_proto.mutable_sparse_tensor_type()->set_elem_type(e);
239 } else if (s.LStrip("tensor")) {
240 s.ParensWhitespaceStrip();
241 int32_t e;
242 FromDataTypeString(std::string(s.Data(), s.Size()), e);
243 type_proto.mutable_tensor_type()->set_elem_type(e);
244 } else {
245 // Scalar
246 int32_t e;
247 FromDataTypeString(std::string(s.Data(), s.Size()), e);
248 TypeProto::Tensor* t = type_proto.mutable_tensor_type();
249 t->set_elem_type(e);
250 // Call mutable_shape() to initialize a shape with no dimension.
251 t->mutable_shape();
252 }
253 } // namespace Utils
254
IsValidDataTypeString(const std::string & type_str)255 bool DataTypeUtils::IsValidDataTypeString(const std::string& type_str) {
256 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
257 const auto& allowedSet = t.GetAllowedDataTypes();
258 return (allowedSet.find(type_str) != allowedSet.end());
259 }
260
FromDataTypeString(const std::string & type_str,int32_t & tensor_data_type)261 void DataTypeUtils::FromDataTypeString(
262 const std::string& type_str,
263 int32_t& tensor_data_type) {
264 if (!IsValidDataTypeString(type_str)) {
265 ONNX_THROW_EX(std::invalid_argument("DataTypeUtils::FromDataTypeString - Received invalid data type string " + type_str));
266 }
267
268 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
269 tensor_data_type = t.TypeStrToTensorDataType()[type_str];
270 }
271
StringRange()272 StringRange::StringRange() : data_(""), size_(0), start_(data_), end_(data_) {}
273
StringRange(const char * p_data,size_t p_size)274 StringRange::StringRange(const char* p_data, size_t p_size)
275 : data_(p_data), size_(p_size), start_(data_), end_(data_) {
276 assert(p_data != nullptr);
277 LAndRStrip();
278 }
279
StringRange(const std::string & p_str)280 StringRange::StringRange(const std::string& p_str)
281 : data_(p_str.data()), size_(p_str.size()), start_(data_), end_(data_) {
282 LAndRStrip();
283 }
284
StringRange(const char * p_data)285 StringRange::StringRange(const char* p_data)
286 : data_(p_data), size_(strlen(p_data)), start_(data_), end_(data_) {
287 LAndRStrip();
288 }
289
Data() const290 const char* StringRange::Data() const {
291 return data_;
292 }
293
Size() const294 size_t StringRange::Size() const {
295 return size_;
296 }
297
Empty() const298 bool StringRange::Empty() const {
299 return size_ == 0;
300 }
301
operator [](size_t idx) const302 char StringRange::operator[](size_t idx) const {
303 return data_[idx];
304 }
305
Reset()306 void StringRange::Reset() {
307 data_ = "";
308 size_ = 0;
309 start_ = end_ = data_;
310 }
311
Reset(const char * data,size_t size)312 void StringRange::Reset(const char* data, size_t size) {
313 data_ = data;
314 size_ = size;
315 start_ = end_ = data_;
316 }
317
Reset(const std::string & str)318 void StringRange::Reset(const std::string& str) {
319 data_ = str.data();
320 size_ = str.size();
321 start_ = end_ = data_;
322 }
323
StartsWith(const StringRange & str) const324 bool StringRange::StartsWith(const StringRange& str) const {
325 return ((size_ >= str.size_) && (memcmp(data_, str.data_, str.size_) == 0));
326 }
327
EndsWith(const StringRange & str) const328 bool StringRange::EndsWith(const StringRange& str) const {
329 return (
330 (size_ >= str.size_) &&
331 (memcmp(data_ + (size_ - str.size_), str.data_, str.size_) == 0));
332 }
333
LStrip()334 bool StringRange::LStrip() {
335 size_t count = 0;
336 const char* ptr = data_;
337 while (count < size_ && isspace(*ptr)) {
338 count++;
339 ptr++;
340 }
341
342 if (count > 0) {
343 return LStrip(count);
344 }
345 return false;
346 }
347
LStrip(size_t size)348 bool StringRange::LStrip(size_t size) {
349 if (size <= size_) {
350 data_ += size;
351 size_ -= size;
352 end_ += size;
353 return true;
354 }
355 return false;
356 }
357
LStrip(StringRange str)358 bool StringRange::LStrip(StringRange str) {
359 if (StartsWith(str)) {
360 return LStrip(str.size_);
361 }
362 return false;
363 }
364
RStrip()365 bool StringRange::RStrip() {
366 size_t count = 0;
367 const char* ptr = data_ + size_ - 1;
368 while (count < size_ && isspace(*ptr)) {
369 ++count;
370 --ptr;
371 }
372
373 if (count > 0) {
374 return RStrip(count);
375 }
376 return false;
377 }
378
RStrip(size_t size)379 bool StringRange::RStrip(size_t size) {
380 if (size_ >= size) {
381 size_ -= size;
382 return true;
383 }
384 return false;
385 }
386
RStrip(StringRange str)387 bool StringRange::RStrip(StringRange str) {
388 if (EndsWith(str)) {
389 return RStrip(str.size_);
390 }
391 return false;
392 }
393
LAndRStrip()394 bool StringRange::LAndRStrip() {
395 bool l = LStrip();
396 bool r = RStrip();
397 return l || r;
398 }
399
ParensWhitespaceStrip()400 void StringRange::ParensWhitespaceStrip() {
401 LStrip();
402 LStrip("(");
403 LAndRStrip();
404 RStrip(")");
405 RStrip();
406 }
407
Find(const char ch) const408 size_t StringRange::Find(const char ch) const {
409 size_t idx = 0;
410 while (idx < size_) {
411 if (data_[idx] == ch) {
412 return idx;
413 }
414 idx++;
415 }
416 return std::string::npos;
417 }
418
RestartCapture()419 void StringRange::RestartCapture() {
420 start_ = data_;
421 end_ = data_;
422 }
423
GetCaptured()424 StringRange StringRange::GetCaptured() {
425 return StringRange(start_, end_ - start_);
426 }
427
GetTypesWrapper()428 TypesWrapper& TypesWrapper::GetTypesWrapper() {
429 static TypesWrapper types;
430 return types;
431 }
432
GetAllowedDataTypes()433 std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes() {
434 return allowed_data_types_;
435 }
436
TypeStrToTensorDataType()437 std::unordered_map<std::string, int>& TypesWrapper::TypeStrToTensorDataType() {
438 return type_str_to_tensor_data_type_;
439 }
440
TensorDataTypeToTypeStr()441 std::unordered_map<int, std::string>& TypesWrapper::TensorDataTypeToTypeStr() {
442 return tensor_data_type_to_type_str_;
443 }
444
TypesWrapper()445 TypesWrapper::TypesWrapper() {
446 // DataType strings. These should match the DataTypes defined in onnx.proto
447 type_str_to_tensor_data_type_["float"] = TensorProto_DataType_FLOAT;
448 type_str_to_tensor_data_type_["float16"] = TensorProto_DataType_FLOAT16;
449 type_str_to_tensor_data_type_["bfloat16"] = TensorProto_DataType_BFLOAT16;
450 type_str_to_tensor_data_type_["double"] = TensorProto_DataType_DOUBLE;
451 type_str_to_tensor_data_type_["int8"] = TensorProto_DataType_INT8;
452 type_str_to_tensor_data_type_["int16"] = TensorProto_DataType_INT16;
453 type_str_to_tensor_data_type_["int32"] = TensorProto_DataType_INT32;
454 type_str_to_tensor_data_type_["int64"] = TensorProto_DataType_INT64;
455 type_str_to_tensor_data_type_["uint8"] = TensorProto_DataType_UINT8;
456 type_str_to_tensor_data_type_["uint16"] = TensorProto_DataType_UINT16;
457 type_str_to_tensor_data_type_["uint32"] = TensorProto_DataType_UINT32;
458 type_str_to_tensor_data_type_["uint64"] = TensorProto_DataType_UINT64;
459 type_str_to_tensor_data_type_["complex64"] = TensorProto_DataType_COMPLEX64;
460 type_str_to_tensor_data_type_["complex128"] = TensorProto_DataType_COMPLEX128;
461 type_str_to_tensor_data_type_["string"] = TensorProto_DataType_STRING;
462 type_str_to_tensor_data_type_["bool"] = TensorProto_DataType_BOOL;
463
464 for (auto& str_type_pair : type_str_to_tensor_data_type_) {
465 tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first;
466 allowed_data_types_.insert(str_type_pair.first);
467 }
468 }
469 } // namespace Utils
470 } // namespace ONNX_NAMESPACE
471