1 //===- Parser.h - Toy Language Parser -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the Toy language. It processes the Token
10 // provided by the Lexer and returns an AST.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TUTORIAL_TOY_PARSER_H
15 #define MLIR_TUTORIAL_TOY_PARSER_H
16 
17 #include "toy/AST.h"
18 #include "toy/Lexer.h"
19 
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #include <map>
26 #include <utility>
27 #include <vector>
28 
29 namespace toy {
30 
31 /// This is a simple recursive parser for the Toy language. It produces a well
32 /// formed AST from a stream of Token supplied by the Lexer. No semantic checks
33 /// or symbol resolution is performed. For example, variables are referenced by
34 /// string and the code could reference an undeclared variable and the parsing
35 /// succeeds.
36 class Parser {
37 public:
38   /// Create a Parser for the supplied lexer.
Parser(Lexer & lexer)39   Parser(Lexer &lexer) : lexer(lexer) {}
40 
41   /// Parse a full Module. A module is a list of function definitions.
parseModule()42   std::unique_ptr<ModuleAST> parseModule() {
43     lexer.getNextToken(); // prime the lexer
44 
45     // Parse functions one at a time and accumulate in this vector.
46     std::vector<FunctionAST> functions;
47     while (auto f = parseDefinition()) {
48       functions.push_back(std::move(*f));
49       if (lexer.getCurToken() == tok_eof)
50         break;
51     }
52     // If we didn't reach EOF, there was an error during parsing
53     if (lexer.getCurToken() != tok_eof)
54       return parseError<ModuleAST>("nothing", "at end of module");
55 
56     return std::make_unique<ModuleAST>(std::move(functions));
57   }
58 
59 private:
60   Lexer &lexer;
61 
62   /// Parse a return statement.
63   /// return :== return ; | return expr ;
parseReturn()64   std::unique_ptr<ReturnExprAST> parseReturn() {
65     auto loc = lexer.getLastLocation();
66     lexer.consume(tok_return);
67 
68     // return takes an optional argument
69     llvm::Optional<std::unique_ptr<ExprAST>> expr;
70     if (lexer.getCurToken() != ';') {
71       expr = parseExpression();
72       if (!expr)
73         return nullptr;
74     }
75     return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
76   }
77 
78   /// Parse a literal number.
79   /// numberexpr ::= number
parseNumberExpr()80   std::unique_ptr<ExprAST> parseNumberExpr() {
81     auto loc = lexer.getLastLocation();
82     auto result =
83         std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
84     lexer.consume(tok_number);
85     return std::move(result);
86   }
87 
88   /// Parse a literal array expression.
89   /// tensorLiteral ::= [ literalList ] | number
90   /// literalList ::= tensorLiteral | tensorLiteral, literalList
parseTensorLiteralExpr()91   std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
92     auto loc = lexer.getLastLocation();
93     lexer.consume(Token('['));
94 
95     // Hold the list of values at this nesting level.
96     std::vector<std::unique_ptr<ExprAST>> values;
97     // Hold the dimensions for all the nesting inside this level.
98     std::vector<int64_t> dims;
99     do {
100       // We can have either another nested array or a number literal.
101       if (lexer.getCurToken() == '[') {
102         values.push_back(parseTensorLiteralExpr());
103         if (!values.back())
104           return nullptr; // parse error in the nested array.
105       } else {
106         if (lexer.getCurToken() != tok_number)
107           return parseError<ExprAST>("<num> or [", "in literal expression");
108         values.push_back(parseNumberExpr());
109       }
110 
111       // End of this list on ']'
112       if (lexer.getCurToken() == ']')
113         break;
114 
115       // Elements are separated by a comma.
116       if (lexer.getCurToken() != ',')
117         return parseError<ExprAST>("] or ,", "in literal expression");
118 
119       lexer.getNextToken(); // eat ,
120     } while (true);
121     if (values.empty())
122       return parseError<ExprAST>("<something>", "to fill literal expression");
123     lexer.getNextToken(); // eat ]
124 
125     /// Fill in the dimensions now. First the current nesting level:
126     dims.push_back(values.size());
127 
128     /// If there is any nested array, process all of them and ensure that
129     /// dimensions are uniform.
130     if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
131           return llvm::isa<LiteralExprAST>(expr.get());
132         })) {
133       auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
134       if (!firstLiteral)
135         return parseError<ExprAST>("uniform well-nested dimensions",
136                                    "inside literal expression");
137 
138       // Append the nested dimensions to the current level
139       auto firstDims = firstLiteral->getDims();
140       dims.insert(dims.end(), firstDims.begin(), firstDims.end());
141 
142       // Sanity check that shape is uniform across all elements of the list.
143       for (auto &expr : values) {
144         auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
145         if (!exprLiteral)
146           return parseError<ExprAST>("uniform well-nested dimensions",
147                                      "inside literal expression");
148         if (exprLiteral->getDims() != firstDims)
149           return parseError<ExprAST>("uniform well-nested dimensions",
150                                      "inside literal expression");
151       }
152     }
153     return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
154                                             std::move(dims));
155   }
156 
157   /// parenexpr ::= '(' expression ')'
parseParenExpr()158   std::unique_ptr<ExprAST> parseParenExpr() {
159     lexer.getNextToken(); // eat (.
160     auto v = parseExpression();
161     if (!v)
162       return nullptr;
163 
164     if (lexer.getCurToken() != ')')
165       return parseError<ExprAST>(")", "to close expression with parentheses");
166     lexer.consume(Token(')'));
167     return v;
168   }
169 
170   /// identifierexpr
171   ///   ::= identifier
172   ///   ::= identifier '(' expression ')'
parseIdentifierExpr()173   std::unique_ptr<ExprAST> parseIdentifierExpr() {
174     std::string name(lexer.getId());
175 
176     auto loc = lexer.getLastLocation();
177     lexer.getNextToken(); // eat identifier.
178 
179     if (lexer.getCurToken() != '(') // Simple variable ref.
180       return std::make_unique<VariableExprAST>(std::move(loc), name);
181 
182     // This is a function call.
183     lexer.consume(Token('('));
184     std::vector<std::unique_ptr<ExprAST>> args;
185     if (lexer.getCurToken() != ')') {
186       while (true) {
187         if (auto arg = parseExpression())
188           args.push_back(std::move(arg));
189         else
190           return nullptr;
191 
192         if (lexer.getCurToken() == ')')
193           break;
194 
195         if (lexer.getCurToken() != ',')
196           return parseError<ExprAST>(", or )", "in argument list");
197         lexer.getNextToken();
198       }
199     }
200     lexer.consume(Token(')'));
201 
202     // It can be a builtin call to print
203     if (name == "print") {
204       if (args.size() != 1)
205         return parseError<ExprAST>("<single arg>", "as argument to print()");
206 
207       return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
208     }
209 
210     // Call to a user-defined function
211     return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
212   }
213 
214   /// primary
215   ///   ::= identifierexpr
216   ///   ::= numberexpr
217   ///   ::= parenexpr
218   ///   ::= tensorliteral
parsePrimary()219   std::unique_ptr<ExprAST> parsePrimary() {
220     switch (lexer.getCurToken()) {
221     default:
222       llvm::errs() << "unknown token '" << lexer.getCurToken()
223                    << "' when expecting an expression\n";
224       return nullptr;
225     case tok_identifier:
226       return parseIdentifierExpr();
227     case tok_number:
228       return parseNumberExpr();
229     case '(':
230       return parseParenExpr();
231     case '[':
232       return parseTensorLiteralExpr();
233     case ';':
234       return nullptr;
235     case '}':
236       return nullptr;
237     }
238   }
239 
240   /// Recursively parse the right hand side of a binary expression, the ExprPrec
241   /// argument indicates the precedence of the current binary operator.
242   ///
243   /// binoprhs ::= ('+' primary)*
parseBinOpRHS(int exprPrec,std::unique_ptr<ExprAST> lhs)244   std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
245                                          std::unique_ptr<ExprAST> lhs) {
246     // If this is a binop, find its precedence.
247     while (true) {
248       int tokPrec = getTokPrecedence();
249 
250       // If this is a binop that binds at least as tightly as the current binop,
251       // consume it, otherwise we are done.
252       if (tokPrec < exprPrec)
253         return lhs;
254 
255       // Okay, we know this is a binop.
256       int binOp = lexer.getCurToken();
257       lexer.consume(Token(binOp));
258       auto loc = lexer.getLastLocation();
259 
260       // Parse the primary expression after the binary operator.
261       auto rhs = parsePrimary();
262       if (!rhs)
263         return parseError<ExprAST>("expression", "to complete binary operator");
264 
265       // If BinOp binds less tightly with rhs than the operator after rhs, let
266       // the pending operator take rhs as its lhs.
267       int nextPrec = getTokPrecedence();
268       if (tokPrec < nextPrec) {
269         rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
270         if (!rhs)
271           return nullptr;
272       }
273 
274       // Merge lhs/RHS.
275       lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
276                                             std::move(lhs), std::move(rhs));
277     }
278   }
279 
280   /// expression::= primary binop rhs
parseExpression()281   std::unique_ptr<ExprAST> parseExpression() {
282     auto lhs = parsePrimary();
283     if (!lhs)
284       return nullptr;
285 
286     return parseBinOpRHS(0, std::move(lhs));
287   }
288 
289   /// type ::= < shape_list >
290   /// shape_list ::= num | num , shape_list
parseType()291   std::unique_ptr<VarType> parseType() {
292     if (lexer.getCurToken() != '<')
293       return parseError<VarType>("<", "to begin type");
294     lexer.getNextToken(); // eat <
295 
296     auto type = std::make_unique<VarType>();
297 
298     while (lexer.getCurToken() == tok_number) {
299       type->shape.push_back(lexer.getValue());
300       lexer.getNextToken();
301       if (lexer.getCurToken() == ',')
302         lexer.getNextToken();
303     }
304 
305     if (lexer.getCurToken() != '>')
306       return parseError<VarType>(">", "to end type");
307     lexer.getNextToken(); // eat >
308     return type;
309   }
310 
311   /// Parse a variable declaration, it starts with a `var` keyword followed by
312   /// and identifier and an optional type (shape specification) before the
313   /// initializer.
314   /// decl ::= var identifier [ type ] = expr
parseDeclaration()315   std::unique_ptr<VarDeclExprAST> parseDeclaration() {
316     if (lexer.getCurToken() != tok_var)
317       return parseError<VarDeclExprAST>("var", "to begin declaration");
318     auto loc = lexer.getLastLocation();
319     lexer.getNextToken(); // eat var
320 
321     if (lexer.getCurToken() != tok_identifier)
322       return parseError<VarDeclExprAST>("identified",
323                                         "after 'var' declaration");
324     std::string id(lexer.getId());
325     lexer.getNextToken(); // eat id
326 
327     std::unique_ptr<VarType> type; // Type is optional, it can be inferred
328     if (lexer.getCurToken() == '<') {
329       type = parseType();
330       if (!type)
331         return nullptr;
332     }
333 
334     if (!type)
335       type = std::make_unique<VarType>();
336     lexer.consume(Token('='));
337     auto expr = parseExpression();
338     return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
339                                             std::move(*type), std::move(expr));
340   }
341 
342   /// Parse a block: a list of expression separated by semicolons and wrapped in
343   /// curly braces.
344   ///
345   /// block ::= { expression_list }
346   /// expression_list ::= block_expr ; expression_list
347   /// block_expr ::= decl | "return" | expr
parseBlock()348   std::unique_ptr<ExprASTList> parseBlock() {
349     if (lexer.getCurToken() != '{')
350       return parseError<ExprASTList>("{", "to begin block");
351     lexer.consume(Token('{'));
352 
353     auto exprList = std::make_unique<ExprASTList>();
354 
355     // Ignore empty expressions: swallow sequences of semicolons.
356     while (lexer.getCurToken() == ';')
357       lexer.consume(Token(';'));
358 
359     while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
360       if (lexer.getCurToken() == tok_var) {
361         // Variable declaration
362         auto varDecl = parseDeclaration();
363         if (!varDecl)
364           return nullptr;
365         exprList->push_back(std::move(varDecl));
366       } else if (lexer.getCurToken() == tok_return) {
367         // Return statement
368         auto ret = parseReturn();
369         if (!ret)
370           return nullptr;
371         exprList->push_back(std::move(ret));
372       } else {
373         // General expression
374         auto expr = parseExpression();
375         if (!expr)
376           return nullptr;
377         exprList->push_back(std::move(expr));
378       }
379       // Ensure that elements are separated by a semicolon.
380       if (lexer.getCurToken() != ';')
381         return parseError<ExprASTList>(";", "after expression");
382 
383       // Ignore empty expressions: swallow sequences of semicolons.
384       while (lexer.getCurToken() == ';')
385         lexer.consume(Token(';'));
386     }
387 
388     if (lexer.getCurToken() != '}')
389       return parseError<ExprASTList>("}", "to close block");
390 
391     lexer.consume(Token('}'));
392     return exprList;
393   }
394 
395   /// prototype ::= def id '(' decl_list ')'
396   /// decl_list ::= identifier | identifier, decl_list
parsePrototype()397   std::unique_ptr<PrototypeAST> parsePrototype() {
398     auto loc = lexer.getLastLocation();
399 
400     if (lexer.getCurToken() != tok_def)
401       return parseError<PrototypeAST>("def", "in prototype");
402     lexer.consume(tok_def);
403 
404     if (lexer.getCurToken() != tok_identifier)
405       return parseError<PrototypeAST>("function name", "in prototype");
406 
407     std::string fnName(lexer.getId());
408     lexer.consume(tok_identifier);
409 
410     if (lexer.getCurToken() != '(')
411       return parseError<PrototypeAST>("(", "in prototype");
412     lexer.consume(Token('('));
413 
414     std::vector<std::unique_ptr<VariableExprAST>> args;
415     if (lexer.getCurToken() != ')') {
416       do {
417         std::string name(lexer.getId());
418         auto loc = lexer.getLastLocation();
419         lexer.consume(tok_identifier);
420         auto decl = std::make_unique<VariableExprAST>(std::move(loc), name);
421         args.push_back(std::move(decl));
422         if (lexer.getCurToken() != ',')
423           break;
424         lexer.consume(Token(','));
425         if (lexer.getCurToken() != tok_identifier)
426           return parseError<PrototypeAST>(
427               "identifier", "after ',' in function parameter list");
428       } while (true);
429     }
430     if (lexer.getCurToken() != ')')
431       return parseError<PrototypeAST>(")", "to end function prototype");
432 
433     // success.
434     lexer.consume(Token(')'));
435     return std::make_unique<PrototypeAST>(std::move(loc), fnName,
436                                           std::move(args));
437   }
438 
439   /// Parse a function definition, we expect a prototype initiated with the
440   /// `def` keyword, followed by a block containing a list of expressions.
441   ///
442   /// definition ::= prototype block
parseDefinition()443   std::unique_ptr<FunctionAST> parseDefinition() {
444     auto proto = parsePrototype();
445     if (!proto)
446       return nullptr;
447 
448     if (auto block = parseBlock())
449       return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
450     return nullptr;
451   }
452 
453   /// Get the precedence of the pending binary operator token.
getTokPrecedence()454   int getTokPrecedence() {
455     if (!isascii(lexer.getCurToken()))
456       return -1;
457 
458     // 1 is lowest precedence.
459     switch (static_cast<char>(lexer.getCurToken())) {
460     case '-':
461       return 20;
462     case '+':
463       return 20;
464     case '*':
465       return 40;
466     default:
467       return -1;
468     }
469   }
470 
471   /// Helper function to signal errors while parsing, it takes an argument
472   /// indicating the expected token and another argument giving more context.
473   /// Location is retrieved from the lexer to enrich the error message.
474   template <typename R, typename T, typename U = const char *>
475   std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
476     auto curToken = lexer.getCurToken();
477     llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
478                  << lexer.getLastLocation().col << "): expected '" << expected
479                  << "' " << context << " but has Token " << curToken;
480     if (isprint(curToken))
481       llvm::errs() << " '" << (char)curToken << "'";
482     llvm::errs() << "\n";
483     return nullptr;
484   }
485 };
486 
487 } // namespace toy
488 
489 #endif // MLIR_TUTORIAL_TOY_PARSER_H
490