1 //===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple IR generation targeting MLIR from a Module AST
10 // for the Toy language.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "toy/MLIRGen.h"
15 #include "toy/AST.h"
16 #include "toy/Dialect.h"
17
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Verifier.h"
24
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/ScopedHashTable.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <numeric>
29
30 using namespace mlir::toy;
31 using namespace toy;
32
33 using llvm::ArrayRef;
34 using llvm::cast;
35 using llvm::dyn_cast;
36 using llvm::isa;
37 using llvm::makeArrayRef;
38 using llvm::ScopedHashTableScope;
39 using llvm::SmallVector;
40 using llvm::StringRef;
41 using llvm::Twine;
42
43 namespace {
44
45 /// Implementation of a simple MLIR emission from the Toy AST.
46 ///
47 /// This will emit operations that are specific to the Toy language, preserving
48 /// the semantics of the language and (hopefully) allow to perform accurate
49 /// analysis and transformation based on these high level semantics.
50 class MLIRGenImpl {
51 public:
MLIRGenImpl(mlir::MLIRContext & context)52 MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
53
54 /// Public API: convert the AST for a Toy module (source file) to an MLIR
55 /// Module operation.
mlirGen(ModuleAST & moduleAST)56 mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
57 // We create an empty MLIR module and codegen functions one at a time and
58 // add them to the module.
59 theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
60
61 for (auto &record : moduleAST) {
62 if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
63 auto func = mlirGen(*funcAST);
64 if (!func)
65 return nullptr;
66
67 theModule.push_back(func);
68 functionMap.insert({func.getName(), func});
69 } else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
70 if (failed(mlirGen(*str)))
71 return nullptr;
72 } else {
73 llvm_unreachable("unknown record type");
74 }
75 }
76
77 // Verify the module after we have finished constructing it, this will check
78 // the structural properties of the IR and invoke any specific verifiers we
79 // have on the Toy operations.
80 if (failed(mlir::verify(theModule))) {
81 theModule.emitError("module verification error");
82 return nullptr;
83 }
84
85 return theModule;
86 }
87
88 private:
89 /// A "module" matches a Toy source file: containing a list of functions.
90 mlir::ModuleOp theModule;
91
92 /// The builder is a helper class to create IR inside a function. The builder
93 /// is stateful, in particular it keeps an "insertion point": this is where
94 /// the next operations will be introduced.
95 mlir::OpBuilder builder;
96
97 /// The symbol table maps a variable name to a value in the current scope.
98 /// Entering a function creates a new scope, and the function arguments are
99 /// added to the mapping. When the processing of a function is terminated, the
100 /// scope is destroyed and the mappings created in this scope are dropped.
101 llvm::ScopedHashTable<StringRef, std::pair<mlir::Value, VarDeclExprAST *>>
102 symbolTable;
103 using SymbolTableScopeT =
104 llvm::ScopedHashTableScope<StringRef,
105 std::pair<mlir::Value, VarDeclExprAST *>>;
106
107 /// A mapping for the functions that have been code generated to MLIR.
108 llvm::StringMap<mlir::FuncOp> functionMap;
109
110 /// A mapping for named struct types to the underlying MLIR type and the
111 /// original AST node.
112 llvm::StringMap<std::pair<mlir::Type, StructAST *>> structMap;
113
114 /// Helper conversion for a Toy AST location to an MLIR location.
loc(Location loc)115 mlir::Location loc(Location loc) {
116 return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
117 loc.col);
118 }
119
120 /// Declare a variable in the current scope, return success if the variable
121 /// wasn't declared yet.
declare(VarDeclExprAST & var,mlir::Value value)122 mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) {
123 if (symbolTable.count(var.getName()))
124 return mlir::failure();
125 symbolTable.insert(var.getName(), {value, &var});
126 return mlir::success();
127 }
128
129 /// Create an MLIR type for the given struct.
mlirGen(StructAST & str)130 mlir::LogicalResult mlirGen(StructAST &str) {
131 if (structMap.count(str.getName()))
132 return emitError(loc(str.loc())) << "error: struct type with name `"
133 << str.getName() << "' already exists";
134
135 auto variables = str.getVariables();
136 std::vector<mlir::Type> elementTypes;
137 elementTypes.reserve(variables.size());
138 for (auto &variable : variables) {
139 if (variable->getInitVal())
140 return emitError(loc(variable->loc()))
141 << "error: variables within a struct definition must not have "
142 "initializers";
143 if (!variable->getType().shape.empty())
144 return emitError(loc(variable->loc()))
145 << "error: variables within a struct definition must not have "
146 "initializers";
147
148 mlir::Type type = getType(variable->getType(), variable->loc());
149 if (!type)
150 return mlir::failure();
151 elementTypes.push_back(type);
152 }
153
154 structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str);
155 return mlir::success();
156 }
157
158 /// Create the prototype for an MLIR function with as many arguments as the
159 /// provided Toy AST prototype.
mlirGen(PrototypeAST & proto)160 mlir::FuncOp mlirGen(PrototypeAST &proto) {
161 auto location = loc(proto.loc());
162
163 // This is a generic function, the return type will be inferred later.
164 llvm::SmallVector<mlir::Type, 4> argTypes;
165 argTypes.reserve(proto.getArgs().size());
166 for (auto &arg : proto.getArgs()) {
167 mlir::Type type = getType(arg->getType(), arg->loc());
168 if (!type)
169 return nullptr;
170 argTypes.push_back(type);
171 }
172 auto func_type = builder.getFunctionType(argTypes, llvm::None);
173 return mlir::FuncOp::create(location, proto.getName(), func_type);
174 }
175
176 /// Emit a new function and add it to the MLIR module.
mlirGen(FunctionAST & funcAST)177 mlir::FuncOp mlirGen(FunctionAST &funcAST) {
178 // Create a scope in the symbol table to hold variable declarations.
179 SymbolTableScopeT var_scope(symbolTable);
180
181 // Create an MLIR function for the given prototype.
182 mlir::FuncOp function(mlirGen(*funcAST.getProto()));
183 if (!function)
184 return nullptr;
185
186 // Let's start the body of the function now!
187 // In MLIR the entry block of the function is special: it must have the same
188 // argument list as the function itself.
189 auto &entryBlock = *function.addEntryBlock();
190 auto protoArgs = funcAST.getProto()->getArgs();
191
192 // Declare all the function arguments in the symbol table.
193 for (const auto nameValue :
194 llvm::zip(protoArgs, entryBlock.getArguments())) {
195 if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue))))
196 return nullptr;
197 }
198
199 // Set the insertion point in the builder to the beginning of the function
200 // body, it will be used throughout the codegen to create operations in this
201 // function.
202 builder.setInsertionPointToStart(&entryBlock);
203
204 // Emit the body of the function.
205 if (mlir::failed(mlirGen(*funcAST.getBody()))) {
206 function.erase();
207 return nullptr;
208 }
209
210 // Implicitly return void if no return statement was emitted.
211 // FIXME: we may fix the parser instead to always return the last expression
212 // (this would possibly help the REPL case later)
213 ReturnOp returnOp;
214 if (!entryBlock.empty())
215 returnOp = dyn_cast<ReturnOp>(entryBlock.back());
216 if (!returnOp) {
217 builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
218 } else if (returnOp.hasOperand()) {
219 // Otherwise, if this return operation has an operand then add a result to
220 // the function.
221 function.setType(builder.getFunctionType(function.getType().getInputs(),
222 *returnOp.operand_type_begin()));
223 }
224
225 // If this function isn't main, then set the visibility to private.
226 if (funcAST.getProto()->getName() != "main")
227 function.setPrivate();
228
229 return function;
230 }
231
232 /// Return the struct type that is the result of the given expression, or null
233 /// if it cannot be inferred.
getStructFor(ExprAST * expr)234 StructAST *getStructFor(ExprAST *expr) {
235 llvm::StringRef structName;
236 if (auto *decl = llvm::dyn_cast<VariableExprAST>(expr)) {
237 auto varIt = symbolTable.lookup(decl->getName());
238 if (!varIt.first)
239 return nullptr;
240 structName = varIt.second->getType().name;
241 } else if (auto *access = llvm::dyn_cast<BinaryExprAST>(expr)) {
242 if (access->getOp() != '.')
243 return nullptr;
244 // The name being accessed should be in the RHS.
245 auto *name = llvm::dyn_cast<VariableExprAST>(access->getRHS());
246 if (!name)
247 return nullptr;
248 StructAST *parentStruct = getStructFor(access->getLHS());
249 if (!parentStruct)
250 return nullptr;
251
252 // Get the element within the struct corresponding to the name.
253 VarDeclExprAST *decl = nullptr;
254 for (auto &var : parentStruct->getVariables()) {
255 if (var->getName() == name->getName()) {
256 decl = var.get();
257 break;
258 }
259 }
260 if (!decl)
261 return nullptr;
262 structName = decl->getType().name;
263 }
264 if (structName.empty())
265 return nullptr;
266
267 // If the struct name was valid, check for an entry in the struct map.
268 auto structIt = structMap.find(structName);
269 if (structIt == structMap.end())
270 return nullptr;
271 return structIt->second.second;
272 }
273
274 /// Return the numeric member index of the given struct access expression.
getMemberIndex(BinaryExprAST & accessOp)275 llvm::Optional<size_t> getMemberIndex(BinaryExprAST &accessOp) {
276 assert(accessOp.getOp() == '.' && "expected access operation");
277
278 // Lookup the struct node for the LHS.
279 StructAST *structAST = getStructFor(accessOp.getLHS());
280 if (!structAST)
281 return llvm::None;
282
283 // Get the name from the RHS.
284 VariableExprAST *name = llvm::dyn_cast<VariableExprAST>(accessOp.getRHS());
285 if (!name)
286 return llvm::None;
287
288 auto structVars = structAST->getVariables();
289 auto it = llvm::find_if(structVars, [&](auto &var) {
290 return var->getName() == name->getName();
291 });
292 if (it == structVars.end())
293 return llvm::None;
294 return it - structVars.begin();
295 }
296
297 /// Emit a binary operation
mlirGen(BinaryExprAST & binop)298 mlir::Value mlirGen(BinaryExprAST &binop) {
299 // First emit the operations for each side of the operation before emitting
300 // the operation itself. For example if the expression is `a + foo(a)`
301 // 1) First it will visiting the LHS, which will return a reference to the
302 // value holding `a`. This value should have been emitted at declaration
303 // time and registered in the symbol table, so nothing would be
304 // codegen'd. If the value is not in the symbol table, an error has been
305 // emitted and nullptr is returned.
306 // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
307 // and the result value is returned. If an error occurs we get a nullptr
308 // and propagate.
309 //
310 mlir::Value lhs = mlirGen(*binop.getLHS());
311 if (!lhs)
312 return nullptr;
313 auto location = loc(binop.loc());
314
315 // If this is an access operation, handle it immediately.
316 if (binop.getOp() == '.') {
317 llvm::Optional<size_t> accessIndex = getMemberIndex(binop);
318 if (!accessIndex) {
319 emitError(location, "invalid access into struct expression");
320 return nullptr;
321 }
322 return builder.create<StructAccessOp>(location, lhs, *accessIndex);
323 }
324
325 // Otherwise, this is a normal binary op.
326 mlir::Value rhs = mlirGen(*binop.getRHS());
327 if (!rhs)
328 return nullptr;
329
330 // Derive the operation name from the binary operator. At the moment we only
331 // support '+' and '*'.
332 switch (binop.getOp()) {
333 case '+':
334 return builder.create<AddOp>(location, lhs, rhs);
335 case '*':
336 return builder.create<MulOp>(location, lhs, rhs);
337 }
338
339 emitError(location, "invalid binary operator '") << binop.getOp() << "'";
340 return nullptr;
341 }
342
343 /// This is a reference to a variable in an expression. The variable is
344 /// expected to have been declared and so should have a value in the symbol
345 /// table, otherwise emit an error and return nullptr.
mlirGen(VariableExprAST & expr)346 mlir::Value mlirGen(VariableExprAST &expr) {
347 if (auto variable = symbolTable.lookup(expr.getName()).first)
348 return variable;
349
350 emitError(loc(expr.loc()), "error: unknown variable '")
351 << expr.getName() << "'";
352 return nullptr;
353 }
354
355 /// Emit a return operation. This will return failure if any generation fails.
mlirGen(ReturnExprAST & ret)356 mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
357 auto location = loc(ret.loc());
358
359 // 'return' takes an optional expression, handle that case here.
360 mlir::Value expr = nullptr;
361 if (ret.getExpr().hasValue()) {
362 if (!(expr = mlirGen(*ret.getExpr().getValue())))
363 return mlir::failure();
364 }
365
366 // Otherwise, this return operation has zero operands.
367 builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
368 : ArrayRef<mlir::Value>());
369 return mlir::success();
370 }
371
372 /// Emit a constant for a literal/constant array. It will be emitted as a
373 /// flattened array of data in an Attribute attached to a `toy.constant`
374 /// operation. See documentation on [Attributes](LangRef.md#attributes) for
375 /// more details. Here is an excerpt:
376 ///
377 /// Attributes are the mechanism for specifying constant data in MLIR in
378 /// places where a variable is never allowed [...]. They consist of a name
379 /// and a concrete attribute value. The set of expected attributes, their
380 /// structure, and their interpretation are all contextually dependent on
381 /// what they are attached to.
382 ///
383 /// Example, the source level statement:
384 /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
385 /// will be converted to:
386 /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
387 /// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
388 /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
389 ///
getConstantAttr(LiteralExprAST & lit)390 mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) {
391 // The attribute is a vector with a floating point value per element
392 // (number) in the array, see `collectData()` below for more details.
393 std::vector<double> data;
394 data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
395 std::multiplies<int>()));
396 collectData(lit, data);
397
398 // The type of this attribute is tensor of 64-bit floating-point with the
399 // shape of the literal.
400 mlir::Type elementType = builder.getF64Type();
401 auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
402
403 // This is the actual attribute that holds the list of values for this
404 // tensor literal.
405 return mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
406 }
getConstantAttr(NumberExprAST & lit)407 mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) {
408 // The type of this attribute is tensor of 64-bit floating-point with no
409 // shape.
410 mlir::Type elementType = builder.getF64Type();
411 auto dataType = mlir::RankedTensorType::get({}, elementType);
412
413 // This is the actual attribute that holds the list of values for this
414 // tensor literal.
415 return mlir::DenseElementsAttr::get(dataType,
416 llvm::makeArrayRef(lit.getValue()));
417 }
418 /// Emit a constant for a struct literal. It will be emitted as an array of
419 /// other literals in an Attribute attached to a `toy.struct_constant`
420 /// operation. This function returns the generated constant, along with the
421 /// corresponding struct type.
422 std::pair<mlir::ArrayAttr, mlir::Type>
getConstantAttr(StructLiteralExprAST & lit)423 getConstantAttr(StructLiteralExprAST &lit) {
424 std::vector<mlir::Attribute> attrElements;
425 std::vector<mlir::Type> typeElements;
426
427 for (auto &var : lit.getValues()) {
428 if (auto *number = llvm::dyn_cast<NumberExprAST>(var.get())) {
429 attrElements.push_back(getConstantAttr(*number));
430 typeElements.push_back(getType(llvm::None));
431 } else if (auto *lit = llvm::dyn_cast<LiteralExprAST>(var.get())) {
432 attrElements.push_back(getConstantAttr(*lit));
433 typeElements.push_back(getType(llvm::None));
434 } else {
435 auto *structLit = llvm::cast<StructLiteralExprAST>(var.get());
436 auto attrTypePair = getConstantAttr(*structLit);
437 attrElements.push_back(attrTypePair.first);
438 typeElements.push_back(attrTypePair.second);
439 }
440 }
441 mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements);
442 mlir::Type dataType = StructType::get(typeElements);
443 return std::make_pair(dataAttr, dataType);
444 }
445
446 /// Emit an array literal.
mlirGen(LiteralExprAST & lit)447 mlir::Value mlirGen(LiteralExprAST &lit) {
448 mlir::Type type = getType(lit.getDims());
449 mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit);
450
451 // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
452 // method.
453 return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
454 }
455
456 /// Emit a struct literal. It will be emitted as an array of
457 /// other literals in an Attribute attached to a `toy.struct_constant`
458 /// operation.
mlirGen(StructLiteralExprAST & lit)459 mlir::Value mlirGen(StructLiteralExprAST &lit) {
460 mlir::ArrayAttr dataAttr;
461 mlir::Type dataType;
462 std::tie(dataAttr, dataType) = getConstantAttr(lit);
463
464 // Build the MLIR op `toy.struct_constant`. This invokes the
465 // `StructConstantOp::build` method.
466 return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr);
467 }
468
469 /// Recursive helper function to accumulate the data that compose an array
470 /// literal. It flattens the nested structure in the supplied vector. For
471 /// example with this array:
472 /// [[1, 2], [3, 4]]
473 /// we will generate:
474 /// [ 1, 2, 3, 4 ]
475 /// Individual numbers are represented as doubles.
476 /// Attributes are the way MLIR attaches constant to operations.
collectData(ExprAST & expr,std::vector<double> & data)477 void collectData(ExprAST &expr, std::vector<double> &data) {
478 if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
479 for (auto &value : lit->getValues())
480 collectData(*value, data);
481 return;
482 }
483
484 assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
485 data.push_back(cast<NumberExprAST>(expr).getValue());
486 }
487
488 /// Emit a call expression. It emits specific operations for the `transpose`
489 /// builtin. Other identifiers are assumed to be user-defined functions.
mlirGen(CallExprAST & call)490 mlir::Value mlirGen(CallExprAST &call) {
491 llvm::StringRef callee = call.getCallee();
492 auto location = loc(call.loc());
493
494 // Codegen the operands first.
495 SmallVector<mlir::Value, 4> operands;
496 for (auto &expr : call.getArgs()) {
497 auto arg = mlirGen(*expr);
498 if (!arg)
499 return nullptr;
500 operands.push_back(arg);
501 }
502
503 // Builtin calls have their custom operation, meaning this is a
504 // straightforward emission.
505 if (callee == "transpose") {
506 if (call.getArgs().size() != 1) {
507 emitError(location, "MLIR codegen encountered an error: toy.transpose "
508 "does not accept multiple arguments");
509 return nullptr;
510 }
511 return builder.create<TransposeOp>(location, operands[0]);
512 }
513
514 // Otherwise this is a call to a user-defined function. Calls to
515 // user-defined functions are mapped to a custom call that takes the callee
516 // name as an attribute.
517 auto calledFuncIt = functionMap.find(callee);
518 if (calledFuncIt == functionMap.end()) {
519 emitError(location) << "no defined function found for '" << callee << "'";
520 return nullptr;
521 }
522 mlir::FuncOp calledFunc = calledFuncIt->second;
523 return builder.create<GenericCallOp>(
524 location, calledFunc.getType().getResult(0),
525 builder.getSymbolRefAttr(callee), operands);
526 }
527
528 /// Emit a print expression. It emits specific operations for two builtins:
529 /// transpose(x) and print(x).
mlirGen(PrintExprAST & call)530 mlir::LogicalResult mlirGen(PrintExprAST &call) {
531 auto arg = mlirGen(*call.getArg());
532 if (!arg)
533 return mlir::failure();
534
535 builder.create<PrintOp>(loc(call.loc()), arg);
536 return mlir::success();
537 }
538
539 /// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlirGen(NumberExprAST & num)540 mlir::Value mlirGen(NumberExprAST &num) {
541 return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
542 }
543
544 /// Dispatch codegen for the right expression subclass using RTTI.
mlirGen(ExprAST & expr)545 mlir::Value mlirGen(ExprAST &expr) {
546 switch (expr.getKind()) {
547 case toy::ExprAST::Expr_BinOp:
548 return mlirGen(cast<BinaryExprAST>(expr));
549 case toy::ExprAST::Expr_Var:
550 return mlirGen(cast<VariableExprAST>(expr));
551 case toy::ExprAST::Expr_Literal:
552 return mlirGen(cast<LiteralExprAST>(expr));
553 case toy::ExprAST::Expr_StructLiteral:
554 return mlirGen(cast<StructLiteralExprAST>(expr));
555 case toy::ExprAST::Expr_Call:
556 return mlirGen(cast<CallExprAST>(expr));
557 case toy::ExprAST::Expr_Num:
558 return mlirGen(cast<NumberExprAST>(expr));
559 default:
560 emitError(loc(expr.loc()))
561 << "MLIR codegen encountered an unhandled expr kind '"
562 << Twine(expr.getKind()) << "'";
563 return nullptr;
564 }
565 }
566
567 /// Handle a variable declaration, we'll codegen the expression that forms the
568 /// initializer and record the value in the symbol table before returning it.
569 /// Future expressions will be able to reference this variable through symbol
570 /// table lookup.
mlirGen(VarDeclExprAST & vardecl)571 mlir::Value mlirGen(VarDeclExprAST &vardecl) {
572 auto init = vardecl.getInitVal();
573 if (!init) {
574 emitError(loc(vardecl.loc()),
575 "missing initializer in variable declaration");
576 return nullptr;
577 }
578
579 mlir::Value value = mlirGen(*init);
580 if (!value)
581 return nullptr;
582
583 // Handle the case where we are initializing a struct value.
584 VarType varType = vardecl.getType();
585 if (!varType.name.empty()) {
586 // Check that the initializer type is the same as the variable
587 // declaration.
588 mlir::Type type = getType(varType, vardecl.loc());
589 if (!type)
590 return nullptr;
591 if (type != value.getType()) {
592 emitError(loc(vardecl.loc()))
593 << "struct type of initializer is different than the variable "
594 "declaration. Got "
595 << value.getType() << ", but expected " << type;
596 return nullptr;
597 }
598
599 // Otherwise, we have the initializer value, but in case the variable was
600 // declared with specific shape, we emit a "reshape" operation. It will
601 // get optimized out later as needed.
602 } else if (!varType.shape.empty()) {
603 value = builder.create<ReshapeOp>(loc(vardecl.loc()),
604 getType(varType.shape), value);
605 }
606
607 // Register the value in the symbol table.
608 if (failed(declare(vardecl, value)))
609 return nullptr;
610 return value;
611 }
612
613 /// Codegen a list of expression, return failure if one of them hit an error.
mlirGen(ExprASTList & blockAST)614 mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
615 SymbolTableScopeT var_scope(symbolTable);
616 for (auto &expr : blockAST) {
617 // Specific handling for variable declarations, return statement, and
618 // print. These can only appear in block list and not in nested
619 // expressions.
620 if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
621 if (!mlirGen(*vardecl))
622 return mlir::failure();
623 continue;
624 }
625 if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
626 return mlirGen(*ret);
627 if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
628 if (mlir::failed(mlirGen(*print)))
629 return mlir::success();
630 continue;
631 }
632
633 // Generic expression dispatch codegen.
634 if (!mlirGen(*expr))
635 return mlir::failure();
636 }
637 return mlir::success();
638 }
639
640 /// Build a tensor type from a list of shape dimensions.
getType(ArrayRef<int64_t> shape)641 mlir::Type getType(ArrayRef<int64_t> shape) {
642 // If the shape is empty, then this type is unranked.
643 if (shape.empty())
644 return mlir::UnrankedTensorType::get(builder.getF64Type());
645
646 // Otherwise, we use the given shape.
647 return mlir::RankedTensorType::get(shape, builder.getF64Type());
648 }
649
650 /// Build an MLIR type from a Toy AST variable type (forward to the generic
651 /// getType above for non-struct types).
getType(const VarType & type,const Location & location)652 mlir::Type getType(const VarType &type, const Location &location) {
653 if (!type.name.empty()) {
654 auto it = structMap.find(type.name);
655 if (it == structMap.end()) {
656 emitError(loc(location))
657 << "error: unknown struct type '" << type.name << "'";
658 return nullptr;
659 }
660 return it->second.first;
661 }
662
663 return getType(type.shape);
664 }
665 };
666
667 } // namespace
668
669 namespace toy {
670
671 // The public API for codegen.
mlirGen(mlir::MLIRContext & context,ModuleAST & moduleAST)672 mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
673 ModuleAST &moduleAST) {
674 return MLIRGenImpl(context).mlirGen(moduleAST);
675 }
676
677 } // namespace toy
678