1 //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple IR generation targeting MLIR from a Module AST
10 // for the Toy language.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/MLIRGen.h"
15 #include "toy/AST.h"
16 #include "toy/Dialect.h"
17 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Function.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Module.h"
23 #include "mlir/IR/StandardTypes.h"
24 #include "mlir/IR/Verifier.h"
25 
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/ScopedHashTable.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <numeric>
30 
31 using namespace mlir::toy;
32 using namespace toy;
33 
34 using llvm::ArrayRef;
35 using llvm::cast;
36 using llvm::dyn_cast;
37 using llvm::isa;
38 using llvm::makeArrayRef;
39 using llvm::ScopedHashTableScope;
40 using llvm::SmallVector;
41 using llvm::StringRef;
42 using llvm::Twine;
43 
44 namespace {
45 
46 /// Implementation of a simple MLIR emission from the Toy AST.
47 ///
48 /// This will emit operations that are specific to the Toy language, preserving
49 /// the semantics of the language and (hopefully) allow to perform accurate
50 /// analysis and transformation based on these high level semantics.
51 class MLIRGenImpl {
52 public:
MLIRGenImpl(mlir::MLIRContext & context)53   MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
54 
55   /// Public API: convert the AST for a Toy module (source file) to an MLIR
56   /// Module operation.
mlirGen(ModuleAST & moduleAST)57   mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
58     // We create an empty MLIR module and codegen functions one at a time and
59     // add them to the module.
60     theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
61 
62     for (auto &record : moduleAST) {
63       if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
64         auto func = mlirGen(*funcAST);
65         if (!func)
66           return nullptr;
67 
68         theModule.push_back(func);
69         functionMap.insert({func.getName(), func});
70       } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
71         if (failed(mlirGen(*str)))
72           return nullptr;
73       } else {
74         llvm_unreachable("unknown record type");
75       }
76     }
77 
78     // Verify the module after we have finished constructing it, this will check
79     // the structural properties of the IR and invoke any specific verifiers we
80     // have on the Toy operations.
81     if (failed(mlir::verify(theModule))) {
82       theModule.emitError("module verification error");
83       return nullptr;
84     }
85 
86     return theModule;
87   }
88 
89 private:
90   /// A "module" matches a Toy source file: containing a list of functions.
91   mlir::ModuleOp theModule;
92 
93   /// The builder is a helper class to create IR inside a function. The builder
94   /// is stateful, in particular it keeps an "insertion point": this is where
95   /// the next operations will be introduced.
96   mlir::OpBuilder builder;
97 
98   /// The symbol table maps a variable name to a value in the current scope.
99   /// Entering a function creates a new scope, and the function arguments are
100   /// added to the mapping. When the processing of a function is terminated, the
101   /// scope is destroyed and the mappings created in this scope are dropped.
102   llvm::ScopedHashTable<StringRef, std::pair<mlir::Value, VarDeclExprAST *>>
103       symbolTable;
104   using SymbolTableScopeT =
105       llvm::ScopedHashTableScope<StringRef,
106                                  std::pair<mlir::Value, VarDeclExprAST *>>;
107 
108   /// A mapping for the functions that have been code generated to MLIR.
109   llvm::StringMap<mlir::FuncOp> functionMap;
110 
111   /// A mapping for named struct types to the underlying MLIR type and the
112   /// original AST node.
113   llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap;
114 
115   /// Helper conversion for a Toy AST location to an MLIR location.
loc(Location loc)116   mlir::Location loc(Location loc) {
117     return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
118                                      loc.col);
119   }
120 
121   /// Declare a variable in the current scope, return success if the variable
122   /// wasn't declared yet.
declare(VarDeclExprAST & var,mlir::Value value)123   mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) {
124     if (symbolTable.count(var.getName()))
125       return mlir::failure();
126     symbolTable.insert(var.getName(), {value, &var});
127     return mlir::success();
128   }
129 
130   /// Create an MLIR type for the given struct.
mlirGen(StructAST & str)131   mlir::LogicalResult mlirGen(StructAST &str) {
132     if (structMap.count(str.getName()))
133       return emitError(loc(str.loc())) << "error: struct type with name `"
134                                        << str.getName() << "' already exists";
135 
136     auto variables = str.getVariables();
137     std::vector<mlir::Type> elementTypes;
138     elementTypes.reserve(variables.size());
139     for (auto &variable : variables) {
140       if (variable->getInitVal())
141         return emitError(loc(variable->loc()))
142                << "error: variables within a struct definition must not have "
143                   "initializers";
144       if (!variable->getType().shape.empty())
145         return emitError(loc(variable->loc()))
146                << "error: variables within a struct definition must not have "
147                   "initializers";
148 
149       mlir::Type type = getType(variable->getType(), variable->loc());
150       if (!type)
151         return mlir::failure();
152       elementTypes.push_back(type);
153     }
154 
155     structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str);
156     return mlir::success();
157   }
158 
159   /// Create the prototype for an MLIR function with as many arguments as the
160   /// provided Toy AST prototype.
mlirGen(PrototypeAST & proto)161   mlir::FuncOp mlirGen(PrototypeAST &proto) {
162     auto location = loc(proto.loc());
163 
164     // This is a generic function, the return type will be inferred later.
165     llvm::SmallVector<mlir::Type, 4> argTypes;
166     argTypes.reserve(proto.getArgs().size());
167     for (auto &arg : proto.getArgs()) {
168       mlir::Type type = getType(arg->getType(), arg->loc());
169       if (!type)
170         return nullptr;
171       argTypes.push_back(type);
172     }
173     auto func_type = builder.getFunctionType(argTypes, llvm::None);
174     return mlir::FuncOp::create(location, proto.getName(), func_type);
175   }
176 
177   /// Emit a new function and add it to the MLIR module.
mlirGen(FunctionAST & funcAST)178   mlir::FuncOp mlirGen(FunctionAST &funcAST) {
179     // Create a scope in the symbol table to hold variable declarations.
180     SymbolTableScopeT var_scope(symbolTable);
181 
182     // Create an MLIR function for the given prototype.
183     mlir::FuncOp function(mlirGen(*funcAST.getProto()));
184     if (!function)
185       return nullptr;
186 
187     // Let's start the body of the function now!
188     // In MLIR the entry block of the function is special: it must have the same
189     // argument list as the function itself.
190     auto &entryBlock = *function.addEntryBlock();
191     auto protoArgs = funcAST.getProto()->getArgs();
192 
193     // Declare all the function arguments in the symbol table.
194     for (const auto &name_value :
195          llvm::zip(protoArgs, entryBlock.getArguments())) {
196       if (failed(declare(*std::get<0>(name_value), std::get<1>(name_value))))
197         return nullptr;
198     }
199 
200     // Set the insertion point in the builder to the beginning of the function
201     // body, it will be used throughout the codegen to create operations in this
202     // function.
203     builder.setInsertionPointToStart(&entryBlock);
204 
205     // Emit the body of the function.
206     if (mlir::failed(mlirGen(*funcAST.getBody()))) {
207       function.erase();
208       return nullptr;
209     }
210 
211     // Implicitly return void if no return statement was emitted.
212     // FIXME: we may fix the parser instead to always return the last expression
213     // (this would possibly help the REPL case later)
214     ReturnOp returnOp;
215     if (!entryBlock.empty())
216       returnOp = dyn_cast<ReturnOp>(entryBlock.back());
217     if (!returnOp) {
218       builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
219     } else if (returnOp.hasOperand()) {
220       // Otherwise, if this return operation has an operand then add a result to
221       // the function.
222       function.setType(builder.getFunctionType(function.getType().getInputs(),
223                                                *returnOp.operand_type_begin()));
224     }
225 
226     // If this function isn't main, then set the visibility to private.
227     if (funcAST.getProto()->getName() != "main")
228       function.setVisibility(mlir::FuncOp::Visibility::Private);
229 
230     return function;
231   }
232 
233   /// Return the struct type that is the result of the given expression, or null
234   /// if it cannot be inferred.
getStructFor(ExprAST * expr)235   StructAST *getStructFor(ExprAST *expr) {
236     llvm::StringRef structName;
237     if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) {
238       auto varIt = symbolTable.lookup(decl->getName());
239       if (!varIt.first)
240         return nullptr;
241       structName = varIt.second->getType().name;
242     } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) {
243       if (access->getOp() != '.')
244         return nullptr;
245       // The name being accessed should be in the RHS.
246       auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS());
247       if (!name)
248         return nullptr;
249       StructAST *parentStruct = getStructFor(access->getLHS());
250       if (!parentStruct)
251         return nullptr;
252 
253       // Get the element within the struct corresponding to the name.
254       VarDeclExprAST *decl = nullptr;
255       for (auto &var : parentStruct->getVariables()) {
256         if (var->getName() == name->getName()) {
257           decl = var.get();
258           break;
259         }
260       }
261       if (!decl)
262         return nullptr;
263       structName = decl->getType().name;
264     }
265     if (structName.empty())
266       return nullptr;
267 
268     // If the struct name was valid, check for an entry in the struct map.
269     auto structIt = structMap.find(structName);
270     if (structIt == structMap.end())
271       return nullptr;
272     return structIt->second.second;
273   }
274 
275   /// Return the numeric member index of the given struct access expression.
getMemberIndex(BinaryExprAST & accessOp)276   llvm::Optional<size_t> getMemberIndex(BinaryExprAST &accessOp) {
277     assert(accessOp.getOp() == '.' && "expected access operation");
278 
279     // Lookup the struct node for the LHS.
280     StructAST *structAST = getStructFor(accessOp.getLHS());
281     if (!structAST)
282       return llvm::None;
283 
284     // Get the name from the RHS.
285     VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS());
286     if (!name)
287       return llvm::None;
288 
289     auto structVars = structAST->getVariables();
290     auto it = llvm::find_if(structVars, [&](auto &var) {
291       return var->getName() == name->getName();
292     });
293     if (it == structVars.end())
294       return llvm::None;
295     return it - structVars.begin();
296   }
297 
298   /// Emit a binary operation
mlirGen(BinaryExprAST & binop)299   mlir::Value mlirGen(BinaryExprAST &binop) {
300     // First emit the operations for each side of the operation before emitting
301     // the operation itself. For example if the expression is `a + foo(a)`
302     // 1) First it will visiting the LHS, which will return a reference to the
303     //    value holding `a`. This value should have been emitted at declaration
304     //    time and registered in the symbol table, so nothing would be
305     //    codegen'd. If the value is not in the symbol table, an error has been
306     //    emitted and nullptr is returned.
307     // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
308     //    and the result value is returned. If an error occurs we get a nullptr
309     //    and propagate.
310     //
311     mlir::Value lhs = mlirGen(*binop.getLHS());
312     if (!lhs)
313       return nullptr;
314     auto location = loc(binop.loc());
315 
316     // If this is an access operation, handle it immediately.
317     if (binop.getOp() == '.') {
318       llvm::Optional<size_t> accessIndex = getMemberIndex(binop);
319       if (!accessIndex) {
320         emitError(location, "invalid access into struct expression");
321         return nullptr;
322       }
323       return builder.create<StructAccessOp>(location, lhs, *accessIndex);
324     }
325 
326     // Otherwise, this is a normal binary op.
327     mlir::Value rhs = mlirGen(*binop.getRHS());
328     if (!rhs)
329       return nullptr;
330 
331     // Derive the operation name from the binary operator. At the moment we only
332     // support '+' and '*'.
333     switch (binop.getOp()) {
334     case '+':
335       return builder.create<AddOp>(location, lhs, rhs);
336     case '*':
337       return builder.create<MulOp>(location, lhs, rhs);
338     }
339 
340     emitError(location, "invalid binary operator '") << binop.getOp() << "'";
341     return nullptr;
342   }
343 
344   /// This is a reference to a variable in an expression. The variable is
345   /// expected to have been declared and so should have a value in the symbol
346   /// table, otherwise emit an error and return nullptr.
mlirGen(VariableExprAST & expr)347   mlir::Value mlirGen(VariableExprAST &expr) {
348     if (auto variable = symbolTable.lookup(expr.getName()).first)
349       return variable;
350 
351     emitError(loc(expr.loc()), "error: unknown variable '")
352         << expr.getName() << "'";
353     return nullptr;
354   }
355 
356   /// Emit a return operation. This will return failure if any generation fails.
mlirGen(ReturnExprAST & ret)357   mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
358     auto location = loc(ret.loc());
359 
360     // 'return' takes an optional expression, handle that case here.
361     mlir::Value expr = nullptr;
362     if (ret.getExpr().hasValue()) {
363       if (!(expr = mlirGen(*ret.getExpr().getValue())))
364         return mlir::failure();
365     }
366 
367     // Otherwise, this return operation has zero operands.
368     builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
369                                             : ArrayRef<mlir::Value>());
370     return mlir::success();
371   }
372 
373   /// Emit a constant for a literal/constant array. It will be emitted as a
374   /// flattened array of data in an Attribute attached to a `toy.constant`
375   /// operation. See documentation on [Attributes](LangRef.md#attributes) for
376   /// more details. Here is an excerpt:
377   ///
378   ///   Attributes are the mechanism for specifying constant data in MLIR in
379   ///   places where a variable is never allowed [...]. They consist of a name
380   ///   and a concrete attribute value. The set of expected attributes, their
381   ///   structure, and their interpretation are all contextually dependent on
382   ///   what they are attached to.
383   ///
384   /// Example, the source level statement:
385   ///   var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
386   /// will be converted to:
387   ///   %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
388   ///     [[1.000000e+00, 2.000000e+00, 3.000000e+00],
389   ///      [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
390   ///
getConstantAttr(LiteralExprAST & lit)391   mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) {
392     // The attribute is a vector with a floating point value per element
393     // (number) in the array, see `collectData()` below for more details.
394     std::vector<double> data;
395     data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
396                                  std::multiplies<int>()));
397     collectData(lit, data);
398 
399     // The type of this attribute is tensor of 64-bit floating-point with the
400     // shape of the literal.
401     mlir::Type elementType = builder.getF64Type();
402     auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
403 
404     // This is the actual attribute that holds the list of values for this
405     // tensor literal.
406     return mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
407   }
getConstantAttr(NumberExprAST & lit)408   mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) {
409     // The type of this attribute is tensor of 64-bit floating-point with no
410     // shape.
411     mlir::Type elementType = builder.getF64Type();
412     auto dataType = mlir::RankedTensorType::get({}, elementType);
413 
414     // This is the actual attribute that holds the list of values for this
415     // tensor literal.
416     return mlir::DenseElementsAttr::get(dataType,
417                                         llvm::makeArrayRef(lit.getValue()));
418   }
419   /// Emit a constant for a struct literal. It will be emitted as an array of
420   /// other literals in an Attribute attached to a `toy.struct_constant`
421   /// operation. This function returns the generated constant, along with the
422   /// corresponding struct type.
423   std::pair<mlir::ArrayAttr, mlir::Type>
getConstantAttr(StructLiteralExprAST & lit)424   getConstantAttr(StructLiteralExprAST &lit) {
425     std::vector<mlir::Attribute> attrElements;
426     std::vector<mlir::Type> typeElements;
427 
428     for (auto &var : lit.getValues()) {
429       if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) {
430         attrElements.push_back(getConstantAttr(*number));
431         typeElements.push_back(getType(llvm::None));
432       } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) {
433         attrElements.push_back(getConstantAttr(*lit));
434         typeElements.push_back(getType(llvm::None));
435       } else {
436         auto *structLit = llvm::cast<StructLiteralExprAST>(var.get());
437         auto attrTypePair = getConstantAttr(*structLit);
438         attrElements.push_back(attrTypePair.first);
439         typeElements.push_back(attrTypePair.second);
440       }
441     }
442     mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements);
443     mlir::Type dataType = StructType::get(typeElements);
444     return std::make_pair(dataAttr, dataType);
445   }
446 
447   /// Emit an array literal.
mlirGen(LiteralExprAST & lit)448   mlir::Value mlirGen(LiteralExprAST &lit) {
449     mlir::Type type = getType(lit.getDims());
450     mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit);
451 
452     // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
453     // method.
454     return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
455   }
456 
457   /// Emit a struct literal. It will be emitted as an array of
458   /// other literals in an Attribute attached to a `toy.struct_constant`
459   /// operation.
mlirGen(StructLiteralExprAST & lit)460   mlir::Value mlirGen(StructLiteralExprAST &lit) {
461     mlir::ArrayAttr dataAttr;
462     mlir::Type dataType;
463     std::tie(dataAttr, dataType) = getConstantAttr(lit);
464 
465     // Build the MLIR op `toy.struct_constant`. This invokes the
466     // `StructConstantOp::build` method.
467     return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr);
468   }
469 
470   /// Recursive helper function to accumulate the data that compose an array
471   /// literal. It flattens the nested structure in the supplied vector. For
472   /// example with this array:
473   ///  [[1, 2], [3, 4]]
474   /// we will generate:
475   ///  [ 1, 2, 3, 4 ]
476   /// Individual numbers are represented as doubles.
477   /// Attributes are the way MLIR attaches constant to operations.
collectData(ExprAST & expr,std::vector<double> & data)478   void collectData(ExprAST &expr, std::vector<double> &data) {
479     if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
480       for (auto &value : lit->getValues())
481         collectData(*value, data);
482       return;
483     }
484 
485     assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
486     data.push_back(cast<NumberExprAST>(expr).getValue());
487   }
488 
489   /// Emit a call expression. It emits specific operations for the `transpose`
490   /// builtin. Other identifiers are assumed to be user-defined functions.
mlirGen(CallExprAST & call)491   mlir::Value mlirGen(CallExprAST &call) {
492     llvm::StringRef callee = call.getCallee();
493     auto location = loc(call.loc());
494 
495     // Codegen the operands first.
496     SmallVector<mlir::Value, 4> operands;
497     for (auto &expr : call.getArgs()) {
498       auto arg = mlirGen(*expr);
499       if (!arg)
500         return nullptr;
501       operands.push_back(arg);
502     }
503 
504     // Builtin calls have their custom operation, meaning this is a
505     // straightforward emission.
506     if (callee == "transpose") {
507       if (call.getArgs().size() != 1) {
508         emitError(location, "MLIR codegen encountered an error: toy.transpose "
509                             "does not accept multiple arguments");
510         return nullptr;
511       }
512       return builder.create<TransposeOp>(location, operands[0]);
513     }
514 
515     // Otherwise this is a call to a user-defined function. Calls to
516     // user-defined functions are mapped to a custom call that takes the callee
517     // name as an attribute.
518     auto calledFuncIt = functionMap.find(callee);
519     if (calledFuncIt == functionMap.end()) {
520       emitError(location) << "no defined function found for '" << callee << "'";
521       return nullptr;
522     }
523     mlir::FuncOp calledFunc = calledFuncIt->second;
524     return builder.create<GenericCallOp>(
525         location, calledFunc.getType().getResult(0),
526         builder.getSymbolRefAttr(callee), operands);
527   }
528 
529   /// Emit a print expression. It emits specific operations for two builtins:
530   /// transpose(x) and print(x).
mlirGen(PrintExprAST & call)531   mlir::LogicalResult mlirGen(PrintExprAST &call) {
532     auto arg = mlirGen(*call.getArg());
533     if (!arg)
534       return mlir::failure();
535 
536     builder.create<PrintOp>(loc(call.loc()), arg);
537     return mlir::success();
538   }
539 
540   /// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlirGen(NumberExprAST & num)541   mlir::Value mlirGen(NumberExprAST &num) {
542     return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
543   }
544 
545   /// Dispatch codegen for the right expression subclass using RTTI.
mlirGen(ExprAST & expr)546   mlir::Value mlirGen(ExprAST &expr) {
547     switch (expr.getKind()) {
548     case toy::ExprAST::Expr_BinOp:
549       return mlirGen(cast<BinaryExprAST>(expr));
550     case toy::ExprAST::Expr_Var:
551       return mlirGen(cast<VariableExprAST>(expr));
552     case toy::ExprAST::Expr_Literal:
553       return mlirGen(cast<LiteralExprAST>(expr));
554     case toy::ExprAST::Expr_StructLiteral:
555       return mlirGen(cast<StructLiteralExprAST>(expr));
556     case toy::ExprAST::Expr_Call:
557       return mlirGen(cast<CallExprAST>(expr));
558     case toy::ExprAST::Expr_Num:
559       return mlirGen(cast<NumberExprAST>(expr));
560     default:
561       emitError(loc(expr.loc()))
562           << "MLIR codegen encountered an unhandled expr kind '"
563           << Twine(expr.getKind()) << "'";
564       return nullptr;
565     }
566   }
567 
568   /// Handle a variable declaration, we'll codegen the expression that forms the
569   /// initializer and record the value in the symbol table before returning it.
570   /// Future expressions will be able to reference this variable through symbol
571   /// table lookup.
mlirGen(VarDeclExprAST & vardecl)572   mlir::Value mlirGen(VarDeclExprAST &vardecl) {
573     auto init = vardecl.getInitVal();
574     if (!init) {
575       emitError(loc(vardecl.loc()),
576                 "missing initializer in variable declaration");
577       return nullptr;
578     }
579 
580     mlir::Value value = mlirGen(*init);
581     if (!value)
582       return nullptr;
583 
584     // Handle the case where we are initializing a struct value.
585     VarType varType = vardecl.getType();
586     if (!varType.name.empty()) {
587       // Check that the initializer type is the same as the variable
588       // declaration.
589       mlir::Type type = getType(varType, vardecl.loc());
590       if (!type)
591         return nullptr;
592       if (type != value.getType()) {
593         emitError(loc(vardecl.loc()))
594             << "struct type of initializer is different than the variable "
595                "declaration. Got "
596             << value.getType() << ", but expected " << type;
597         return nullptr;
598       }
599 
600       // Otherwise, we have the initializer value, but in case the variable was
601       // declared with specific shape, we emit a "reshape" operation. It will
602       // get optimized out later as needed.
603     } else if (!varType.shape.empty()) {
604       value = builder.create<ReshapeOp>(loc(vardecl.loc()),
605                                         getType(varType.shape), value);
606     }
607 
608     // Register the value in the symbol table.
609     if (failed(declare(vardecl, value)))
610       return nullptr;
611     return value;
612   }
613 
614   /// Codegen a list of expression, return failure if one of them hit an error.
mlirGen(ExprASTList & blockAST)615   mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
616     SymbolTableScopeT var_scope(symbolTable);
617     for (auto &expr : blockAST) {
618       // Specific handling for variable declarations, return statement, and
619       // print. These can only appear in block list and not in nested
620       // expressions.
621       if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
622         if (!mlirGen(*vardecl))
623           return mlir::failure();
624         continue;
625       }
626       if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
627         return mlirGen(*ret);
628       if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
629         if (mlir::failed(mlirGen(*print)))
630           return mlir::success();
631         continue;
632       }
633 
634       // Generic expression dispatch codegen.
635       if (!mlirGen(*expr))
636         return mlir::failure();
637     }
638     return mlir::success();
639   }
640 
641   /// Build a tensor type from a list of shape dimensions.
getType(ArrayRef<int64_t> shape)642   mlir::Type getType(ArrayRef<int64_t> shape) {
643     // If the shape is empty, then this type is unranked.
644     if (shape.empty())
645       return mlir::UnrankedTensorType::get(builder.getF64Type());
646 
647     // Otherwise, we use the given shape.
648     return mlir::RankedTensorType::get(shape, builder.getF64Type());
649   }
650 
651   /// Build an MLIR type from a Toy AST variable type (forward to the generic
652   /// getType above for non-struct types).
getType(const VarType & type,const Location & location)653   mlir::Type getType(const VarType &type, const Location &location) {
654     if (!type.name.empty()) {
655       auto it = structMap.find(type.name);
656       if (it == structMap.end()) {
657         emitError(loc(location))
658             << "error: unknown struct type '" << type.name << "'";
659         return nullptr;
660       }
661       return it->second.first;
662     }
663 
664     return getType(type.shape);
665   }
666 };
667 
668 } // namespace
669 
670 namespace toy {
671 
672 // The public API for codegen.
mlirGen(mlir::MLIRContext & context,ModuleAST & moduleAST)673 mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
674                               ModuleAST &moduleAST) {
675   return MLIRGenImpl(context).mlirGen(moduleAST);
676 }
677 
678 } // namespace toy
679