1 //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple IR generation targeting MLIR from a Module AST
10 // for the Toy language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/MLIRGen.h"
15 #include "toy/AST.h"
16 #include "toy/Dialect.h"
17 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Verifier.h"
24 
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/ScopedHashTable.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <numeric>
29 
30 using namespace mlir::toy;
31 using namespace toy;
32 
33 using llvm::ArrayRef;
34 using llvm::cast;
35 using llvm::dyn_cast;
36 using llvm::isa;
37 using llvm::makeArrayRef;
38 using llvm::ScopedHashTableScope;
39 using llvm::SmallVector;
40 using llvm::StringRef;
41 using llvm::Twine;
42 
43 namespace {
44 
45 /// Implementation of a simple MLIR emission from the Toy AST.
46 ///
47 /// This will emit operations that are specific to the Toy language, preserving
48 /// the semantics of the language and (hopefully) allow to perform accurate
49 /// analysis and transformation based on these high level semantics.
50 class MLIRGenImpl {
51 public:
MLIRGenImpl(mlir::MLIRContext & context)52   MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
53 
54   /// Public API: convert the AST for a Toy module (source file) to an MLIR
55   /// Module operation.
mlirGen(ModuleAST & moduleAST)56   mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
57     // We create an empty MLIR module and codegen functions one at a time and
58     // add them to the module.
59     theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
60 
61     for (auto &record : moduleAST) {
62       if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
63         auto func = mlirGen(*funcAST);
64         if (!func)
65           return nullptr;
66 
67         theModule.push_back(func);
68         functionMap.insert({func.getName(), func});
69       } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
70         if (failed(mlirGen(*str)))
71           return nullptr;
72       } else {
73         llvm_unreachable("unknown record type");
74       }
75     }
76 
77     // Verify the module after we have finished constructing it, this will check
78     // the structural properties of the IR and invoke any specific verifiers we
79     // have on the Toy operations.
80     if (failed(mlir::verify(theModule))) {
81       theModule.emitError("module verification error");
82       return nullptr;
83     }
84 
85     return theModule;
86   }
87 
88 private:
89   /// A "module" matches a Toy source file: containing a list of functions.
90   mlir::ModuleOp theModule;
91 
92   /// The builder is a helper class to create IR inside a function. The builder
93   /// is stateful, in particular it keeps an "insertion point": this is where
94   /// the next operations will be introduced.
95   mlir::OpBuilder builder;
96 
97   /// The symbol table maps a variable name to a value in the current scope.
98   /// Entering a function creates a new scope, and the function arguments are
99   /// added to the mapping. When the processing of a function is terminated, the
100   /// scope is destroyed and the mappings created in this scope are dropped.
101   llvm::ScopedHashTable<StringRef, std::pair<mlir::Value, VarDeclExprAST *>>
102       symbolTable;
103   using SymbolTableScopeT =
104       llvm::ScopedHashTableScope<StringRef,
105                                  std::pair<mlir::Value, VarDeclExprAST *>>;
106 
107   /// A mapping for the functions that have been code generated to MLIR.
108   llvm::StringMap<mlir::FuncOp> functionMap;
109 
110   /// A mapping for named struct types to the underlying MLIR type and the
111   /// original AST node.
112   llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap;
113 
114   /// Helper conversion for a Toy AST location to an MLIR location.
loc(Location loc)115   mlir::Location loc(Location loc) {
116     return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
117                                      loc.col);
118   }
119 
120   /// Declare a variable in the current scope, return success if the variable
121   /// wasn't declared yet.
declare(VarDeclExprAST & var,mlir::Value value)122   mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) {
123     if (symbolTable.count(var.getName()))
124       return mlir::failure();
125     symbolTable.insert(var.getName(), {value, &var});
126     return mlir::success();
127   }
128 
129   /// Create an MLIR type for the given struct.
mlirGen(StructAST & str)130   mlir::LogicalResult mlirGen(StructAST &str) {
131     if (structMap.count(str.getName()))
132       return emitError(loc(str.loc())) << "error: struct type with name `"
133                                        << str.getName() << "' already exists";
134 
135     auto variables = str.getVariables();
136     std::vector<mlir::Type> elementTypes;
137     elementTypes.reserve(variables.size());
138     for (auto &variable : variables) {
139       if (variable->getInitVal())
140         return emitError(loc(variable->loc()))
141                << "error: variables within a struct definition must not have "
142                   "initializers";
143       if (!variable->getType().shape.empty())
144         return emitError(loc(variable->loc()))
145                << "error: variables within a struct definition must not have "
146                   "initializers";
147 
148       mlir::Type type = getType(variable->getType(), variable->loc());
149       if (!type)
150         return mlir::failure();
151       elementTypes.push_back(type);
152     }
153 
154     structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str);
155     return mlir::success();
156   }
157 
158   /// Create the prototype for an MLIR function with as many arguments as the
159   /// provided Toy AST prototype.
mlirGen(PrototypeAST & proto)160   mlir::FuncOp mlirGen(PrototypeAST &proto) {
161     auto location = loc(proto.loc());
162 
163     // This is a generic function, the return type will be inferred later.
164     llvm::SmallVector<mlir::Type, 4> argTypes;
165     argTypes.reserve(proto.getArgs().size());
166     for (auto &arg : proto.getArgs()) {
167       mlir::Type type = getType(arg->getType(), arg->loc());
168       if (!type)
169         return nullptr;
170       argTypes.push_back(type);
171     }
172     auto func_type = builder.getFunctionType(argTypes, llvm::None);
173     return mlir::FuncOp::create(location, proto.getName(), func_type);
174   }
175 
176   /// Emit a new function and add it to the MLIR module.
mlirGen(FunctionAST & funcAST)177   mlir::FuncOp mlirGen(FunctionAST &funcAST) {
178     // Create a scope in the symbol table to hold variable declarations.
179     SymbolTableScopeT var_scope(symbolTable);
180 
181     // Create an MLIR function for the given prototype.
182     mlir::FuncOp function(mlirGen(*funcAST.getProto()));
183     if (!function)
184       return nullptr;
185 
186     // Let's start the body of the function now!
187     // In MLIR the entry block of the function is special: it must have the same
188     // argument list as the function itself.
189     auto &entryBlock = *function.addEntryBlock();
190     auto protoArgs = funcAST.getProto()->getArgs();
191 
192     // Declare all the function arguments in the symbol table.
193     for (const auto nameValue :
194          llvm::zip(protoArgs, entryBlock.getArguments())) {
195       if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue))))
196         return nullptr;
197     }
198 
199     // Set the insertion point in the builder to the beginning of the function
200     // body, it will be used throughout the codegen to create operations in this
201     // function.
202     builder.setInsertionPointToStart(&entryBlock);
203 
204     // Emit the body of the function.
205     if (mlir::failed(mlirGen(*funcAST.getBody()))) {
206       function.erase();
207       return nullptr;
208     }
209 
210     // Implicitly return void if no return statement was emitted.
211     // FIXME: we may fix the parser instead to always return the last expression
212     // (this would possibly help the REPL case later)
213     ReturnOp returnOp;
214     if (!entryBlock.empty())
215       returnOp = dyn_cast<ReturnOp>(entryBlock.back());
216     if (!returnOp) {
217       builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
218     } else if (returnOp.hasOperand()) {
219       // Otherwise, if this return operation has an operand then add a result to
220       // the function.
221       function.setType(builder.getFunctionType(function.getType().getInputs(),
222                                                *returnOp.operand_type_begin()));
223     }
224 
225     // If this function isn't main, then set the visibility to private.
226     if (funcAST.getProto()->getName() != "main")
227       function.setPrivate();
228 
229     return function;
230   }
231 
232   /// Return the struct type that is the result of the given expression, or null
233   /// if it cannot be inferred.
getStructFor(ExprAST * expr)234   StructAST *getStructFor(ExprAST *expr) {
235     llvm::StringRef structName;
236     if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) {
237       auto varIt = symbolTable.lookup(decl->getName());
238       if (!varIt.first)
239         return nullptr;
240       structName = varIt.second->getType().name;
241     } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) {
242       if (access->getOp() != '.')
243         return nullptr;
244       // The name being accessed should be in the RHS.
245       auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS());
246       if (!name)
247         return nullptr;
248       StructAST *parentStruct = getStructFor(access->getLHS());
249       if (!parentStruct)
250         return nullptr;
251 
252       // Get the element within the struct corresponding to the name.
253       VarDeclExprAST *decl = nullptr;
254       for (auto &var : parentStruct->getVariables()) {
255         if (var->getName() == name->getName()) {
256           decl = var.get();
257           break;
258         }
259       }
260       if (!decl)
261         return nullptr;
262       structName = decl->getType().name;
263     }
264     if (structName.empty())
265       return nullptr;
266 
267     // If the struct name was valid, check for an entry in the struct map.
268     auto structIt = structMap.find(structName);
269     if (structIt == structMap.end())
270       return nullptr;
271     return structIt->second.second;
272   }
273 
274   /// Return the numeric member index of the given struct access expression.
getMemberIndex(BinaryExprAST & accessOp)275   llvm::Optional<size_t> getMemberIndex(BinaryExprAST &accessOp) {
276     assert(accessOp.getOp() == '.' && "expected access operation");
277 
278     // Lookup the struct node for the LHS.
279     StructAST *structAST = getStructFor(accessOp.getLHS());
280     if (!structAST)
281       return llvm::None;
282 
283     // Get the name from the RHS.
284     VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS());
285     if (!name)
286       return llvm::None;
287 
288     auto structVars = structAST->getVariables();
289     auto it = llvm::find_if(structVars, [&](auto &var) {
290       return var->getName() == name->getName();
291     });
292     if (it == structVars.end())
293       return llvm::None;
294     return it - structVars.begin();
295   }
296 
297   /// Emit a binary operation
mlirGen(BinaryExprAST & binop)298   mlir::Value mlirGen(BinaryExprAST &binop) {
299     // First emit the operations for each side of the operation before emitting
300     // the operation itself. For example if the expression is `a + foo(a)`
301     // 1) First it will visiting the LHS, which will return a reference to the
302     //    value holding `a`. This value should have been emitted at declaration
303     //    time and registered in the symbol table, so nothing would be
304     //    codegen'd. If the value is not in the symbol table, an error has been
305     //    emitted and nullptr is returned.
306     // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
307     //    and the result value is returned. If an error occurs we get a nullptr
308     //    and propagate.
309     //
310     mlir::Value lhs = mlirGen(*binop.getLHS());
311     if (!lhs)
312       return nullptr;
313     auto location = loc(binop.loc());
314 
315     // If this is an access operation, handle it immediately.
316     if (binop.getOp() == '.') {
317       llvm::Optional<size_t> accessIndex = getMemberIndex(binop);
318       if (!accessIndex) {
319         emitError(location, "invalid access into struct expression");
320         return nullptr;
321       }
322       return builder.create<StructAccessOp>(location, lhs, *accessIndex);
323     }
324 
325     // Otherwise, this is a normal binary op.
326     mlir::Value rhs = mlirGen(*binop.getRHS());
327     if (!rhs)
328       return nullptr;
329 
330     // Derive the operation name from the binary operator. At the moment we only
331     // support '+' and '*'.
332     switch (binop.getOp()) {
333     case '+':
334       return builder.create<AddOp>(location, lhs, rhs);
335     case '*':
336       return builder.create<MulOp>(location, lhs, rhs);
337     }
338 
339     emitError(location, "invalid binary operator '") << binop.getOp() << "'";
340     return nullptr;
341   }
342 
343   /// This is a reference to a variable in an expression. The variable is
344   /// expected to have been declared and so should have a value in the symbol
345   /// table, otherwise emit an error and return nullptr.
mlirGen(VariableExprAST & expr)346   mlir::Value mlirGen(VariableExprAST &expr) {
347     if (auto variable = symbolTable.lookup(expr.getName()).first)
348       return variable;
349 
350     emitError(loc(expr.loc()), "error: unknown variable '")
351         << expr.getName() << "'";
352     return nullptr;
353   }
354 
355   /// Emit a return operation. This will return failure if any generation fails.
mlirGen(ReturnExprAST & ret)356   mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
357     auto location = loc(ret.loc());
358 
359     // 'return' takes an optional expression, handle that case here.
360     mlir::Value expr = nullptr;
361     if (ret.getExpr().hasValue()) {
362       if (!(expr = mlirGen(*ret.getExpr().getValue())))
363         return mlir::failure();
364     }
365 
366     // Otherwise, this return operation has zero operands.
367     builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
368                                             : ArrayRef<mlir::Value>());
369     return mlir::success();
370   }
371 
372   /// Emit a constant for a literal/constant array. It will be emitted as a
373   /// flattened array of data in an Attribute attached to a `toy.constant`
374   /// operation. See documentation on [Attributes](LangRef.md#attributes) for
375   /// more details. Here is an excerpt:
376   ///
377   ///   Attributes are the mechanism for specifying constant data in MLIR in
378   ///   places where a variable is never allowed [...]. They consist of a name
379   ///   and a concrete attribute value. The set of expected attributes, their
380   ///   structure, and their interpretation are all contextually dependent on
381   ///   what they are attached to.
382   ///
383   /// Example, the source level statement:
384   ///   var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
385   /// will be converted to:
386   ///   %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
387   ///     [[1.000000e+00, 2.000000e+00, 3.000000e+00],
388   ///      [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
389   ///
getConstantAttr(LiteralExprAST & lit)390   mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) {
391     // The attribute is a vector with a floating point value per element
392     // (number) in the array, see `collectData()` below for more details.
393     std::vector<double> data;
394     data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
395                                  std::multiplies<int>()));
396     collectData(lit, data);
397 
398     // The type of this attribute is tensor of 64-bit floating-point with the
399     // shape of the literal.
400     mlir::Type elementType = builder.getF64Type();
401     auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
402 
403     // This is the actual attribute that holds the list of values for this
404     // tensor literal.
405     return mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
406   }
getConstantAttr(NumberExprAST & lit)407   mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) {
408     // The type of this attribute is tensor of 64-bit floating-point with no
409     // shape.
410     mlir::Type elementType = builder.getF64Type();
411     auto dataType = mlir::RankedTensorType::get({}, elementType);
412 
413     // This is the actual attribute that holds the list of values for this
414     // tensor literal.
415     return mlir::DenseElementsAttr::get(dataType,
416                                         llvm::makeArrayRef(lit.getValue()));
417   }
418   /// Emit a constant for a struct literal. It will be emitted as an array of
419   /// other literals in an Attribute attached to a `toy.struct_constant`
420   /// operation. This function returns the generated constant, along with the
421   /// corresponding struct type.
422   std::pair<mlir::ArrayAttr, mlir::Type>
getConstantAttr(StructLiteralExprAST & lit)423   getConstantAttr(StructLiteralExprAST &lit) {
424     std::vector<mlir::Attribute> attrElements;
425     std::vector<mlir::Type> typeElements;
426 
427     for (auto &var : lit.getValues()) {
428       if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) {
429         attrElements.push_back(getConstantAttr(*number));
430         typeElements.push_back(getType(llvm::None));
431       } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) {
432         attrElements.push_back(getConstantAttr(*lit));
433         typeElements.push_back(getType(llvm::None));
434       } else {
435         auto *structLit = llvm::cast<StructLiteralExprAST>(var.get());
436         auto attrTypePair = getConstantAttr(*structLit);
437         attrElements.push_back(attrTypePair.first);
438         typeElements.push_back(attrTypePair.second);
439       }
440     }
441     mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements);
442     mlir::Type dataType = StructType::get(typeElements);
443     return std::make_pair(dataAttr, dataType);
444   }
445 
446   /// Emit an array literal.
mlirGen(LiteralExprAST & lit)447   mlir::Value mlirGen(LiteralExprAST &lit) {
448     mlir::Type type = getType(lit.getDims());
449     mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit);
450 
451     // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
452     // method.
453     return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
454   }
455 
456   /// Emit a struct literal. It will be emitted as an array of
457   /// other literals in an Attribute attached to a `toy.struct_constant`
458   /// operation.
mlirGen(StructLiteralExprAST & lit)459   mlir::Value mlirGen(StructLiteralExprAST &lit) {
460     mlir::ArrayAttr dataAttr;
461     mlir::Type dataType;
462     std::tie(dataAttr, dataType) = getConstantAttr(lit);
463 
464     // Build the MLIR op `toy.struct_constant`. This invokes the
465     // `StructConstantOp::build` method.
466     return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr);
467   }
468 
469   /// Recursive helper function to accumulate the data that compose an array
470   /// literal. It flattens the nested structure in the supplied vector. For
471   /// example with this array:
472   ///  [[1, 2], [3, 4]]
473   /// we will generate:
474   ///  [ 1, 2, 3, 4 ]
475   /// Individual numbers are represented as doubles.
476   /// Attributes are the way MLIR attaches constant to operations.
collectData(ExprAST & expr,std::vector<double> & data)477   void collectData(ExprAST &expr, std::vector<double> &data) {
478     if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
479       for (auto &value : lit->getValues())
480         collectData(*value, data);
481       return;
482     }
483 
484     assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
485     data.push_back(cast<NumberExprAST>(expr).getValue());
486   }
487 
488   /// Emit a call expression. It emits specific operations for the `transpose`
489   /// builtin. Other identifiers are assumed to be user-defined functions.
mlirGen(CallExprAST & call)490   mlir::Value mlirGen(CallExprAST &call) {
491     llvm::StringRef callee = call.getCallee();
492     auto location = loc(call.loc());
493 
494     // Codegen the operands first.
495     SmallVector<mlir::Value, 4> operands;
496     for (auto &expr : call.getArgs()) {
497       auto arg = mlirGen(*expr);
498       if (!arg)
499         return nullptr;
500       operands.push_back(arg);
501     }
502 
503     // Builtin calls have their custom operation, meaning this is a
504     // straightforward emission.
505     if (callee == "transpose") {
506       if (call.getArgs().size() != 1) {
507         emitError(location, "MLIR codegen encountered an error: toy.transpose "
508                             "does not accept multiple arguments");
509         return nullptr;
510       }
511       return builder.create<TransposeOp>(location, operands[0]);
512     }
513 
514     // Otherwise this is a call to a user-defined function. Calls to
515     // user-defined functions are mapped to a custom call that takes the callee
516     // name as an attribute.
517     auto calledFuncIt = functionMap.find(callee);
518     if (calledFuncIt == functionMap.end()) {
519       emitError(location) << "no defined function found for '" << callee << "'";
520       return nullptr;
521     }
522     mlir::FuncOp calledFunc = calledFuncIt->second;
523     return builder.create<GenericCallOp>(
524         location, calledFunc.getType().getResult(0),
525         builder.getSymbolRefAttr(callee), operands);
526   }
527 
528   /// Emit a print expression. It emits specific operations for two builtins:
529   /// transpose(x) and print(x).
mlirGen(PrintExprAST & call)530   mlir::LogicalResult mlirGen(PrintExprAST &call) {
531     auto arg = mlirGen(*call.getArg());
532     if (!arg)
533       return mlir::failure();
534 
535     builder.create<PrintOp>(loc(call.loc()), arg);
536     return mlir::success();
537   }
538 
539   /// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlirGen(NumberExprAST & num)540   mlir::Value mlirGen(NumberExprAST &num) {
541     return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
542   }
543 
544   /// Dispatch codegen for the right expression subclass using RTTI.
mlirGen(ExprAST & expr)545   mlir::Value mlirGen(ExprAST &expr) {
546     switch (expr.getKind()) {
547     case toy::ExprAST::Expr_BinOp:
548       return mlirGen(cast<BinaryExprAST>(expr));
549     case toy::ExprAST::Expr_Var:
550       return mlirGen(cast<VariableExprAST>(expr));
551     case toy::ExprAST::Expr_Literal:
552       return mlirGen(cast<LiteralExprAST>(expr));
553     case toy::ExprAST::Expr_StructLiteral:
554       return mlirGen(cast<StructLiteralExprAST>(expr));
555     case toy::ExprAST::Expr_Call:
556       return mlirGen(cast<CallExprAST>(expr));
557     case toy::ExprAST::Expr_Num:
558       return mlirGen(cast<NumberExprAST>(expr));
559     default:
560       emitError(loc(expr.loc()))
561           << "MLIR codegen encountered an unhandled expr kind '"
562           << Twine(expr.getKind()) << "'";
563       return nullptr;
564     }
565   }
566 
567   /// Handle a variable declaration, we'll codegen the expression that forms the
568   /// initializer and record the value in the symbol table before returning it.
569   /// Future expressions will be able to reference this variable through symbol
570   /// table lookup.
mlirGen(VarDeclExprAST & vardecl)571   mlir::Value mlirGen(VarDeclExprAST &vardecl) {
572     auto init = vardecl.getInitVal();
573     if (!init) {
574       emitError(loc(vardecl.loc()),
575                 "missing initializer in variable declaration");
576       return nullptr;
577     }
578 
579     mlir::Value value = mlirGen(*init);
580     if (!value)
581       return nullptr;
582 
583     // Handle the case where we are initializing a struct value.
584     VarType varType = vardecl.getType();
585     if (!varType.name.empty()) {
586       // Check that the initializer type is the same as the variable
587       // declaration.
588       mlir::Type type = getType(varType, vardecl.loc());
589       if (!type)
590         return nullptr;
591       if (type != value.getType()) {
592         emitError(loc(vardecl.loc()))
593             << "struct type of initializer is different than the variable "
594                "declaration. Got "
595             << value.getType() << ", but expected " << type;
596         return nullptr;
597       }
598 
599       // Otherwise, we have the initializer value, but in case the variable was
600       // declared with specific shape, we emit a "reshape" operation. It will
601       // get optimized out later as needed.
602     } else if (!varType.shape.empty()) {
603       value = builder.create<ReshapeOp>(loc(vardecl.loc()),
604                                         getType(varType.shape), value);
605     }
606 
607     // Register the value in the symbol table.
608     if (failed(declare(vardecl, value)))
609       return nullptr;
610     return value;
611   }
612 
613   /// Codegen a list of expression, return failure if one of them hit an error.
mlirGen(ExprASTList & blockAST)614   mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
615     SymbolTableScopeT var_scope(symbolTable);
616     for (auto &expr : blockAST) {
617       // Specific handling for variable declarations, return statement, and
618       // print. These can only appear in block list and not in nested
619       // expressions.
620       if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
621         if (!mlirGen(*vardecl))
622           return mlir::failure();
623         continue;
624       }
625       if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
626         return mlirGen(*ret);
627       if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
628         if (mlir::failed(mlirGen(*print)))
629           return mlir::success();
630         continue;
631       }
632 
633       // Generic expression dispatch codegen.
634       if (!mlirGen(*expr))
635         return mlir::failure();
636     }
637     return mlir::success();
638   }
639 
640   /// Build a tensor type from a list of shape dimensions.
getType(ArrayRef<int64_t> shape)641   mlir::Type getType(ArrayRef<int64_t> shape) {
642     // If the shape is empty, then this type is unranked.
643     if (shape.empty())
644       return mlir::UnrankedTensorType::get(builder.getF64Type());
645 
646     // Otherwise, we use the given shape.
647     return mlir::RankedTensorType::get(shape, builder.getF64Type());
648   }
649 
650   /// Build an MLIR type from a Toy AST variable type (forward to the generic
651   /// getType above for non-struct types).
getType(const VarType & type,const Location & location)652   mlir::Type getType(const VarType &type, const Location &location) {
653     if (!type.name.empty()) {
654       auto it = structMap.find(type.name);
655       if (it == structMap.end()) {
656         emitError(loc(location))
657             << "error: unknown struct type '" << type.name << "'";
658         return nullptr;
659       }
660       return it->second.first;
661     }
662 
663     return getType(type.shape);
664   }
665 };
666 
667 } // namespace
668 
669 namespace toy {
670 
671 // The public API for codegen.
mlirGen(mlir::MLIRContext & context,ModuleAST & moduleAST)672 mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
673                               ModuleAST &moduleAST) {
674   return MLIRGenImpl(context).mlirGen(moduleAST);
675 }
676 
677 } // namespace toy
678