1 //===- Lexer.cpp - MLIR Lexer Implementation ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lexer for the MLIR textual form.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Lexer.h"
14 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/Identifier.h"
16 #include "mlir/IR/Location.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/SourceMgr.h"
21 using namespace mlir;
22 
23 using llvm::SMLoc;
24 using llvm::SourceMgr;
25 
26 // Returns true if 'c' is an allowable punctuation character: [$._-]
27 // Returns false otherwise.
isPunct(char c)28 static bool isPunct(char c) {
29   return c == '$' || c == '.' || c == '_' || c == '-';
30 }
31 
Lexer(const llvm::SourceMgr & sourceMgr,MLIRContext * context)32 Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context)
33     : sourceMgr(sourceMgr), context(context) {
34   auto bufferID = sourceMgr.getMainFileID();
35   curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
36   curPtr = curBuffer.begin();
37 }
38 
39 /// Encode the specified source location information into an attribute for
40 /// attachment to the IR.
getEncodedSourceLocation(llvm::SMLoc loc)41 Location Lexer::getEncodedSourceLocation(llvm::SMLoc loc) {
42   auto &sourceMgr = getSourceMgr();
43   unsigned mainFileID = sourceMgr.getMainFileID();
44   auto lineAndColumn = sourceMgr.getLineAndColumn(loc, mainFileID);
45   auto *buffer = sourceMgr.getMemoryBuffer(mainFileID);
46 
47   return FileLineColLoc::get(buffer->getBufferIdentifier(), lineAndColumn.first,
48                              lineAndColumn.second, context);
49 }
50 
51 /// emitError - Emit an error message and return an Token::error token.
emitError(const char * loc,const Twine & message)52 Token Lexer::emitError(const char *loc, const Twine &message) {
53   mlir::emitError(getEncodedSourceLocation(SMLoc::getFromPointer(loc)),
54                   message);
55   return formToken(Token::error, loc);
56 }
57 
lexToken()58 Token Lexer::lexToken() {
59   while (true) {
60     const char *tokStart = curPtr;
61     switch (*curPtr++) {
62     default:
63       // Handle bare identifiers.
64       if (isalpha(curPtr[-1]))
65         return lexBareIdentifierOrKeyword(tokStart);
66 
67       // Unknown character, emit an error.
68       return emitError(tokStart, "unexpected character");
69 
70     case ' ':
71     case '\t':
72     case '\n':
73     case '\r':
74       // Handle whitespace.
75       continue;
76 
77     case '_':
78       // Handle bare identifiers.
79       return lexBareIdentifierOrKeyword(tokStart);
80 
81     case 0:
82       // This may either be a nul character in the source file or may be the EOF
83       // marker that llvm::MemoryBuffer guarantees will be there.
84       if (curPtr - 1 == curBuffer.end())
85         return formToken(Token::eof, tokStart);
86       continue;
87 
88     case ':':
89       return formToken(Token::colon, tokStart);
90     case ',':
91       return formToken(Token::comma, tokStart);
92     case '.':
93       return lexEllipsis(tokStart);
94     case '(':
95       return formToken(Token::l_paren, tokStart);
96     case ')':
97       return formToken(Token::r_paren, tokStart);
98     case '{':
99       return formToken(Token::l_brace, tokStart);
100     case '}':
101       return formToken(Token::r_brace, tokStart);
102     case '[':
103       return formToken(Token::l_square, tokStart);
104     case ']':
105       return formToken(Token::r_square, tokStart);
106     case '<':
107       return formToken(Token::less, tokStart);
108     case '>':
109       return formToken(Token::greater, tokStart);
110     case '=':
111       return formToken(Token::equal, tokStart);
112 
113     case '+':
114       return formToken(Token::plus, tokStart);
115     case '*':
116       return formToken(Token::star, tokStart);
117     case '-':
118       if (*curPtr == '>') {
119         ++curPtr;
120         return formToken(Token::arrow, tokStart);
121       }
122       return formToken(Token::minus, tokStart);
123 
124     case '?':
125       return formToken(Token::question, tokStart);
126 
127     case '/':
128       if (*curPtr == '/') {
129         skipComment();
130         continue;
131       }
132       return emitError(tokStart, "unexpected character");
133 
134     case '@':
135       return lexAtIdentifier(tokStart);
136 
137     case '!':
138       LLVM_FALLTHROUGH;
139     case '^':
140       LLVM_FALLTHROUGH;
141     case '#':
142       LLVM_FALLTHROUGH;
143     case '%':
144       return lexPrefixedIdentifier(tokStart);
145     case '"':
146       return lexString(tokStart);
147 
148     case '0':
149     case '1':
150     case '2':
151     case '3':
152     case '4':
153     case '5':
154     case '6':
155     case '7':
156     case '8':
157     case '9':
158       return lexNumber(tokStart);
159     }
160   }
161 }
162 
163 /// Lex an '@foo' identifier.
164 ///
165 ///   symbol-ref-id ::= `@` (bare-id | string-literal)
166 ///
lexAtIdentifier(const char * tokStart)167 Token Lexer::lexAtIdentifier(const char *tokStart) {
168   char cur = *curPtr++;
169 
170   // Try to parse a string literal, if present.
171   if (cur == '"') {
172     Token stringIdentifier = lexString(curPtr);
173     if (stringIdentifier.is(Token::error))
174       return stringIdentifier;
175     return formToken(Token::at_identifier, tokStart);
176   }
177 
178   // Otherwise, these always start with a letter or underscore.
179   if (!isalpha(cur) && cur != '_')
180     return emitError(curPtr - 1,
181                      "@ identifier expected to start with letter or '_'");
182 
183   while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
184          *curPtr == '$' || *curPtr == '.')
185     ++curPtr;
186   return formToken(Token::at_identifier, tokStart);
187 }
188 
189 /// Lex a bare identifier or keyword that starts with a letter.
190 ///
191 ///   bare-id ::= (letter|[_]) (letter|digit|[_$.])*
192 ///   integer-type ::= `[su]?i[1-9][0-9]*`
193 ///
lexBareIdentifierOrKeyword(const char * tokStart)194 Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
195   // Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
196   while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
197          *curPtr == '$' || *curPtr == '.')
198     ++curPtr;
199 
200   // Check to see if this identifier is a keyword.
201   StringRef spelling(tokStart, curPtr - tokStart);
202 
203   auto isAllDigit = [](StringRef str) {
204     return llvm::all_of(str, [](char c) { return llvm::isDigit(c); });
205   };
206 
207   // Check for i123, si456, ui789.
208   if ((spelling.size() > 1 && tokStart[0] == 'i' &&
209        isAllDigit(spelling.drop_front())) ||
210       ((spelling.size() > 2 && tokStart[1] == 'i' &&
211         (tokStart[0] == 's' || tokStart[0] == 'u')) &&
212        isAllDigit(spelling.drop_front(2))))
213     return Token(Token::inttype, spelling);
214 
215   Token::Kind kind = StringSwitch<Token::Kind>(spelling)
216 #define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
217 #include "TokenKinds.def"
218                          .Default(Token::bare_identifier);
219 
220   return Token(kind, spelling);
221 }
222 
223 /// Skip a comment line, starting with a '//'.
224 ///
225 ///   TODO: add a regex for comments here and to the spec.
226 ///
skipComment()227 void Lexer::skipComment() {
228   // Advance over the second '/' in a '//' comment.
229   assert(*curPtr == '/');
230   ++curPtr;
231 
232   while (true) {
233     switch (*curPtr++) {
234     case '\n':
235     case '\r':
236       // Newline is end of comment.
237       return;
238     case 0:
239       // If this is the end of the buffer, end the comment.
240       if (curPtr - 1 == curBuffer.end()) {
241         --curPtr;
242         return;
243       }
244       LLVM_FALLTHROUGH;
245     default:
246       // Skip over other characters.
247       break;
248     }
249   }
250 }
251 
252 /// Lex an ellipsis.
253 ///
254 ///   ellipsis ::= '...'
255 ///
lexEllipsis(const char * tokStart)256 Token Lexer::lexEllipsis(const char *tokStart) {
257   assert(curPtr[-1] == '.');
258 
259   if (curPtr == curBuffer.end() || *curPtr != '.' || *(curPtr + 1) != '.')
260     return emitError(curPtr, "expected three consecutive dots for an ellipsis");
261 
262   curPtr += 2;
263   return formToken(Token::ellipsis, tokStart);
264 }
265 
266 /// Lex a number literal.
267 ///
268 ///   integer-literal ::= digit+ | `0x` hex_digit+
269 ///   float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
270 ///
lexNumber(const char * tokStart)271 Token Lexer::lexNumber(const char *tokStart) {
272   assert(isdigit(curPtr[-1]));
273 
274   // Handle the hexadecimal case.
275   if (curPtr[-1] == '0' && *curPtr == 'x') {
276     // If we see stuff like 0xi32, this is a literal `0` followed by an
277     // identifier `xi32`, stop after `0`.
278     if (!isxdigit(curPtr[1]))
279       return formToken(Token::integer, tokStart);
280 
281     curPtr += 2;
282     while (isxdigit(*curPtr))
283       ++curPtr;
284 
285     return formToken(Token::integer, tokStart);
286   }
287 
288   // Handle the normal decimal case.
289   while (isdigit(*curPtr))
290     ++curPtr;
291 
292   if (*curPtr != '.')
293     return formToken(Token::integer, tokStart);
294   ++curPtr;
295 
296   // Skip over [0-9]*([eE][-+]?[0-9]+)?
297   while (isdigit(*curPtr))
298     ++curPtr;
299 
300   if (*curPtr == 'e' || *curPtr == 'E') {
301     if (isdigit(static_cast<unsigned char>(curPtr[1])) ||
302         ((curPtr[1] == '-' || curPtr[1] == '+') &&
303          isdigit(static_cast<unsigned char>(curPtr[2])))) {
304       curPtr += 2;
305       while (isdigit(*curPtr))
306         ++curPtr;
307     }
308   }
309   return formToken(Token::floatliteral, tokStart);
310 }
311 
312 /// Lex an identifier that starts with a prefix followed by suffix-id.
313 ///
314 ///   attribute-id  ::= `#` suffix-id
315 ///   ssa-id        ::= '%' suffix-id
316 ///   block-id      ::= '^' suffix-id
317 ///   type-id       ::= '!' suffix-id
318 ///   suffix-id     ::= digit+ | (letter|id-punct) (letter|id-punct|digit)*
319 ///   id-punct      ::= `$` | `.` | `_` | `-`
320 ///
lexPrefixedIdentifier(const char * tokStart)321 Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
322   Token::Kind kind;
323   StringRef errorKind;
324   switch (*tokStart) {
325   case '#':
326     kind = Token::hash_identifier;
327     errorKind = "invalid attribute name";
328     break;
329   case '%':
330     kind = Token::percent_identifier;
331     errorKind = "invalid SSA name";
332     break;
333   case '^':
334     kind = Token::caret_identifier;
335     errorKind = "invalid block name";
336     break;
337   case '!':
338     kind = Token::exclamation_identifier;
339     errorKind = "invalid type identifier";
340     break;
341   default:
342     llvm_unreachable("invalid caller");
343   }
344 
345   // Parse suffix-id.
346   if (isdigit(*curPtr)) {
347     // If suffix-id starts with a digit, the rest must be digits.
348     while (isdigit(*curPtr)) {
349       ++curPtr;
350     }
351   } else if (isalpha(*curPtr) || isPunct(*curPtr)) {
352     do {
353       ++curPtr;
354     } while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr));
355   } else {
356     return emitError(curPtr - 1, errorKind);
357   }
358 
359   return formToken(kind, tokStart);
360 }
361 
362 /// Lex a string literal.
363 ///
364 ///   string-literal ::= '"' [^"\n\f\v\r]* '"'
365 ///
366 /// TODO: define escaping rules.
lexString(const char * tokStart)367 Token Lexer::lexString(const char *tokStart) {
368   assert(curPtr[-1] == '"');
369 
370   while (true) {
371     switch (*curPtr++) {
372     case '"':
373       return formToken(Token::string, tokStart);
374     case 0:
375       // If this is a random nul character in the middle of a string, just
376       // include it.  If it is the end of file, then it is an error.
377       if (curPtr - 1 != curBuffer.end())
378         continue;
379       LLVM_FALLTHROUGH;
380     case '\n':
381     case '\v':
382     case '\f':
383       return emitError(curPtr - 1, "expected '\"' in string literal");
384     case '\\':
385       // Handle explicitly a few escapes.
386       if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
387         ++curPtr;
388       else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
389         // Support \xx for two hex digits.
390         curPtr += 2;
391       else
392         return emitError(curPtr - 1, "unknown escape in string literal");
393       continue;
394 
395     default:
396       continue;
397     }
398   }
399 }
400