1 #include <cctype>
2 #include <iostream>
3 #include <iterator>
4 #include <sstream>
5
6 #include "data_type_utils.h"
7
8 namespace ONNX_NAMESPACE {
9 namespace Utils {
10
11 // Singleton wrapper around allowed data types.
12 // This implements construct on first use which is needed to ensure
13 // static objects are initialized before use. Ops registration does not work
14 // properly without this.
15 class TypesWrapper final {
16 public:
17 static TypesWrapper& GetTypesWrapper();
18
19 std::unordered_set<std::string>& GetAllowedDataTypes();
20
21 std::unordered_map<std::string, int32_t>& TypeStrToTensorDataType();
22
23 std::unordered_map<int32_t, std::string>& TensorDataTypeToTypeStr();
24
25 ~TypesWrapper() = default;
26 TypesWrapper(const TypesWrapper&) = delete;
27 void operator=(const TypesWrapper&) = delete;
28
29 private:
30 TypesWrapper();
31
32 std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
33 std::unordered_map<int, std::string> tensor_data_type_to_type_str_;
34 std::unordered_set<std::string> allowed_data_types_;
35 };
36
37 // Simple class which contains pointers to external string buffer and a size.
38 // This can be used to track a "valid" range/slice of the string.
39 // Caller should ensure StringRange is not used after external storage has
40 // been freed.
41 class StringRange final {
42 public:
43 StringRange();
44 StringRange(const char* data, size_t size);
45 StringRange(const std::string& str);
46 StringRange(const char* data);
47 const char* Data() const;
48 size_t Size() const;
49 bool Empty() const;
50 char operator[](size_t idx) const;
51 void Reset();
52 void Reset(const char* data, size_t size);
53 void Reset(const std::string& str);
54 bool StartsWith(const StringRange& str) const;
55 bool EndsWith(const StringRange& str) const;
56 bool LStrip();
57 bool LStrip(size_t size);
58 bool LStrip(StringRange str);
59 bool RStrip();
60 bool RStrip(size_t size);
61 bool RStrip(StringRange str);
62 bool LAndRStrip();
63 void ParensWhitespaceStrip();
64 size_t Find(const char ch) const;
65
66 // These methods provide a way to return the range of the string
67 // which was discarded by LStrip(). i.e. We capture the string
68 // range which was discarded.
69 StringRange GetCaptured();
70 void RestartCapture();
71
72 private:
73 // data_ + size tracks the "valid" range of the external string buffer.
74 const char* data_;
75 size_t size_;
76
77 // start_ and end_ track the captured range.
78 // end_ advances when LStrip() is called.
79 const char* start_;
80 const char* end_;
81 };
82
83 std::unordered_map<std::string, TypeProto>&
GetTypeStrToProtoMap()84 DataTypeUtils::GetTypeStrToProtoMap() {
85 static std::unordered_map<std::string, TypeProto> map;
86 return map;
87 }
88
GetTypeStrLock()89 std::mutex& DataTypeUtils::GetTypeStrLock() {
90 static std::mutex lock;
91 return lock;
92 }
93
ToType(const TypeProto & type_proto)94 DataType DataTypeUtils::ToType(const TypeProto& type_proto) {
95 auto typeStr = ToString(type_proto);
96 std::lock_guard<std::mutex> lock(GetTypeStrLock());
97 if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end()) {
98 TypeProto type;
99 FromString(typeStr, type);
100 GetTypeStrToProtoMap()[typeStr] = type;
101 }
102 return &(GetTypeStrToProtoMap().find(typeStr)->first);
103 }
104
ToType(const std::string & type_str)105 DataType DataTypeUtils::ToType(const std::string& type_str) {
106 TypeProto type;
107 FromString(type_str, type);
108 return ToType(type);
109 }
110
ToTypeProto(const DataType & data_type)111 const TypeProto& DataTypeUtils::ToTypeProto(const DataType& data_type) {
112 std::lock_guard<std::mutex> lock(GetTypeStrLock());
113 auto it = GetTypeStrToProtoMap().find(*data_type);
114 assert(it != GetTypeStrToProtoMap().end());
115 return it->second;
116 }
117
ToString(const TypeProto & type_proto,const std::string & left,const std::string & right)118 std::string DataTypeUtils::ToString(
119 const TypeProto& type_proto,
120 const std::string& left,
121 const std::string& right) {
122 switch (type_proto.value_case()) {
123 case TypeProto::ValueCase::kTensorType: {
124 // Note: We do not distinguish tensors with zero rank (a shape consisting
125 // of an empty sequence of dimensions) here.
126 return left + "tensor(" +
127 ToDataTypeString(type_proto.tensor_type().elem_type()) + ")" + right;
128 }
129 #ifdef ONNX_ML
130 case TypeProto::ValueCase::kSequenceType: {
131 return ToString(
132 type_proto.sequence_type().elem_type(), left + "seq(", ")" + right);
133 }
134
135 case TypeProto::ValueCase::kMapType: {
136 std::string map_str =
137 "map(" + ToDataTypeString(type_proto.map_type().key_type()) + ",";
138 return ToString(
139 type_proto.map_type().value_type(), left + map_str, ")" + right);
140 }
141 case TypeProto::ValueCase::kOpaqueType: {
142 static const std::string empty;
143 std::string result;
144 const auto& op_type = type_proto.opaque_type();
145 result.append(left).append("opaque(");
146 if (op_type.has_domain() && !op_type.domain().empty()) {
147 result.append(op_type.domain()).append(",");
148 }
149 if (op_type.has_name() && !op_type.name().empty()) {
150 result.append(op_type.name());
151 }
152 result.append(")").append(right);
153 return result;
154 }
155 case TypeProto::ValueCase::kSparseTensorType: {
156 // Note: We do not distinguish tensors with zero rank (a shape consisting
157 // of an empty sequence of dimensions) here.
158 return left + "sparse_tensor(" +
159 ToDataTypeString(type_proto.sparse_tensor_type().elem_type()) + ")" +
160 right;
161 }
162 #endif
163 default:
164 assert(false);
165 return std::string();
166 }
167 }
168
ToDataTypeString(int32_t tensor_data_type)169 std::string DataTypeUtils::ToDataTypeString(int32_t tensor_data_type) {
170 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
171 auto iter = t.TensorDataTypeToTypeStr().find(tensor_data_type);
172 assert(t.TensorDataTypeToTypeStr().end() != iter);
173 return iter->second;
174 }
175
FromString(const std::string & type_str,TypeProto & type_proto)176 void DataTypeUtils::FromString(
177 const std::string& type_str,
178 TypeProto& type_proto) {
179 StringRange s(type_str);
180 type_proto.Clear();
181 #ifdef ONNX_ML
182 if (s.LStrip("seq")) {
183 s.ParensWhitespaceStrip();
184 return FromString(
185 std::string(s.Data(), s.Size()),
186 *type_proto.mutable_sequence_type()->mutable_elem_type());
187 } else if (s.LStrip("map")) {
188 s.ParensWhitespaceStrip();
189 size_t key_size = s.Find(',');
190 StringRange k(s.Data(), key_size);
191 std::string key(k.Data(), k.Size());
192 s.LStrip(key_size);
193 s.LStrip(",");
194 StringRange v(s.Data(), s.Size());
195 int32_t key_type;
196 FromDataTypeString(key, key_type);
197 type_proto.mutable_map_type()->set_key_type(key_type);
198 return FromString(
199 std::string(v.Data(), v.Size()),
200 *type_proto.mutable_map_type()->mutable_value_type());
201 } else if (s.LStrip("opaque")) {
202 auto* opaque_type = type_proto.mutable_opaque_type();
203 s.ParensWhitespaceStrip();
204 if (!s.Empty()) {
205 size_t cm = s.Find(',');
206 if (cm != std::string::npos) {
207 if (cm > 0) {
208 opaque_type->mutable_domain()->assign(s.Data(), cm);
209 }
210 s.LStrip(cm + 1); // skip comma
211 }
212 if (!s.Empty()) {
213 opaque_type->mutable_name()->assign(s.Data(), s.Size());
214 }
215 }
216 } else if (s.LStrip("sparse_tensor")) {
217 s.ParensWhitespaceStrip();
218 int32_t e;
219 FromDataTypeString(std::string(s.Data(), s.Size()), e);
220 type_proto.mutable_sparse_tensor_type()->set_elem_type(e);
221 } else
222 #endif
223 if (s.LStrip("tensor")) {
224 s.ParensWhitespaceStrip();
225 int32_t e;
226 FromDataTypeString(std::string(s.Data(), s.Size()), e);
227 type_proto.mutable_tensor_type()->set_elem_type(e);
228 } else {
229 // Scalar
230 int32_t e;
231 FromDataTypeString(std::string(s.Data(), s.Size()), e);
232 TypeProto::Tensor* t = type_proto.mutable_tensor_type();
233 t->set_elem_type(e);
234 // Call mutable_shape() to initialize a shape with no dimension.
235 t->mutable_shape();
236 }
237 } // namespace Utils
238
IsValidDataTypeString(const std::string & type_str)239 bool DataTypeUtils::IsValidDataTypeString(const std::string& type_str) {
240 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
241 const auto& allowedSet = t.GetAllowedDataTypes();
242 return (allowedSet.find(type_str) != allowedSet.end());
243 }
244
FromDataTypeString(const std::string & type_str,int32_t & tensor_data_type)245 void DataTypeUtils::FromDataTypeString(
246 const std::string& type_str,
247 int32_t& tensor_data_type) {
248 assert(IsValidDataTypeString(type_str));
249
250 TypesWrapper& t = TypesWrapper::GetTypesWrapper();
251 tensor_data_type = t.TypeStrToTensorDataType()[type_str];
252 }
253
StringRange()254 StringRange::StringRange() : data_(""), size_(0), start_(data_), end_(data_) {}
255
StringRange(const char * p_data,size_t p_size)256 StringRange::StringRange(const char* p_data, size_t p_size)
257 : data_(p_data), size_(p_size), start_(data_), end_(data_) {
258 assert(p_data != nullptr);
259 LAndRStrip();
260 }
261
StringRange(const std::string & p_str)262 StringRange::StringRange(const std::string& p_str)
263 : data_(p_str.data()), size_(p_str.size()), start_(data_), end_(data_) {
264 LAndRStrip();
265 }
266
StringRange(const char * p_data)267 StringRange::StringRange(const char* p_data)
268 : data_(p_data), size_(strlen(p_data)), start_(data_), end_(data_) {
269 LAndRStrip();
270 }
271
Data() const272 const char* StringRange::Data() const {
273 return data_;
274 }
275
Size() const276 size_t StringRange::Size() const {
277 return size_;
278 }
279
Empty() const280 bool StringRange::Empty() const {
281 return size_ == 0;
282 }
283
operator [](size_t idx) const284 char StringRange::operator[](size_t idx) const {
285 return data_[idx];
286 }
287
Reset()288 void StringRange::Reset() {
289 data_ = "";
290 size_ = 0;
291 start_ = end_ = data_;
292 }
293
Reset(const char * data,size_t size)294 void StringRange::Reset(const char* data, size_t size) {
295 data_ = data;
296 size_ = size;
297 start_ = end_ = data_;
298 }
299
Reset(const std::string & str)300 void StringRange::Reset(const std::string& str) {
301 data_ = str.data();
302 size_ = str.size();
303 start_ = end_ = data_;
304 }
305
StartsWith(const StringRange & str) const306 bool StringRange::StartsWith(const StringRange& str) const {
307 return ((size_ >= str.size_) && (memcmp(data_, str.data_, str.size_) == 0));
308 }
309
EndsWith(const StringRange & str) const310 bool StringRange::EndsWith(const StringRange& str) const {
311 return (
312 (size_ >= str.size_) &&
313 (memcmp(data_ + (size_ - str.size_), str.data_, str.size_) == 0));
314 }
315
LStrip()316 bool StringRange::LStrip() {
317 size_t count = 0;
318 const char* ptr = data_;
319 while (count < size_ && isspace(*ptr)) {
320 count++;
321 ptr++;
322 }
323
324 if (count > 0) {
325 return LStrip(count);
326 }
327 return false;
328 }
329
LStrip(size_t size)330 bool StringRange::LStrip(size_t size) {
331 if (size <= size_) {
332 data_ += size;
333 size_ -= size;
334 end_ += size;
335 return true;
336 }
337 return false;
338 }
339
LStrip(StringRange str)340 bool StringRange::LStrip(StringRange str) {
341 if (StartsWith(str)) {
342 return LStrip(str.size_);
343 }
344 return false;
345 }
346
RStrip()347 bool StringRange::RStrip() {
348 size_t count = 0;
349 const char* ptr = data_ + size_ - 1;
350 while (count < size_ && isspace(*ptr)) {
351 ++count;
352 --ptr;
353 }
354
355 if (count > 0) {
356 return RStrip(count);
357 }
358 return false;
359 }
360
RStrip(size_t size)361 bool StringRange::RStrip(size_t size) {
362 if (size_ >= size) {
363 size_ -= size;
364 return true;
365 }
366 return false;
367 }
368
RStrip(StringRange str)369 bool StringRange::RStrip(StringRange str) {
370 if (EndsWith(str)) {
371 return RStrip(str.size_);
372 }
373 return false;
374 }
375
LAndRStrip()376 bool StringRange::LAndRStrip() {
377 bool l = LStrip();
378 bool r = RStrip();
379 return l || r;
380 }
381
ParensWhitespaceStrip()382 void StringRange::ParensWhitespaceStrip() {
383 LStrip();
384 LStrip("(");
385 LAndRStrip();
386 RStrip(")");
387 RStrip();
388 }
389
Find(const char ch) const390 size_t StringRange::Find(const char ch) const {
391 size_t idx = 0;
392 while (idx < size_) {
393 if (data_[idx] == ch) {
394 return idx;
395 }
396 idx++;
397 }
398 return std::string::npos;
399 }
400
RestartCapture()401 void StringRange::RestartCapture() {
402 start_ = data_;
403 end_ = data_;
404 }
405
GetCaptured()406 StringRange StringRange::GetCaptured() {
407 return StringRange(start_, end_ - start_);
408 }
409
GetTypesWrapper()410 TypesWrapper& TypesWrapper::GetTypesWrapper() {
411 static TypesWrapper types;
412 return types;
413 }
414
GetAllowedDataTypes()415 std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes() {
416 return allowed_data_types_;
417 }
418
TypeStrToTensorDataType()419 std::unordered_map<std::string, int>& TypesWrapper::TypeStrToTensorDataType() {
420 return type_str_to_tensor_data_type_;
421 }
422
TensorDataTypeToTypeStr()423 std::unordered_map<int, std::string>& TypesWrapper::TensorDataTypeToTypeStr() {
424 return tensor_data_type_to_type_str_;
425 }
426
TypesWrapper()427 TypesWrapper::TypesWrapper() {
428 // DataType strings. These should match the DataTypes defined in onnx.proto
429 type_str_to_tensor_data_type_["float"] = TensorProto_DataType_FLOAT;
430 type_str_to_tensor_data_type_["float16"] = TensorProto_DataType_FLOAT16;
431 type_str_to_tensor_data_type_["bfloat16"] = TensorProto_DataType_BFLOAT16;
432 type_str_to_tensor_data_type_["double"] = TensorProto_DataType_DOUBLE;
433 type_str_to_tensor_data_type_["int8"] = TensorProto_DataType_INT8;
434 type_str_to_tensor_data_type_["int16"] = TensorProto_DataType_INT16;
435 type_str_to_tensor_data_type_["int32"] = TensorProto_DataType_INT32;
436 type_str_to_tensor_data_type_["int64"] = TensorProto_DataType_INT64;
437 type_str_to_tensor_data_type_["uint8"] = TensorProto_DataType_UINT8;
438 type_str_to_tensor_data_type_["uint16"] = TensorProto_DataType_UINT16;
439 type_str_to_tensor_data_type_["uint32"] = TensorProto_DataType_UINT32;
440 type_str_to_tensor_data_type_["uint64"] = TensorProto_DataType_UINT64;
441 type_str_to_tensor_data_type_["complex64"] = TensorProto_DataType_COMPLEX64;
442 type_str_to_tensor_data_type_["complex128"] = TensorProto_DataType_COMPLEX128;
443 type_str_to_tensor_data_type_["string"] = TensorProto_DataType_STRING;
444 type_str_to_tensor_data_type_["bool"] = TensorProto_DataType_BOOL;
445
446 for (auto& str_type_pair : type_str_to_tensor_data_type_) {
447 tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first;
448 allowed_data_types_.insert(str_type_pair.first);
449 }
450 }
451 } // namespace Utils
452 } // namespace ONNX_NAMESPACE
453