1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 #include <cctype>
6 #include <iostream>
7 #include <iterator>
8 #include <sstream>
9 
10 #include "data_type_utils.h"
11 
12 namespace ONNX_NAMESPACE {
13 namespace Utils {
14 
15 // Singleton wrapper around allowed data types.
16 // This implements construct on first use which is needed to ensure
17 // static objects are initialized before use. Ops registration does not work
18 // properly without this.
19 class TypesWrapper final {
20  public:
21   static TypesWrapper& GetTypesWrapper();
22 
23   std::unordered_set<std::string>& GetAllowedDataTypes();
24 
25   std::unordered_map<std::string, int32_t>& TypeStrToTensorDataType();
26 
27   std::unordered_map<int32_t, std::string>& TensorDataTypeToTypeStr();
28 
29   ~TypesWrapper() = default;
30   TypesWrapper(const TypesWrapper&) = delete;
31   void operator=(const TypesWrapper&) = delete;
32 
33  private:
34   TypesWrapper();
35 
36   std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
37   std::unordered_map<int, std::string> tensor_data_type_to_type_str_;
38   std::unordered_set<std::string> allowed_data_types_;
39 };
40 
41 // Simple class which contains pointers to external string buffer and a size.
42 // This can be used to track a "valid" range/slice of the string.
43 // Caller should ensure StringRange is not used after external storage has
44 // been freed.
45 class StringRange final {
46  public:
47   StringRange();
48   StringRange(const char* data, size_t size);
49   StringRange(const std::string& str);
50   StringRange(const char* data);
51   const char* Data() const;
52   size_t Size() const;
53   bool Empty() const;
54   char operator[](size_t idx) const;
55   void Reset();
56   void Reset(const char* data, size_t size);
57   void Reset(const std::string& str);
58   bool StartsWith(const StringRange& str) const;
59   bool EndsWith(const StringRange& str) const;
60   bool LStrip();
61   bool LStrip(size_t size);
62   bool LStrip(StringRange str);
63   bool RStrip();
64   bool RStrip(size_t size);
65   bool RStrip(StringRange str);
66   bool LAndRStrip();
67   void ParensWhitespaceStrip();
68   size_t Find(const char ch) const;
69 
70   // These methods provide a way to return the range of the string
71   // which was discarded by LStrip(). i.e. We capture the string
72   // range which was discarded.
73   StringRange GetCaptured();
74   void RestartCapture();
75 
76  private:
77   // data_ + size tracks the "valid" range of the external string buffer.
78   const char* data_;
79   size_t size_;
80 
81   // start_ and end_ track the captured range.
82   // end_ advances when LStrip() is called.
83   const char* start_;
84   const char* end_;
85 };
86 
87 std::unordered_map<std::string, TypeProto>&
GetTypeStrToProtoMap()88 DataTypeUtils::GetTypeStrToProtoMap() {
89   static std::unordered_map<std::string, TypeProto> map;
90   return map;
91 }
92 
GetTypeStrLock()93 std::mutex& DataTypeUtils::GetTypeStrLock() {
94   static std::mutex lock;
95   return lock;
96 }
97 
ToType(const TypeProto & type_proto)98 DataType DataTypeUtils::ToType(const TypeProto& type_proto) {
99   auto typeStr = ToString(type_proto);
100   std::lock_guard<std::mutex> lock(GetTypeStrLock());
101   if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end()) {
102     TypeProto type;
103     FromString(typeStr, type);
104     GetTypeStrToProtoMap()[typeStr] = type;
105   }
106   return &(GetTypeStrToProtoMap().find(typeStr)->first);
107 }
108 
ToType(const std::string & type_str)109 DataType DataTypeUtils::ToType(const std::string& type_str) {
110   TypeProto type;
111   FromString(type_str, type);
112   return ToType(type);
113 }
114 
ToTypeProto(const DataType & data_type)115 const TypeProto& DataTypeUtils::ToTypeProto(const DataType& data_type) {
116   std::lock_guard<std::mutex> lock(GetTypeStrLock());
117   auto it = GetTypeStrToProtoMap().find(*data_type);
118   if (GetTypeStrToProtoMap().end() == it) {
119     ONNX_THROW_EX(std::invalid_argument("Invalid data type " + *data_type));
120   }
121   return it->second;
122 }
123 
ToString(const TypeProto & type_proto,const std::string & left,const std::string & right)124 std::string DataTypeUtils::ToString(
125     const TypeProto& type_proto,
126     const std::string& left,
127     const std::string& right) {
128   switch (type_proto.value_case()) {
129     case TypeProto::ValueCase::kTensorType: {
130       // Note: We do not distinguish tensors with zero rank (a shape consisting
131       // of an empty sequence of dimensions) here.
132       return left + "tensor(" +
133           ToDataTypeString(type_proto.tensor_type().elem_type()) + ")" + right;
134     }
135     case TypeProto::ValueCase::kSequenceType: {
136       return ToString(
137           type_proto.sequence_type().elem_type(), left + "seq(", ")" + right);
138     }
139     case TypeProto::ValueCase::kOptionalType: {
140       return ToString(
141           type_proto.optional_type().elem_type(), left + "optional(", ")" + right);
142     }
143     case TypeProto::ValueCase::kMapType: {
144       std::string map_str =
145           "map(" + ToDataTypeString(type_proto.map_type().key_type()) + ",";
146       return ToString(
147           type_proto.map_type().value_type(), left + map_str, ")" + right);
148     }
149 #ifdef ONNX_ML
150     case TypeProto::ValueCase::kOpaqueType: {
151       static const std::string empty;
152       std::string result;
153       const auto& op_type = type_proto.opaque_type();
154       result.append(left).append("opaque(");
155       if (op_type.has_domain() && !op_type.domain().empty()) {
156         result.append(op_type.domain()).append(",");
157       }
158       if (op_type.has_name() && !op_type.name().empty()) {
159         result.append(op_type.name());
160       }
161       result.append(")").append(right);
162       return result;
163     }
164 #endif
165     case TypeProto::ValueCase::kSparseTensorType: {
166       // Note: We do not distinguish tensors with zero rank (a shape consisting
167       // of an empty sequence of dimensions) here.
168       return left + "sparse_tensor(" +
169           ToDataTypeString(type_proto.sparse_tensor_type().elem_type()) + ")" +
170           right;
171     }
172     default:
173       ONNX_THROW_EX(std::invalid_argument("Unsuported type proto value case."));
174   }
175 }
176 
ToDataTypeString(int32_t tensor_data_type)177 std::string DataTypeUtils::ToDataTypeString(int32_t tensor_data_type) {
178   TypesWrapper& t = TypesWrapper::GetTypesWrapper();
179   auto iter = t.TensorDataTypeToTypeStr().find(tensor_data_type);
180   if (t.TensorDataTypeToTypeStr().end() == iter) {
181     ONNX_THROW_EX(std::invalid_argument("Invalid tensor data type "));
182   }
183   return iter->second;
184 }
185 
FromString(const std::string & type_str,TypeProto & type_proto)186 void DataTypeUtils::FromString(
187     const std::string& type_str,
188     TypeProto& type_proto) {
189   StringRange s(type_str);
190   type_proto.Clear();
191   if (s.LStrip("seq")) {
192     s.ParensWhitespaceStrip();
193     return FromString(
194         std::string(s.Data(), s.Size()),
195         *type_proto.mutable_sequence_type()->mutable_elem_type());
196   } else if (s.LStrip("optional")) {
197     s.ParensWhitespaceStrip();
198     return FromString(
199         std::string(s.Data(), s.Size()),
200         *type_proto.mutable_optional_type()->mutable_elem_type());
201   } else if (s.LStrip("map")) {
202     s.ParensWhitespaceStrip();
203     size_t key_size = s.Find(',');
204     StringRange k(s.Data(), key_size);
205     std::string key(k.Data(), k.Size());
206     s.LStrip(key_size);
207     s.LStrip(",");
208     StringRange v(s.Data(), s.Size());
209     int32_t key_type;
210     FromDataTypeString(key, key_type);
211     type_proto.mutable_map_type()->set_key_type(key_type);
212     return FromString(
213         std::string(v.Data(), v.Size()),
214         *type_proto.mutable_map_type()->mutable_value_type());
215   } else
216 #ifdef ONNX_ML
217       if (s.LStrip("opaque")) {
218     auto* opaque_type = type_proto.mutable_opaque_type();
219     s.ParensWhitespaceStrip();
220     if (!s.Empty()) {
221       size_t cm = s.Find(',');
222       if (cm != std::string::npos) {
223         if (cm > 0) {
224           opaque_type->mutable_domain()->assign(s.Data(), cm);
225         }
226         s.LStrip(cm + 1); // skip comma
227       }
228       if (!s.Empty()) {
229         opaque_type->mutable_name()->assign(s.Data(), s.Size());
230       }
231     }
232   } else
233 #endif
234   if (s.LStrip("sparse_tensor")) {
235     s.ParensWhitespaceStrip();
236     int32_t e;
237     FromDataTypeString(std::string(s.Data(), s.Size()), e);
238     type_proto.mutable_sparse_tensor_type()->set_elem_type(e);
239   } else if (s.LStrip("tensor")) {
240     s.ParensWhitespaceStrip();
241     int32_t e;
242     FromDataTypeString(std::string(s.Data(), s.Size()), e);
243     type_proto.mutable_tensor_type()->set_elem_type(e);
244   } else {
245     // Scalar
246     int32_t e;
247     FromDataTypeString(std::string(s.Data(), s.Size()), e);
248     TypeProto::Tensor* t = type_proto.mutable_tensor_type();
249     t->set_elem_type(e);
250     // Call mutable_shape() to initialize a shape with no dimension.
251     t->mutable_shape();
252   }
253 } // namespace Utils
254 
IsValidDataTypeString(const std::string & type_str)255 bool DataTypeUtils::IsValidDataTypeString(const std::string& type_str) {
256   TypesWrapper& t = TypesWrapper::GetTypesWrapper();
257   const auto& allowedSet = t.GetAllowedDataTypes();
258   return (allowedSet.find(type_str) != allowedSet.end());
259 }
260 
FromDataTypeString(const std::string & type_str,int32_t & tensor_data_type)261 void DataTypeUtils::FromDataTypeString(
262     const std::string& type_str,
263     int32_t& tensor_data_type) {
264   if (!IsValidDataTypeString(type_str)) {
265     ONNX_THROW_EX(std::invalid_argument("DataTypeUtils::FromDataTypeString - Received invalid data type string " + type_str));
266   }
267 
268   TypesWrapper& t = TypesWrapper::GetTypesWrapper();
269   tensor_data_type = t.TypeStrToTensorDataType()[type_str];
270 }
271 
StringRange()272 StringRange::StringRange() : data_(""), size_(0), start_(data_), end_(data_) {}
273 
StringRange(const char * p_data,size_t p_size)274 StringRange::StringRange(const char* p_data, size_t p_size)
275     : data_(p_data), size_(p_size), start_(data_), end_(data_) {
276   assert(p_data != nullptr);
277   LAndRStrip();
278 }
279 
StringRange(const std::string & p_str)280 StringRange::StringRange(const std::string& p_str)
281     : data_(p_str.data()), size_(p_str.size()), start_(data_), end_(data_) {
282   LAndRStrip();
283 }
284 
StringRange(const char * p_data)285 StringRange::StringRange(const char* p_data)
286     : data_(p_data), size_(strlen(p_data)), start_(data_), end_(data_) {
287   LAndRStrip();
288 }
289 
Data() const290 const char* StringRange::Data() const {
291   return data_;
292 }
293 
Size() const294 size_t StringRange::Size() const {
295   return size_;
296 }
297 
Empty() const298 bool StringRange::Empty() const {
299   return size_ == 0;
300 }
301 
operator [](size_t idx) const302 char StringRange::operator[](size_t idx) const {
303   return data_[idx];
304 }
305 
Reset()306 void StringRange::Reset() {
307   data_ = "";
308   size_ = 0;
309   start_ = end_ = data_;
310 }
311 
Reset(const char * data,size_t size)312 void StringRange::Reset(const char* data, size_t size) {
313   data_ = data;
314   size_ = size;
315   start_ = end_ = data_;
316 }
317 
Reset(const std::string & str)318 void StringRange::Reset(const std::string& str) {
319   data_ = str.data();
320   size_ = str.size();
321   start_ = end_ = data_;
322 }
323 
StartsWith(const StringRange & str) const324 bool StringRange::StartsWith(const StringRange& str) const {
325   return ((size_ >= str.size_) && (memcmp(data_, str.data_, str.size_) == 0));
326 }
327 
EndsWith(const StringRange & str) const328 bool StringRange::EndsWith(const StringRange& str) const {
329   return (
330       (size_ >= str.size_) &&
331       (memcmp(data_ + (size_ - str.size_), str.data_, str.size_) == 0));
332 }
333 
LStrip()334 bool StringRange::LStrip() {
335   size_t count = 0;
336   const char* ptr = data_;
337   while (count < size_ && isspace(*ptr)) {
338     count++;
339     ptr++;
340   }
341 
342   if (count > 0) {
343     return LStrip(count);
344   }
345   return false;
346 }
347 
LStrip(size_t size)348 bool StringRange::LStrip(size_t size) {
349   if (size <= size_) {
350     data_ += size;
351     size_ -= size;
352     end_ += size;
353     return true;
354   }
355   return false;
356 }
357 
LStrip(StringRange str)358 bool StringRange::LStrip(StringRange str) {
359   if (StartsWith(str)) {
360     return LStrip(str.size_);
361   }
362   return false;
363 }
364 
RStrip()365 bool StringRange::RStrip() {
366   size_t count = 0;
367   const char* ptr = data_ + size_ - 1;
368   while (count < size_ && isspace(*ptr)) {
369     ++count;
370     --ptr;
371   }
372 
373   if (count > 0) {
374     return RStrip(count);
375   }
376   return false;
377 }
378 
RStrip(size_t size)379 bool StringRange::RStrip(size_t size) {
380   if (size_ >= size) {
381     size_ -= size;
382     return true;
383   }
384   return false;
385 }
386 
RStrip(StringRange str)387 bool StringRange::RStrip(StringRange str) {
388   if (EndsWith(str)) {
389     return RStrip(str.size_);
390   }
391   return false;
392 }
393 
LAndRStrip()394 bool StringRange::LAndRStrip() {
395   bool l = LStrip();
396   bool r = RStrip();
397   return l || r;
398 }
399 
ParensWhitespaceStrip()400 void StringRange::ParensWhitespaceStrip() {
401   LStrip();
402   LStrip("(");
403   LAndRStrip();
404   RStrip(")");
405   RStrip();
406 }
407 
Find(const char ch) const408 size_t StringRange::Find(const char ch) const {
409   size_t idx = 0;
410   while (idx < size_) {
411     if (data_[idx] == ch) {
412       return idx;
413     }
414     idx++;
415   }
416   return std::string::npos;
417 }
418 
RestartCapture()419 void StringRange::RestartCapture() {
420   start_ = data_;
421   end_ = data_;
422 }
423 
GetCaptured()424 StringRange StringRange::GetCaptured() {
425   return StringRange(start_, end_ - start_);
426 }
427 
GetTypesWrapper()428 TypesWrapper& TypesWrapper::GetTypesWrapper() {
429   static TypesWrapper types;
430   return types;
431 }
432 
GetAllowedDataTypes()433 std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes() {
434   return allowed_data_types_;
435 }
436 
TypeStrToTensorDataType()437 std::unordered_map<std::string, int>& TypesWrapper::TypeStrToTensorDataType() {
438   return type_str_to_tensor_data_type_;
439 }
440 
TensorDataTypeToTypeStr()441 std::unordered_map<int, std::string>& TypesWrapper::TensorDataTypeToTypeStr() {
442   return tensor_data_type_to_type_str_;
443 }
444 
TypesWrapper()445 TypesWrapper::TypesWrapper() {
446   // DataType strings. These should match the DataTypes defined in onnx.proto
447   type_str_to_tensor_data_type_["float"] = TensorProto_DataType_FLOAT;
448   type_str_to_tensor_data_type_["float16"] = TensorProto_DataType_FLOAT16;
449   type_str_to_tensor_data_type_["bfloat16"] = TensorProto_DataType_BFLOAT16;
450   type_str_to_tensor_data_type_["double"] = TensorProto_DataType_DOUBLE;
451   type_str_to_tensor_data_type_["int8"] = TensorProto_DataType_INT8;
452   type_str_to_tensor_data_type_["int16"] = TensorProto_DataType_INT16;
453   type_str_to_tensor_data_type_["int32"] = TensorProto_DataType_INT32;
454   type_str_to_tensor_data_type_["int64"] = TensorProto_DataType_INT64;
455   type_str_to_tensor_data_type_["uint8"] = TensorProto_DataType_UINT8;
456   type_str_to_tensor_data_type_["uint16"] = TensorProto_DataType_UINT16;
457   type_str_to_tensor_data_type_["uint32"] = TensorProto_DataType_UINT32;
458   type_str_to_tensor_data_type_["uint64"] = TensorProto_DataType_UINT64;
459   type_str_to_tensor_data_type_["complex64"] = TensorProto_DataType_COMPLEX64;
460   type_str_to_tensor_data_type_["complex128"] = TensorProto_DataType_COMPLEX128;
461   type_str_to_tensor_data_type_["string"] = TensorProto_DataType_STRING;
462   type_str_to_tensor_data_type_["bool"] = TensorProto_DataType_BOOL;
463 
464   for (auto& str_type_pair : type_str_to_tensor_data_type_) {
465     tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first;
466     allowed_data_types_.insert(str_type_pair.first);
467   }
468 }
469 } // namespace Utils
470 } // namespace ONNX_NAMESPACE
471