1 #include <cctype>
2 #include <iostream>
3 #include <iterator>
4 #include <sstream>
5 
6 #include "data_type_utils.h"
7 
8 namespace ONNX_NAMESPACE {
9 namespace Utils {
10 
11 // Singleton wrapper around allowed data types.
12 // This implements construct on first use which is needed to ensure
13 // static objects are initialized before use. Ops registration does not work
14 // properly without this.
15 class TypesWrapper final {
16  public:
17   static TypesWrapper& GetTypesWrapper();
18 
19   std::unordered_set<std::string>& GetAllowedDataTypes();
20 
21   std::unordered_map<std::string, int32_t>& TypeStrToTensorDataType();
22 
23   std::unordered_map<int32_t, std::string>& TensorDataTypeToTypeStr();
24 
25   ~TypesWrapper() = default;
26   TypesWrapper(const TypesWrapper&) = delete;
27   void operator=(const TypesWrapper&) = delete;
28 
29  private:
30   TypesWrapper();
31 
32   std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
33   std::unordered_map<int, std::string> tensor_data_type_to_type_str_;
34   std::unordered_set<std::string> allowed_data_types_;
35 };
36 
37 // Simple class which contains pointers to external string buffer and a size.
38 // This can be used to track a "valid" range/slice of the string.
39 // Caller should ensure StringRange is not used after external storage has
40 // been freed.
41 class StringRange final {
42  public:
43   StringRange();
44   StringRange(const char* data, size_t size);
45   StringRange(const std::string& str);
46   StringRange(const char* data);
47   const char* Data() const;
48   size_t Size() const;
49   bool Empty() const;
50   char operator[](size_t idx) const;
51   void Reset();
52   void Reset(const char* data, size_t size);
53   void Reset(const std::string& str);
54   bool StartsWith(const StringRange& str) const;
55   bool EndsWith(const StringRange& str) const;
56   bool LStrip();
57   bool LStrip(size_t size);
58   bool LStrip(StringRange str);
59   bool RStrip();
60   bool RStrip(size_t size);
61   bool RStrip(StringRange str);
62   bool LAndRStrip();
63   void ParensWhitespaceStrip();
64   size_t Find(const char ch) const;
65 
66   // These methods provide a way to return the range of the string
67   // which was discarded by LStrip(). i.e. We capture the string
68   // range which was discarded.
69   StringRange GetCaptured();
70   void RestartCapture();
71 
72  private:
73   // data_ + size tracks the "valid" range of the external string buffer.
74   const char* data_;
75   size_t size_;
76 
77   // start_ and end_ track the captured range.
78   // end_ advances when LStrip() is called.
79   const char* start_;
80   const char* end_;
81 };
82 
83 std::unordered_map<std::string, TypeProto>&
GetTypeStrToProtoMap()84 DataTypeUtils::GetTypeStrToProtoMap() {
85   static std::unordered_map<std::string, TypeProto> map;
86   return map;
87 }
88 
GetTypeStrLock()89 std::mutex& DataTypeUtils::GetTypeStrLock() {
90   static std::mutex lock;
91   return lock;
92 }
93 
ToType(const TypeProto & type_proto)94 DataType DataTypeUtils::ToType(const TypeProto& type_proto) {
95   auto typeStr = ToString(type_proto);
96   std::lock_guard<std::mutex> lock(GetTypeStrLock());
97   if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end()) {
98     TypeProto type;
99     FromString(typeStr, type);
100     GetTypeStrToProtoMap()[typeStr] = type;
101   }
102   return &(GetTypeStrToProtoMap().find(typeStr)->first);
103 }
104 
ToType(const std::string & type_str)105 DataType DataTypeUtils::ToType(const std::string& type_str) {
106   TypeProto type;
107   FromString(type_str, type);
108   return ToType(type);
109 }
110 
ToTypeProto(const DataType & data_type)111 const TypeProto& DataTypeUtils::ToTypeProto(const DataType& data_type) {
112   std::lock_guard<std::mutex> lock(GetTypeStrLock());
113   auto it = GetTypeStrToProtoMap().find(*data_type);
114   assert(it != GetTypeStrToProtoMap().end());
115   return it->second;
116 }
117 
ToString(const TypeProto & type_proto,const std::string & left,const std::string & right)118 std::string DataTypeUtils::ToString(
119     const TypeProto& type_proto,
120     const std::string& left,
121     const std::string& right) {
122   switch (type_proto.value_case()) {
123     case TypeProto::ValueCase::kTensorType: {
124       // Note: We do not distinguish tensors with zero rank (a shape consisting
125       // of an empty sequence of dimensions) here.
126       return left + "tensor(" +
127           ToDataTypeString(type_proto.tensor_type().elem_type()) + ")" + right;
128     }
129 #ifdef ONNX_ML
130     case TypeProto::ValueCase::kSequenceType: {
131       return ToString(
132           type_proto.sequence_type().elem_type(), left + "seq(", ")" + right);
133     }
134 
135     case TypeProto::ValueCase::kMapType: {
136       std::string map_str =
137           "map(" + ToDataTypeString(type_proto.map_type().key_type()) + ",";
138       return ToString(
139           type_proto.map_type().value_type(), left + map_str, ")" + right);
140     }
141     case TypeProto::ValueCase::kOpaqueType: {
142       static const std::string empty;
143       std::string result;
144       const auto& op_type = type_proto.opaque_type();
145       result.append(left).append("opaque(");
146       if (op_type.has_domain() && !op_type.domain().empty()) {
147         result.append(op_type.domain()).append(",");
148       }
149       if (op_type.has_name() && !op_type.name().empty()) {
150         result.append(op_type.name());
151       }
152       result.append(")").append(right);
153       return result;
154     }
155     case TypeProto::ValueCase::kSparseTensorType: {
156       // Note: We do not distinguish tensors with zero rank (a shape consisting
157       // of an empty sequence of dimensions) here.
158       return left + "sparse_tensor(" +
159           ToDataTypeString(type_proto.sparse_tensor_type().elem_type()) + ")" +
160           right;
161     }
162 #endif
163     default:
164       assert(false);
165       return std::string();
166   }
167 }
168 
ToDataTypeString(int32_t tensor_data_type)169 std::string DataTypeUtils::ToDataTypeString(int32_t tensor_data_type) {
170   TypesWrapper& t = TypesWrapper::GetTypesWrapper();
171   auto iter = t.TensorDataTypeToTypeStr().find(tensor_data_type);
172   assert(t.TensorDataTypeToTypeStr().end() != iter);
173   return iter->second;
174 }
175 
FromString(const std::string & type_str,TypeProto & type_proto)176 void DataTypeUtils::FromString(
177     const std::string& type_str,
178     TypeProto& type_proto) {
179   StringRange s(type_str);
180   type_proto.Clear();
181 #ifdef ONNX_ML
182   if (s.LStrip("seq")) {
183     s.ParensWhitespaceStrip();
184     return FromString(
185         std::string(s.Data(), s.Size()),
186         *type_proto.mutable_sequence_type()->mutable_elem_type());
187   } else if (s.LStrip("map")) {
188     s.ParensWhitespaceStrip();
189     size_t key_size = s.Find(',');
190     StringRange k(s.Data(), key_size);
191     std::string key(k.Data(), k.Size());
192     s.LStrip(key_size);
193     s.LStrip(",");
194     StringRange v(s.Data(), s.Size());
195     int32_t key_type;
196     FromDataTypeString(key, key_type);
197     type_proto.mutable_map_type()->set_key_type(key_type);
198     return FromString(
199         std::string(v.Data(), v.Size()),
200         *type_proto.mutable_map_type()->mutable_value_type());
201   } else if (s.LStrip("opaque")) {
202     auto* opaque_type = type_proto.mutable_opaque_type();
203     s.ParensWhitespaceStrip();
204     if (!s.Empty()) {
205       size_t cm = s.Find(',');
206       if (cm != std::string::npos) {
207         if (cm > 0) {
208           opaque_type->mutable_domain()->assign(s.Data(), cm);
209         }
210         s.LStrip(cm + 1); // skip comma
211       }
212       if (!s.Empty()) {
213         opaque_type->mutable_name()->assign(s.Data(), s.Size());
214       }
215     }
216   } else if (s.LStrip("sparse_tensor")) {
217     s.ParensWhitespaceStrip();
218     int32_t e;
219     FromDataTypeString(std::string(s.Data(), s.Size()), e);
220     type_proto.mutable_sparse_tensor_type()->set_elem_type(e);
221   } else
222 #endif
223       if (s.LStrip("tensor")) {
224     s.ParensWhitespaceStrip();
225     int32_t e;
226     FromDataTypeString(std::string(s.Data(), s.Size()), e);
227     type_proto.mutable_tensor_type()->set_elem_type(e);
228   } else {
229     // Scalar
230     int32_t e;
231     FromDataTypeString(std::string(s.Data(), s.Size()), e);
232     TypeProto::Tensor* t = type_proto.mutable_tensor_type();
233     t->set_elem_type(e);
234     // Call mutable_shape() to initialize a shape with no dimension.
235     t->mutable_shape();
236   }
237 } // namespace Utils
238 
IsValidDataTypeString(const std::string & type_str)239 bool DataTypeUtils::IsValidDataTypeString(const std::string& type_str) {
240   TypesWrapper& t = TypesWrapper::GetTypesWrapper();
241   const auto& allowedSet = t.GetAllowedDataTypes();
242   return (allowedSet.find(type_str) != allowedSet.end());
243 }
244 
FromDataTypeString(const std::string & type_str,int32_t & tensor_data_type)245 void DataTypeUtils::FromDataTypeString(
246     const std::string& type_str,
247     int32_t& tensor_data_type) {
248   assert(IsValidDataTypeString(type_str));
249 
250   TypesWrapper& t = TypesWrapper::GetTypesWrapper();
251   tensor_data_type = t.TypeStrToTensorDataType()[type_str];
252 }
253 
StringRange()254 StringRange::StringRange() : data_(""), size_(0), start_(data_), end_(data_) {}
255 
StringRange(const char * p_data,size_t p_size)256 StringRange::StringRange(const char* p_data, size_t p_size)
257     : data_(p_data), size_(p_size), start_(data_), end_(data_) {
258   assert(p_data != nullptr);
259   LAndRStrip();
260 }
261 
StringRange(const std::string & p_str)262 StringRange::StringRange(const std::string& p_str)
263     : data_(p_str.data()), size_(p_str.size()), start_(data_), end_(data_) {
264   LAndRStrip();
265 }
266 
StringRange(const char * p_data)267 StringRange::StringRange(const char* p_data)
268     : data_(p_data), size_(strlen(p_data)), start_(data_), end_(data_) {
269   LAndRStrip();
270 }
271 
Data() const272 const char* StringRange::Data() const {
273   return data_;
274 }
275 
Size() const276 size_t StringRange::Size() const {
277   return size_;
278 }
279 
Empty() const280 bool StringRange::Empty() const {
281   return size_ == 0;
282 }
283 
operator [](size_t idx) const284 char StringRange::operator[](size_t idx) const {
285   return data_[idx];
286 }
287 
Reset()288 void StringRange::Reset() {
289   data_ = "";
290   size_ = 0;
291   start_ = end_ = data_;
292 }
293 
Reset(const char * data,size_t size)294 void StringRange::Reset(const char* data, size_t size) {
295   data_ = data;
296   size_ = size;
297   start_ = end_ = data_;
298 }
299 
Reset(const std::string & str)300 void StringRange::Reset(const std::string& str) {
301   data_ = str.data();
302   size_ = str.size();
303   start_ = end_ = data_;
304 }
305 
StartsWith(const StringRange & str) const306 bool StringRange::StartsWith(const StringRange& str) const {
307   return ((size_ >= str.size_) && (memcmp(data_, str.data_, str.size_) == 0));
308 }
309 
EndsWith(const StringRange & str) const310 bool StringRange::EndsWith(const StringRange& str) const {
311   return (
312       (size_ >= str.size_) &&
313       (memcmp(data_ + (size_ - str.size_), str.data_, str.size_) == 0));
314 }
315 
LStrip()316 bool StringRange::LStrip() {
317   size_t count = 0;
318   const char* ptr = data_;
319   while (count < size_ && isspace(*ptr)) {
320     count++;
321     ptr++;
322   }
323 
324   if (count > 0) {
325     return LStrip(count);
326   }
327   return false;
328 }
329 
LStrip(size_t size)330 bool StringRange::LStrip(size_t size) {
331   if (size <= size_) {
332     data_ += size;
333     size_ -= size;
334     end_ += size;
335     return true;
336   }
337   return false;
338 }
339 
LStrip(StringRange str)340 bool StringRange::LStrip(StringRange str) {
341   if (StartsWith(str)) {
342     return LStrip(str.size_);
343   }
344   return false;
345 }
346 
RStrip()347 bool StringRange::RStrip() {
348   size_t count = 0;
349   const char* ptr = data_ + size_ - 1;
350   while (count < size_ && isspace(*ptr)) {
351     ++count;
352     --ptr;
353   }
354 
355   if (count > 0) {
356     return RStrip(count);
357   }
358   return false;
359 }
360 
RStrip(size_t size)361 bool StringRange::RStrip(size_t size) {
362   if (size_ >= size) {
363     size_ -= size;
364     return true;
365   }
366   return false;
367 }
368 
RStrip(StringRange str)369 bool StringRange::RStrip(StringRange str) {
370   if (EndsWith(str)) {
371     return RStrip(str.size_);
372   }
373   return false;
374 }
375 
LAndRStrip()376 bool StringRange::LAndRStrip() {
377   bool l = LStrip();
378   bool r = RStrip();
379   return l || r;
380 }
381 
ParensWhitespaceStrip()382 void StringRange::ParensWhitespaceStrip() {
383   LStrip();
384   LStrip("(");
385   LAndRStrip();
386   RStrip(")");
387   RStrip();
388 }
389 
Find(const char ch) const390 size_t StringRange::Find(const char ch) const {
391   size_t idx = 0;
392   while (idx < size_) {
393     if (data_[idx] == ch) {
394       return idx;
395     }
396     idx++;
397   }
398   return std::string::npos;
399 }
400 
RestartCapture()401 void StringRange::RestartCapture() {
402   start_ = data_;
403   end_ = data_;
404 }
405 
GetCaptured()406 StringRange StringRange::GetCaptured() {
407   return StringRange(start_, end_ - start_);
408 }
409 
GetTypesWrapper()410 TypesWrapper& TypesWrapper::GetTypesWrapper() {
411   static TypesWrapper types;
412   return types;
413 }
414 
GetAllowedDataTypes()415 std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes() {
416   return allowed_data_types_;
417 }
418 
TypeStrToTensorDataType()419 std::unordered_map<std::string, int>& TypesWrapper::TypeStrToTensorDataType() {
420   return type_str_to_tensor_data_type_;
421 }
422 
TensorDataTypeToTypeStr()423 std::unordered_map<int, std::string>& TypesWrapper::TensorDataTypeToTypeStr() {
424   return tensor_data_type_to_type_str_;
425 }
426 
TypesWrapper()427 TypesWrapper::TypesWrapper() {
428   // DataType strings. These should match the DataTypes defined in onnx.proto
429   type_str_to_tensor_data_type_["float"] = TensorProto_DataType_FLOAT;
430   type_str_to_tensor_data_type_["float16"] = TensorProto_DataType_FLOAT16;
431   type_str_to_tensor_data_type_["bfloat16"] = TensorProto_DataType_BFLOAT16;
432   type_str_to_tensor_data_type_["double"] = TensorProto_DataType_DOUBLE;
433   type_str_to_tensor_data_type_["int8"] = TensorProto_DataType_INT8;
434   type_str_to_tensor_data_type_["int16"] = TensorProto_DataType_INT16;
435   type_str_to_tensor_data_type_["int32"] = TensorProto_DataType_INT32;
436   type_str_to_tensor_data_type_["int64"] = TensorProto_DataType_INT64;
437   type_str_to_tensor_data_type_["uint8"] = TensorProto_DataType_UINT8;
438   type_str_to_tensor_data_type_["uint16"] = TensorProto_DataType_UINT16;
439   type_str_to_tensor_data_type_["uint32"] = TensorProto_DataType_UINT32;
440   type_str_to_tensor_data_type_["uint64"] = TensorProto_DataType_UINT64;
441   type_str_to_tensor_data_type_["complex64"] = TensorProto_DataType_COMPLEX64;
442   type_str_to_tensor_data_type_["complex128"] = TensorProto_DataType_COMPLEX128;
443   type_str_to_tensor_data_type_["string"] = TensorProto_DataType_STRING;
444   type_str_to_tensor_data_type_["bool"] = TensorProto_DataType_BOOL;
445 
446   for (auto& str_type_pair : type_str_to_tensor_data_type_) {
447     tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first;
448     allowed_data_types_.insert(str_type_pair.first);
449   }
450 }
451 } // namespace Utils
452 } // namespace ONNX_NAMESPACE
453