1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/Dialect.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/OpImplementation.h"
19 
20 using namespace mlir;
21 using namespace mlir::toy;
22 
23 #include "toy/Dialect.cpp.inc"
24 
25 //===----------------------------------------------------------------------===//
26 // ToyDialect
27 //===----------------------------------------------------------------------===//
28 
29 /// Dialect initialization, the instance will be owned by the context. This is
30 /// the point of registration of types and operations for the dialect.
initialize()31 void ToyDialect::initialize() {
32   addOperations<
33 #define GET_OP_LIST
34 #include "toy/Ops.cpp.inc"
35       >();
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Toy Operations
40 //===----------------------------------------------------------------------===//
41 
42 /// A generalized parser for binary operations. This parses the different forms
43 /// of 'printBinaryOp' below.
parseBinaryOp(mlir::OpAsmParser & parser,mlir::OperationState & result)44 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
45                                        mlir::OperationState &result) {
46   SmallVector<mlir::OpAsmParser::OperandType, 2> operands;
47   llvm::SMLoc operandsLoc = parser.getCurrentLocation();
48   Type type;
49   if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
50       parser.parseOptionalAttrDict(result.attributes) ||
51       parser.parseColonType(type))
52     return mlir::failure();
53 
54   // If the type is a function type, it contains the input and result types of
55   // this operation.
56   if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
57     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
58                                result.operands))
59       return mlir::failure();
60     result.addTypes(funcType.getResults());
61     return mlir::success();
62   }
63 
64   // Otherwise, the parsed type is the type of both operands and results.
65   if (parser.resolveOperands(operands, type, result.operands))
66     return mlir::failure();
67   result.addTypes(type);
68   return mlir::success();
69 }
70 
71 /// A generalized printer for binary operations. It prints in two different
72 /// forms depending on if all of the types match.
printBinaryOp(mlir::OpAsmPrinter & printer,mlir::Operation * op)73 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
74   printer << " " << op->getOperands();
75   printer.printOptionalAttrDict(op->getAttrs());
76   printer << " : ";
77 
78   // If all of the types are the same, print the type directly.
79   Type resultType = *op->result_type_begin();
80   if (llvm::all_of(op->getOperandTypes(),
81                    [=](Type type) { return type == resultType; })) {
82     printer << resultType;
83     return;
84   }
85 
86   // Otherwise, print a functional type.
87   printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // ConstantOp
92 
93 /// Build a constant operation.
94 /// The builder is passed as an argument, so is the state that this method is
95 /// expected to fill in order to build the operation.
build(mlir::OpBuilder & builder,mlir::OperationState & state,double value)96 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
97                        double value) {
98   auto dataType = RankedTensorType::get({}, builder.getF64Type());
99   auto dataAttribute = DenseElementsAttr::get(dataType, value);
100   ConstantOp::build(builder, state, dataType, dataAttribute);
101 }
102 
103 /// The 'OpAsmParser' class provides a collection of methods for parsing
104 /// various punctuation, as well as attributes, operands, types, etc. Each of
105 /// these methods returns a `ParseResult`. This class is a wrapper around
106 /// `LogicalResult` that can be converted to a boolean `true` value on failure,
107 /// or `false` on success. This allows for easily chaining together a set of
108 /// parser rules. These rules are used to populate an `mlir::OperationState`
109 /// similarly to the `build` methods described above.
parseConstantOp(mlir::OpAsmParser & parser,mlir::OperationState & result)110 static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
111                                          mlir::OperationState &result) {
112   mlir::DenseElementsAttr value;
113   if (parser.parseOptionalAttrDict(result.attributes) ||
114       parser.parseAttribute(value, "value", result.attributes))
115     return failure();
116 
117   result.addTypes(value.getType());
118   return success();
119 }
120 
121 /// The 'OpAsmPrinter' class is a stream that allows for formatting
122 /// strings, attributes, operands, types, etc.
print(mlir::OpAsmPrinter & printer,ConstantOp op)123 static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
124   printer << " ";
125   printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
126   printer << op.value();
127 }
128 
129 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
130 /// in the op definition.
verify(ConstantOp op)131 static mlir::LogicalResult verify(ConstantOp op) {
132   // If the return type of the constant is not an unranked tensor, the shape
133   // must match the shape of the attribute holding the data.
134   auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
135   if (!resultType)
136     return success();
137 
138   // Check that the rank of the attribute type matches the rank of the constant
139   // result type.
140   auto attrType = op.value().getType().cast<mlir::TensorType>();
141   if (attrType.getRank() != resultType.getRank()) {
142     return op.emitOpError(
143                "return type must match the one of the attached value "
144                "attribute: ")
145            << attrType.getRank() << " != " << resultType.getRank();
146   }
147 
148   // Check that each of the dimensions match between the two types.
149   for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
150     if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
151       return op.emitOpError(
152                  "return type shape mismatches its attribute at dimension ")
153              << dim << ": " << attrType.getShape()[dim]
154              << " != " << resultType.getShape()[dim];
155     }
156   }
157   return mlir::success();
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // AddOp
162 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)163 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
164                   mlir::Value lhs, mlir::Value rhs) {
165   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
166   state.addOperands({lhs, rhs});
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // GenericCallOp
171 
build(mlir::OpBuilder & builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)172 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
173                           StringRef callee, ArrayRef<mlir::Value> arguments) {
174   // Generic call always returns an unranked Tensor initially.
175   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
176   state.addOperands(arguments);
177   state.addAttribute("callee",
178                      mlir::SymbolRefAttr::get(builder.getContext(), callee));
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // MulOp
183 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)184 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
185                   mlir::Value lhs, mlir::Value rhs) {
186   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
187   state.addOperands({lhs, rhs});
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // ReturnOp
192 
verify(ReturnOp op)193 static mlir::LogicalResult verify(ReturnOp op) {
194   // We know that the parent operation is a function, because of the 'HasParent'
195   // trait attached to the operation definition.
196   auto function = cast<FuncOp>(op->getParentOp());
197 
198   /// ReturnOps can only have a single optional operand.
199   if (op.getNumOperands() > 1)
200     return op.emitOpError() << "expects at most 1 return operand";
201 
202   // The operand number and types must match the function signature.
203   const auto &results = function.getType().getResults();
204   if (op.getNumOperands() != results.size())
205     return op.emitOpError()
206            << "does not return the same number of values ("
207            << op.getNumOperands() << ") as the enclosing function ("
208            << results.size() << ")";
209 
210   // If the operation does not have an input, we are done.
211   if (!op.hasOperand())
212     return mlir::success();
213 
214   auto inputType = *op.operand_type_begin();
215   auto resultType = results.front();
216 
217   // Check that the result type of the function matches the operand type.
218   if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
219       resultType.isa<mlir::UnrankedTensorType>())
220     return mlir::success();
221 
222   return op.emitError() << "type of return operand (" << inputType
223                         << ") doesn't match function result type ("
224                         << resultType << ")";
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // TransposeOp
229 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value value)230 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
231                         mlir::Value value) {
232   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
233   state.addOperands(value);
234 }
235 
verify(TransposeOp op)236 static mlir::LogicalResult verify(TransposeOp op) {
237   auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
238   auto resultType = op.getType().dyn_cast<RankedTensorType>();
239   if (!inputType || !resultType)
240     return mlir::success();
241 
242   auto inputShape = inputType.getShape();
243   if (!std::equal(inputShape.begin(), inputShape.end(),
244                   resultType.getShape().rbegin())) {
245     return op.emitError()
246            << "expected result shape to be a transpose of the input";
247   }
248   return mlir::success();
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // TableGen'd op method definitions
253 //===----------------------------------------------------------------------===//
254 
255 #define GET_OP_CLASSES
256 #include "toy/Ops.cpp.inc"
257