1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "toy/Dialect.h"
15
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/OpImplementation.h"
19
20 using namespace mlir;
21 using namespace mlir::toy;
22
23 #include "toy/Dialect.cpp.inc"
24
25 //===----------------------------------------------------------------------===//
26 // ToyDialect
27 //===----------------------------------------------------------------------===//
28
29 /// Dialect initialization, the instance will be owned by the context. This is
30 /// the point of registration of types and operations for the dialect.
initialize()31 void ToyDialect::initialize() {
32 addOperations<
33 #define GET_OP_LIST
34 #include "toy/Ops.cpp.inc"
35 >();
36 }
37
38 //===----------------------------------------------------------------------===//
39 // Toy Operations
40 //===----------------------------------------------------------------------===//
41
42 /// A generalized parser for binary operations. This parses the different forms
43 /// of 'printBinaryOp' below.
parseBinaryOp(mlir::OpAsmParser & parser,mlir::OperationState & result)44 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
45 mlir::OperationState &result) {
46 SmallVector<mlir::OpAsmParser::OperandType, 2> operands;
47 llvm::SMLoc operandsLoc = parser.getCurrentLocation();
48 Type type;
49 if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
50 parser.parseOptionalAttrDict(result.attributes) ||
51 parser.parseColonType(type))
52 return mlir::failure();
53
54 // If the type is a function type, it contains the input and result types of
55 // this operation.
56 if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
57 if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
58 result.operands))
59 return mlir::failure();
60 result.addTypes(funcType.getResults());
61 return mlir::success();
62 }
63
64 // Otherwise, the parsed type is the type of both operands and results.
65 if (parser.resolveOperands(operands, type, result.operands))
66 return mlir::failure();
67 result.addTypes(type);
68 return mlir::success();
69 }
70
71 /// A generalized printer for binary operations. It prints in two different
72 /// forms depending on if all of the types match.
printBinaryOp(mlir::OpAsmPrinter & printer,mlir::Operation * op)73 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
74 printer << " " << op->getOperands();
75 printer.printOptionalAttrDict(op->getAttrs());
76 printer << " : ";
77
78 // If all of the types are the same, print the type directly.
79 Type resultType = *op->result_type_begin();
80 if (llvm::all_of(op->getOperandTypes(),
81 [=](Type type) { return type == resultType; })) {
82 printer << resultType;
83 return;
84 }
85
86 // Otherwise, print a functional type.
87 printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
88 }
89
90 //===----------------------------------------------------------------------===//
91 // ConstantOp
92
93 /// Build a constant operation.
94 /// The builder is passed as an argument, so is the state that this method is
95 /// expected to fill in order to build the operation.
build(mlir::OpBuilder & builder,mlir::OperationState & state,double value)96 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
97 double value) {
98 auto dataType = RankedTensorType::get({}, builder.getF64Type());
99 auto dataAttribute = DenseElementsAttr::get(dataType, value);
100 ConstantOp::build(builder, state, dataType, dataAttribute);
101 }
102
103 /// The 'OpAsmParser' class provides a collection of methods for parsing
104 /// various punctuation, as well as attributes, operands, types, etc. Each of
105 /// these methods returns a `ParseResult`. This class is a wrapper around
106 /// `LogicalResult` that can be converted to a boolean `true` value on failure,
107 /// or `false` on success. This allows for easily chaining together a set of
108 /// parser rules. These rules are used to populate an `mlir::OperationState`
109 /// similarly to the `build` methods described above.
parseConstantOp(mlir::OpAsmParser & parser,mlir::OperationState & result)110 static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
111 mlir::OperationState &result) {
112 mlir::DenseElementsAttr value;
113 if (parser.parseOptionalAttrDict(result.attributes) ||
114 parser.parseAttribute(value, "value", result.attributes))
115 return failure();
116
117 result.addTypes(value.getType());
118 return success();
119 }
120
121 /// The 'OpAsmPrinter' class is a stream that allows for formatting
122 /// strings, attributes, operands, types, etc.
print(mlir::OpAsmPrinter & printer,ConstantOp op)123 static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
124 printer << " ";
125 printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
126 printer << op.value();
127 }
128
129 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
130 /// in the op definition.
verify(ConstantOp op)131 static mlir::LogicalResult verify(ConstantOp op) {
132 // If the return type of the constant is not an unranked tensor, the shape
133 // must match the shape of the attribute holding the data.
134 auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
135 if (!resultType)
136 return success();
137
138 // Check that the rank of the attribute type matches the rank of the constant
139 // result type.
140 auto attrType = op.value().getType().cast<mlir::TensorType>();
141 if (attrType.getRank() != resultType.getRank()) {
142 return op.emitOpError(
143 "return type must match the one of the attached value "
144 "attribute: ")
145 << attrType.getRank() << " != " << resultType.getRank();
146 }
147
148 // Check that each of the dimensions match between the two types.
149 for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
150 if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
151 return op.emitOpError(
152 "return type shape mismatches its attribute at dimension ")
153 << dim << ": " << attrType.getShape()[dim]
154 << " != " << resultType.getShape()[dim];
155 }
156 }
157 return mlir::success();
158 }
159
160 //===----------------------------------------------------------------------===//
161 // AddOp
162
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)163 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
164 mlir::Value lhs, mlir::Value rhs) {
165 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
166 state.addOperands({lhs, rhs});
167 }
168
169 //===----------------------------------------------------------------------===//
170 // GenericCallOp
171
build(mlir::OpBuilder & builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)172 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
173 StringRef callee, ArrayRef<mlir::Value> arguments) {
174 // Generic call always returns an unranked Tensor initially.
175 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
176 state.addOperands(arguments);
177 state.addAttribute("callee",
178 mlir::SymbolRefAttr::get(builder.getContext(), callee));
179 }
180
181 //===----------------------------------------------------------------------===//
182 // MulOp
183
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)184 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
185 mlir::Value lhs, mlir::Value rhs) {
186 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
187 state.addOperands({lhs, rhs});
188 }
189
190 //===----------------------------------------------------------------------===//
191 // ReturnOp
192
verify(ReturnOp op)193 static mlir::LogicalResult verify(ReturnOp op) {
194 // We know that the parent operation is a function, because of the 'HasParent'
195 // trait attached to the operation definition.
196 auto function = cast<FuncOp>(op->getParentOp());
197
198 /// ReturnOps can only have a single optional operand.
199 if (op.getNumOperands() > 1)
200 return op.emitOpError() << "expects at most 1 return operand";
201
202 // The operand number and types must match the function signature.
203 const auto &results = function.getType().getResults();
204 if (op.getNumOperands() != results.size())
205 return op.emitOpError()
206 << "does not return the same number of values ("
207 << op.getNumOperands() << ") as the enclosing function ("
208 << results.size() << ")";
209
210 // If the operation does not have an input, we are done.
211 if (!op.hasOperand())
212 return mlir::success();
213
214 auto inputType = *op.operand_type_begin();
215 auto resultType = results.front();
216
217 // Check that the result type of the function matches the operand type.
218 if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
219 resultType.isa<mlir::UnrankedTensorType>())
220 return mlir::success();
221
222 return op.emitError() << "type of return operand (" << inputType
223 << ") doesn't match function result type ("
224 << resultType << ")";
225 }
226
227 //===----------------------------------------------------------------------===//
228 // TransposeOp
229
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value value)230 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
231 mlir::Value value) {
232 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
233 state.addOperands(value);
234 }
235
verify(TransposeOp op)236 static mlir::LogicalResult verify(TransposeOp op) {
237 auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
238 auto resultType = op.getType().dyn_cast<RankedTensorType>();
239 if (!inputType || !resultType)
240 return mlir::success();
241
242 auto inputShape = inputType.getShape();
243 if (!std::equal(inputShape.begin(), inputShape.end(),
244 resultType.getShape().rbegin())) {
245 return op.emitError()
246 << "expected result shape to be a transpose of the input";
247 }
248 return mlir::success();
249 }
250
251 //===----------------------------------------------------------------------===//
252 // TableGen'd op method definitions
253 //===----------------------------------------------------------------------===//
254
255 #define GET_OP_CLASSES
256 #include "toy/Ops.cpp.inc"
257