1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 // Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
6 // by this parser is preliminary and may change.
7 
8 #pragma once
9 
10 #include <ctype.h>
11 #include <iostream>
12 #include <stdexcept>
13 #include <string>
14 #include <unordered_map>
15 
16 #include "onnx/onnx_pb.h"
17 
18 #include "onnx/common/status.h"
19 #include "onnx/string_utils.h"
20 
21 namespace ONNX_NAMESPACE {
22 
23 using namespace ONNX_NAMESPACE::Common;
24 
25 using IdList = google::protobuf::RepeatedPtrField<std::string>;
26 
27 using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;
28 
29 using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>;
30 
31 using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>;
32 
33 using TensorList = google::protobuf::RepeatedPtrField<TensorProto>;
34 
35 #define CHECK_PARSER_STATUS(status) \
36   {                                 \
37     auto local_status_ = status;    \
38     if (!local_status_.IsOK())      \
39       return local_status_;         \
40   }
41 
42 class PrimitiveTypeNameMap {
43  public:
PrimitiveTypeNameMap()44   PrimitiveTypeNameMap() {
45     map_["float"] = 1;
46     map_["uint8"] = 2;
47     map_["int8"] = 3;
48     map_["uint16"] = 4;
49     map_["int16"] = 5;
50     map_["int32"] = 6;
51     map_["int64"] = 7;
52     map_["string"] = 8;
53     map_["bool"] = 9;
54     map_["float16"] = 10;
55     map_["double"] = 11;
56     map_["uint32"] = 12;
57     map_["uint64"] = 13;
58     map_["complex64"] = 14;
59     map_["complex128"] = 15;
60     map_["bfloat16"] = 16;
61   }
62 
Instance()63   static const std::unordered_map<std::string, int32_t>& Instance() {
64     static PrimitiveTypeNameMap instance;
65     return instance.map_;
66   }
67 
Lookup(const std::string & dtype)68   static int32_t Lookup(const std::string& dtype) {
69     auto it = Instance().find(dtype);
70     if (it != Instance().end())
71       return it->second;
72     return 0;
73   }
74 
IsTypeName(const std::string & dtype)75   static bool IsTypeName(const std::string& dtype) {
76     return Lookup(dtype) != 0;
77   }
78 
ToString(int32_t dtype)79   static const std::string& ToString(int32_t dtype) {
80     static std::string undefined("undefined");
81     for (const auto& pair : Instance()) {
82       if (pair.second == dtype)
83         return pair.first;
84     }
85     return undefined;
86   }
87 
88  private:
89   std::unordered_map<std::string, int32_t> map_;
90 };
91 
92 class KeyWordMap {
93  public:
94   enum class KeyWord {
95     NONE,
96     IR_VERSION,
97     OPSET_IMPORT,
98     PRODUCER_NAME,
99     PRODUCER_VERSION,
100     DOMAIN_KW,
101     MODEL_VERSION,
102     DOC_STRING,
103     METADATA_PROPS
104   };
105 
KeyWordMap()106   KeyWordMap() {
107     map_["ir_version"] = KeyWord::IR_VERSION;
108     map_["opset_import"] = KeyWord::OPSET_IMPORT;
109     map_["producer_name"] = KeyWord::PRODUCER_NAME;
110     map_["producer_version"] = KeyWord::PRODUCER_VERSION;
111     map_["domain"] = KeyWord::DOMAIN_KW;
112     map_["model_version"] = KeyWord::MODEL_VERSION;
113     map_["doc_string"] = KeyWord::DOC_STRING;
114     map_["metadata_props"] = KeyWord::METADATA_PROPS;
115   }
116 
Instance()117   static const std::unordered_map<std::string, KeyWord>& Instance() {
118     static KeyWordMap instance;
119     return instance.map_;
120   }
121 
Lookup(const std::string & id)122   static KeyWord Lookup(const std::string& id) {
123     auto it = Instance().find(id);
124     if (it != Instance().end())
125       return it->second;
126     return KeyWord::NONE;
127   }
128 
129  private:
130   std::unordered_map<std::string, KeyWord> map_;
131 };
132 
133 class ParserBase {
134  public:
ParserBase(const std::string & str)135   ParserBase(const std::string& str)
136       : start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {}
137 
ParserBase(const char * cstr)138   ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {}
139 
SavePos()140   void SavePos() {
141     saved_pos_ = next_;
142   }
143 
RestorePos()144   void RestorePos() {
145     next_ = saved_pos_;
146   }
147 
GetCurrentPos()148   std::string GetCurrentPos() {
149     uint32_t line = 1, col = 1;
150     for (const char* p = start_; p < next_; ++p) {
151       if (*p == '\n') {
152         ++line;
153         col = 1;
154       } else {
155         ++col;
156       }
157     }
158     return ONNX_NAMESPACE::MakeString("(line: ", line, " column: ", col, ")");
159   }
160 
161   template <typename... Args>
ParseError(const Args &...args)162   Status ParseError(const Args&... args) {
163     return Status(NONE, FAIL, ONNX_NAMESPACE::MakeString("[ParseError at position ", GetCurrentPos(), "]", args...));
164   }
165 
SkipWhiteSpace()166   void SkipWhiteSpace() {
167     while ((next_ < end_) && (isspace(*next_)))
168       ++next_;
169   }
170 
171   int NextChar(bool skipspace = true) {
172     if (skipspace)
173       SkipWhiteSpace();
174     return (next_ < end_) ? *next_ : 0;
175   }
176 
177   bool Matches(char ch, bool skipspace = true) {
178     if (skipspace)
179       SkipWhiteSpace();
180     if ((next_ < end_) && (*next_ == ch)) {
181       ++next_;
182       return true;
183     }
184     return false;
185   }
186 
187   Status Match(char ch, bool skipspace = true) {
188     if (!Matches(ch, skipspace))
189       return ParseError("Expected character ", ch, " not found", ch);
190     return Status::OK();
191   }
192 
EndOfInput()193   bool EndOfInput() {
194     SkipWhiteSpace();
195     return (next_ >= end_);
196   }
197 
198   enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL };
199 
200   struct Literal {
201     LiteralType type;
202     std::string value;
203   };
204 
Parse(Literal & result)205   Status Parse(Literal& result) {
206     bool decimal_point = false;
207     auto nextch = NextChar();
208     auto from = next_;
209     if (nextch == '"') {
210       ++next_;
211       // TODO: Handle escape characters
212       while ((next_ < end_) && (*next_ != '"')) {
213         ++next_;
214       }
215       ++next_;
216       result.type = LiteralType::STRING_LITERAL;
217       result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
218     } else if ((isdigit(nextch) || (nextch == '-'))) {
219       ++next_;
220 
221       while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
222         if (*next_ == '.') {
223           if (decimal_point)
224             break; // Only one decimal point allowed in numeric literal
225           decimal_point = true;
226         }
227         ++next_;
228       }
229 
230       if (next_ == from)
231         return ParseError("Value expected but not found.");
232 
233       result.value = std::string(from, next_ - from);
234       result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
235     }
236     return Status::OK();
237   }
238 
Parse(int64_t & val)239   Status Parse(int64_t& val) {
240     Literal literal;
241     CHECK_PARSER_STATUS(Parse(literal));
242     if (literal.type != LiteralType::INT_LITERAL)
243       return ParseError("Integer value expected, but not found.");
244     std::string s = literal.value;
245     val = std::stoll(s);
246     return Status::OK();
247   }
248 
Parse(uint64_t & val)249   Status Parse(uint64_t& val) {
250     Literal literal;
251     CHECK_PARSER_STATUS(Parse(literal));
252     if (literal.type != LiteralType::INT_LITERAL)
253       return ParseError("Integer value expected, but not found.");
254     std::string s = literal.value;
255     val = std::stoull(s);
256     return Status::OK();
257   }
258 
Parse(float & val)259   Status Parse(float& val) {
260     Literal literal;
261     CHECK_PARSER_STATUS(Parse(literal));
262     switch (literal.type) {
263       case LiteralType::INT_LITERAL:
264       case LiteralType::FLOAT_LITERAL:
265         val = std::stof(literal.value);
266         break;
267       default:
268         return ParseError("Unexpected literal type.");
269     }
270     return Status::OK();
271   }
272 
Parse(double & val)273   Status Parse(double& val) {
274     Literal literal;
275     CHECK_PARSER_STATUS(Parse(literal));
276     switch (literal.type) {
277       case LiteralType::INT_LITERAL:
278       case LiteralType::FLOAT_LITERAL:
279         val = std::stod(literal.value);
280         break;
281       default:
282         return ParseError("Unexpected literal type.");
283     }
284     return Status::OK();
285   }
286 
287   // Parse a string-literal enclosed within doube-quotes.
Parse(std::string & val)288   Status Parse(std::string& val) {
289     Literal literal;
290     CHECK_PARSER_STATUS(Parse(literal));
291     if (literal.type != LiteralType::STRING_LITERAL)
292       return ParseError("String value expected, but not found.");
293     val = literal.value;
294     return Status::OK();
295   }
296 
297   // Parse an identifier, including keywords. If none found, this will
298   // return an empty-string identifier.
ParseOptionalIdentifier(std::string & id)299   Status ParseOptionalIdentifier(std::string& id) {
300     SkipWhiteSpace();
301     auto from = next_;
302     if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) {
303       ++next_;
304       while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_')))
305         ++next_;
306     }
307     id = std::string(from, next_ - from);
308     return Status::OK();
309   }
310 
ParseIdentifier(std::string & id)311   Status ParseIdentifier(std::string& id) {
312     ParseOptionalIdentifier(id);
313     if (id.empty())
314       return ParseError("Identifier expected but not found.");
315     return Status::OK();
316   }
317 
PeekIdentifier(std::string & id)318   Status PeekIdentifier(std::string& id) {
319     SavePos();
320     ParseOptionalIdentifier(id);
321     RestorePos();
322     return Status::OK();
323   }
324 
Parse(KeyWordMap::KeyWord & keyword)325   Status Parse(KeyWordMap::KeyWord& keyword) {
326     std::string id;
327     CHECK_PARSER_STATUS(ParseIdentifier(id));
328     keyword = KeyWordMap::Lookup(id);
329     return Status::OK();
330   }
331 
332  protected:
333   const char* start_;
334   const char* next_;
335   const char* end_;
336   const char* saved_pos_;
337 };
338 
339 class OnnxParser : public ParserBase {
340  public:
OnnxParser(const char * cstr)341   OnnxParser(const char* cstr) : ParserBase(cstr) {}
342 
343   Status Parse(TensorShapeProto& shape);
344 
345   Status Parse(TypeProto& typeProto);
346 
347   Status Parse(TensorProto& tensorProto);
348 
349   Status Parse(AttributeProto& attr);
350 
351   Status Parse(AttrList& attrlist);
352 
353   Status Parse(NodeProto& node);
354 
355   Status Parse(NodeList& nodelist);
356 
357   Status Parse(GraphProto& graph);
358 
359   Status Parse(ModelProto& model);
360 
361   template <typename T>
Parse(T & parsedData,const char * input)362   static Status Parse(T& parsedData, const char* input) {
363     OnnxParser parser(input);
364     return parser.Parse(parsedData);
365   }
366 
367  private:
368   Status Parse(std::string name, GraphProto& graph);
369 
370   Status Parse(IdList& idlist);
371 
372   Status ParseSingleAttributeValue(AttributeProto& attr);
373 
374   Status Parse(ValueInfoProto& valueinfo);
375 
376   Status Parse(ValueInfoList& vilist);
377 
378   Status ParseInput(ValueInfoList& vilist, TensorList& initializers);
379 
380   Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers);
381 
382   Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto);
383 };
384 
385 } // namespace ONNX_NAMESPACE