1 /* 2 * SPDX-License-Identifier: Apache-2.0 3 */ 4 5 // Experimental language syntax and parser for ONNX. Please note that the syntax as formalized 6 // by this parser is preliminary and may change. 7 8 #pragma once 9 10 #include <ctype.h> 11 #include <iostream> 12 #include <stdexcept> 13 #include <string> 14 #include <unordered_map> 15 16 #include "onnx/onnx_pb.h" 17 18 #include "onnx/common/status.h" 19 #include "onnx/string_utils.h" 20 21 namespace ONNX_NAMESPACE { 22 23 using namespace ONNX_NAMESPACE::Common; 24 25 using IdList = google::protobuf::RepeatedPtrField<std::string>; 26 27 using NodeList = google::protobuf::RepeatedPtrField<NodeProto>; 28 29 using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>; 30 31 using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>; 32 33 using TensorList = google::protobuf::RepeatedPtrField<TensorProto>; 34 35 #define CHECK_PARSER_STATUS(status) \ 36 { \ 37 auto local_status_ = status; \ 38 if (!local_status_.IsOK()) \ 39 return local_status_; \ 40 } 41 42 class PrimitiveTypeNameMap { 43 public: PrimitiveTypeNameMap()44 PrimitiveTypeNameMap() { 45 map_["float"] = 1; 46 map_["uint8"] = 2; 47 map_["int8"] = 3; 48 map_["uint16"] = 4; 49 map_["int16"] = 5; 50 map_["int32"] = 6; 51 map_["int64"] = 7; 52 map_["string"] = 8; 53 map_["bool"] = 9; 54 map_["float16"] = 10; 55 map_["double"] = 11; 56 map_["uint32"] = 12; 57 map_["uint64"] = 13; 58 map_["complex64"] = 14; 59 map_["complex128"] = 15; 60 map_["bfloat16"] = 16; 61 } 62 Instance()63 static const std::unordered_map<std::string, int32_t>& Instance() { 64 static PrimitiveTypeNameMap instance; 65 return instance.map_; 66 } 67 Lookup(const std::string & dtype)68 static int32_t Lookup(const std::string& dtype) { 69 auto it = Instance().find(dtype); 70 if (it != Instance().end()) 71 return it->second; 72 return 0; 73 } 74 IsTypeName(const std::string & dtype)75 static bool IsTypeName(const std::string& dtype) { 76 return Lookup(dtype) != 0; 77 } 78 ToString(int32_t dtype)79 static const std::string& ToString(int32_t dtype) { 80 static std::string undefined("undefined"); 81 for (const auto& pair : Instance()) { 82 if (pair.second == dtype) 83 return pair.first; 84 } 85 return undefined; 86 } 87 88 private: 89 std::unordered_map<std::string, int32_t> map_; 90 }; 91 92 class KeyWordMap { 93 public: 94 enum class KeyWord { 95 NONE, 96 IR_VERSION, 97 OPSET_IMPORT, 98 PRODUCER_NAME, 99 PRODUCER_VERSION, 100 DOMAIN_KW, 101 MODEL_VERSION, 102 DOC_STRING, 103 METADATA_PROPS 104 }; 105 KeyWordMap()106 KeyWordMap() { 107 map_["ir_version"] = KeyWord::IR_VERSION; 108 map_["opset_import"] = KeyWord::OPSET_IMPORT; 109 map_["producer_name"] = KeyWord::PRODUCER_NAME; 110 map_["producer_version"] = KeyWord::PRODUCER_VERSION; 111 map_["domain"] = KeyWord::DOMAIN_KW; 112 map_["model_version"] = KeyWord::MODEL_VERSION; 113 map_["doc_string"] = KeyWord::DOC_STRING; 114 map_["metadata_props"] = KeyWord::METADATA_PROPS; 115 } 116 Instance()117 static const std::unordered_map<std::string, KeyWord>& Instance() { 118 static KeyWordMap instance; 119 return instance.map_; 120 } 121 Lookup(const std::string & id)122 static KeyWord Lookup(const std::string& id) { 123 auto it = Instance().find(id); 124 if (it != Instance().end()) 125 return it->second; 126 return KeyWord::NONE; 127 } 128 129 private: 130 std::unordered_map<std::string, KeyWord> map_; 131 }; 132 133 class ParserBase { 134 public: ParserBase(const std::string & str)135 ParserBase(const std::string& str) 136 : start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {} 137 ParserBase(const char * cstr)138 ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {} 139 SavePos()140 void SavePos() { 141 saved_pos_ = next_; 142 } 143 RestorePos()144 void RestorePos() { 145 next_ = saved_pos_; 146 } 147 GetCurrentPos()148 std::string GetCurrentPos() { 149 uint32_t line = 1, col = 1; 150 for (const char* p = start_; p < next_; ++p) { 151 if (*p == '\n') { 152 ++line; 153 col = 1; 154 } else { 155 ++col; 156 } 157 } 158 return ONNX_NAMESPACE::MakeString("(line: ", line, " column: ", col, ")"); 159 } 160 161 template <typename... Args> ParseError(const Args &...args)162 Status ParseError(const Args&... args) { 163 return Status(NONE, FAIL, ONNX_NAMESPACE::MakeString("[ParseError at position ", GetCurrentPos(), "]", args...)); 164 } 165 SkipWhiteSpace()166 void SkipWhiteSpace() { 167 while ((next_ < end_) && (isspace(*next_))) 168 ++next_; 169 } 170 171 int NextChar(bool skipspace = true) { 172 if (skipspace) 173 SkipWhiteSpace(); 174 return (next_ < end_) ? *next_ : 0; 175 } 176 177 bool Matches(char ch, bool skipspace = true) { 178 if (skipspace) 179 SkipWhiteSpace(); 180 if ((next_ < end_) && (*next_ == ch)) { 181 ++next_; 182 return true; 183 } 184 return false; 185 } 186 187 Status Match(char ch, bool skipspace = true) { 188 if (!Matches(ch, skipspace)) 189 return ParseError("Expected character ", ch, " not found", ch); 190 return Status::OK(); 191 } 192 EndOfInput()193 bool EndOfInput() { 194 SkipWhiteSpace(); 195 return (next_ >= end_); 196 } 197 198 enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL }; 199 200 struct Literal { 201 LiteralType type; 202 std::string value; 203 }; 204 Parse(Literal & result)205 Status Parse(Literal& result) { 206 bool decimal_point = false; 207 auto nextch = NextChar(); 208 auto from = next_; 209 if (nextch == '"') { 210 ++next_; 211 // TODO: Handle escape characters 212 while ((next_ < end_) && (*next_ != '"')) { 213 ++next_; 214 } 215 ++next_; 216 result.type = LiteralType::STRING_LITERAL; 217 result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes 218 } else if ((isdigit(nextch) || (nextch == '-'))) { 219 ++next_; 220 221 while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) { 222 if (*next_ == '.') { 223 if (decimal_point) 224 break; // Only one decimal point allowed in numeric literal 225 decimal_point = true; 226 } 227 ++next_; 228 } 229 230 if (next_ == from) 231 return ParseError("Value expected but not found."); 232 233 result.value = std::string(from, next_ - from); 234 result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL; 235 } 236 return Status::OK(); 237 } 238 Parse(int64_t & val)239 Status Parse(int64_t& val) { 240 Literal literal; 241 CHECK_PARSER_STATUS(Parse(literal)); 242 if (literal.type != LiteralType::INT_LITERAL) 243 return ParseError("Integer value expected, but not found."); 244 std::string s = literal.value; 245 val = std::stoll(s); 246 return Status::OK(); 247 } 248 Parse(uint64_t & val)249 Status Parse(uint64_t& val) { 250 Literal literal; 251 CHECK_PARSER_STATUS(Parse(literal)); 252 if (literal.type != LiteralType::INT_LITERAL) 253 return ParseError("Integer value expected, but not found."); 254 std::string s = literal.value; 255 val = std::stoull(s); 256 return Status::OK(); 257 } 258 Parse(float & val)259 Status Parse(float& val) { 260 Literal literal; 261 CHECK_PARSER_STATUS(Parse(literal)); 262 switch (literal.type) { 263 case LiteralType::INT_LITERAL: 264 case LiteralType::FLOAT_LITERAL: 265 val = std::stof(literal.value); 266 break; 267 default: 268 return ParseError("Unexpected literal type."); 269 } 270 return Status::OK(); 271 } 272 Parse(double & val)273 Status Parse(double& val) { 274 Literal literal; 275 CHECK_PARSER_STATUS(Parse(literal)); 276 switch (literal.type) { 277 case LiteralType::INT_LITERAL: 278 case LiteralType::FLOAT_LITERAL: 279 val = std::stod(literal.value); 280 break; 281 default: 282 return ParseError("Unexpected literal type."); 283 } 284 return Status::OK(); 285 } 286 287 // Parse a string-literal enclosed within doube-quotes. Parse(std::string & val)288 Status Parse(std::string& val) { 289 Literal literal; 290 CHECK_PARSER_STATUS(Parse(literal)); 291 if (literal.type != LiteralType::STRING_LITERAL) 292 return ParseError("String value expected, but not found."); 293 val = literal.value; 294 return Status::OK(); 295 } 296 297 // Parse an identifier, including keywords. If none found, this will 298 // return an empty-string identifier. ParseOptionalIdentifier(std::string & id)299 Status ParseOptionalIdentifier(std::string& id) { 300 SkipWhiteSpace(); 301 auto from = next_; 302 if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) { 303 ++next_; 304 while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_'))) 305 ++next_; 306 } 307 id = std::string(from, next_ - from); 308 return Status::OK(); 309 } 310 ParseIdentifier(std::string & id)311 Status ParseIdentifier(std::string& id) { 312 ParseOptionalIdentifier(id); 313 if (id.empty()) 314 return ParseError("Identifier expected but not found."); 315 return Status::OK(); 316 } 317 PeekIdentifier(std::string & id)318 Status PeekIdentifier(std::string& id) { 319 SavePos(); 320 ParseOptionalIdentifier(id); 321 RestorePos(); 322 return Status::OK(); 323 } 324 Parse(KeyWordMap::KeyWord & keyword)325 Status Parse(KeyWordMap::KeyWord& keyword) { 326 std::string id; 327 CHECK_PARSER_STATUS(ParseIdentifier(id)); 328 keyword = KeyWordMap::Lookup(id); 329 return Status::OK(); 330 } 331 332 protected: 333 const char* start_; 334 const char* next_; 335 const char* end_; 336 const char* saved_pos_; 337 }; 338 339 class OnnxParser : public ParserBase { 340 public: OnnxParser(const char * cstr)341 OnnxParser(const char* cstr) : ParserBase(cstr) {} 342 343 Status Parse(TensorShapeProto& shape); 344 345 Status Parse(TypeProto& typeProto); 346 347 Status Parse(TensorProto& tensorProto); 348 349 Status Parse(AttributeProto& attr); 350 351 Status Parse(AttrList& attrlist); 352 353 Status Parse(NodeProto& node); 354 355 Status Parse(NodeList& nodelist); 356 357 Status Parse(GraphProto& graph); 358 359 Status Parse(ModelProto& model); 360 361 template <typename T> Parse(T & parsedData,const char * input)362 static Status Parse(T& parsedData, const char* input) { 363 OnnxParser parser(input); 364 return parser.Parse(parsedData); 365 } 366 367 private: 368 Status Parse(std::string name, GraphProto& graph); 369 370 Status Parse(IdList& idlist); 371 372 Status ParseSingleAttributeValue(AttributeProto& attr); 373 374 Status Parse(ValueInfoProto& valueinfo); 375 376 Status Parse(ValueInfoList& vilist); 377 378 Status ParseInput(ValueInfoList& vilist, TensorList& initializers); 379 380 Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers); 381 382 Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto); 383 }; 384 385 } // namespace ONNX_NAMESPACE