1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/Dialect.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/OpImplementation.h"
19 
20 using namespace mlir;
21 using namespace mlir::toy;
22 
23 //===----------------------------------------------------------------------===//
24 // ToyDialect
25 //===----------------------------------------------------------------------===//
26 
27 /// Dialect creation, the instance will be owned by the context. This is the
28 /// point of registration of custom types and operations for the dialect.
ToyDialect(mlir::MLIRContext * ctx)29 ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
30     : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
31   addOperations<
32 #define GET_OP_LIST
33 #include "toy/Ops.cpp.inc"
34       >();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Toy Operations
39 //===----------------------------------------------------------------------===//
40 
41 /// A generalized parser for binary operations. This parses the different forms
42 /// of 'printBinaryOp' below.
parseBinaryOp(mlir::OpAsmParser & parser,mlir::OperationState & result)43 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
44                                        mlir::OperationState &result) {
45   SmallVector<mlir::OpAsmParser::OperandType, 2> operands;
46   llvm::SMLoc operandsLoc = parser.getCurrentLocation();
47   Type type;
48   if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
49       parser.parseOptionalAttrDict(result.attributes) ||
50       parser.parseColonType(type))
51     return mlir::failure();
52 
53   // If the type is a function type, it contains the input and result types of
54   // this operation.
55   if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
56     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
57                                result.operands))
58       return mlir::failure();
59     result.addTypes(funcType.getResults());
60     return mlir::success();
61   }
62 
63   // Otherwise, the parsed type is the type of both operands and results.
64   if (parser.resolveOperands(operands, type, result.operands))
65     return mlir::failure();
66   result.addTypes(type);
67   return mlir::success();
68 }
69 
70 /// A generalized printer for binary operations. It prints in two different
71 /// forms depending on if all of the types match.
printBinaryOp(mlir::OpAsmPrinter & printer,mlir::Operation * op)72 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
73   printer << op->getName() << " " << op->getOperands();
74   printer.printOptionalAttrDict(op->getAttrs());
75   printer << " : ";
76 
77   // If all of the types are the same, print the type directly.
78   Type resultType = *op->result_type_begin();
79   if (llvm::all_of(op->getOperandTypes(),
80                    [=](Type type) { return type == resultType; })) {
81     printer << resultType;
82     return;
83   }
84 
85   // Otherwise, print a functional type.
86   printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // ConstantOp
91 
92 /// Build a constant operation.
93 /// The builder is passed as an argument, so is the state that this method is
94 /// expected to fill in order to build the operation.
build(mlir::OpBuilder & builder,mlir::OperationState & state,double value)95 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
96                        double value) {
97   auto dataType = RankedTensorType::get({}, builder.getF64Type());
98   auto dataAttribute = DenseElementsAttr::get(dataType, value);
99   ConstantOp::build(builder, state, dataType, dataAttribute);
100 }
101 
102 /// The 'OpAsmParser' class provides a collection of methods for parsing
103 /// various punctuation, as well as attributes, operands, types, etc. Each of
104 /// these methods returns a `ParseResult`. This class is a wrapper around
105 /// `LogicalResult` that can be converted to a boolean `true` value on failure,
106 /// or `false` on success. This allows for easily chaining together a set of
107 /// parser rules. These rules are used to populate an `mlir::OperationState`
108 /// similarly to the `build` methods described above.
parseConstantOp(mlir::OpAsmParser & parser,mlir::OperationState & result)109 static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
110                                          mlir::OperationState &result) {
111   mlir::DenseElementsAttr value;
112   if (parser.parseOptionalAttrDict(result.attributes) ||
113       parser.parseAttribute(value, "value", result.attributes))
114     return failure();
115 
116   result.addTypes(value.getType());
117   return success();
118 }
119 
120 /// The 'OpAsmPrinter' class is a stream that allows for formatting
121 /// strings, attributes, operands, types, etc.
print(mlir::OpAsmPrinter & printer,ConstantOp op)122 static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
123   printer << "toy.constant ";
124   printer.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
125   printer << op.value();
126 }
127 
128 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
129 /// in the op definition.
verify(ConstantOp op)130 static mlir::LogicalResult verify(ConstantOp op) {
131   // If the return type of the constant is not an unranked tensor, the shape
132   // must match the shape of the attribute holding the data.
133   auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
134   if (!resultType)
135     return success();
136 
137   // Check that the rank of the attribute type matches the rank of the constant
138   // result type.
139   auto attrType = op.value().getType().cast<mlir::TensorType>();
140   if (attrType.getRank() != resultType.getRank()) {
141     return op.emitOpError(
142                "return type must match the one of the attached value "
143                "attribute: ")
144            << attrType.getRank() << " != " << resultType.getRank();
145   }
146 
147   // Check that each of the dimensions match between the two types.
148   for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
149     if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
150       return op.emitOpError(
151                  "return type shape mismatches its attribute at dimension ")
152              << dim << ": " << attrType.getShape()[dim]
153              << " != " << resultType.getShape()[dim];
154     }
155   }
156   return mlir::success();
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // AddOp
161 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)162 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
163                   mlir::Value lhs, mlir::Value rhs) {
164   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
165   state.addOperands({lhs, rhs});
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // GenericCallOp
170 
build(mlir::OpBuilder & builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)171 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
172                           StringRef callee, ArrayRef<mlir::Value> arguments) {
173   // Generic call always returns an unranked Tensor initially.
174   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
175   state.addOperands(arguments);
176   state.addAttribute("callee", builder.getSymbolRefAttr(callee));
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // MulOp
181 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)182 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
183                   mlir::Value lhs, mlir::Value rhs) {
184   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
185   state.addOperands({lhs, rhs});
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // ReturnOp
190 
verify(ReturnOp op)191 static mlir::LogicalResult verify(ReturnOp op) {
192   // We know that the parent operation is a function, because of the 'HasParent'
193   // trait attached to the operation definition.
194   auto function = cast<FuncOp>(op->getParentOp());
195 
196   /// ReturnOps can only have a single optional operand.
197   if (op.getNumOperands() > 1)
198     return op.emitOpError() << "expects at most 1 return operand";
199 
200   // The operand number and types must match the function signature.
201   const auto &results = function.getType().getResults();
202   if (op.getNumOperands() != results.size())
203     return op.emitOpError()
204            << "does not return the same number of values ("
205            << op.getNumOperands() << ") as the enclosing function ("
206            << results.size() << ")";
207 
208   // If the operation does not have an input, we are done.
209   if (!op.hasOperand())
210     return mlir::success();
211 
212   auto inputType = *op.operand_type_begin();
213   auto resultType = results.front();
214 
215   // Check that the result type of the function matches the operand type.
216   if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
217       resultType.isa<mlir::UnrankedTensorType>())
218     return mlir::success();
219 
220   return op.emitError() << "type of return operand (" << inputType
221                         << ") doesn't match function result type ("
222                         << resultType << ")";
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // TransposeOp
227 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value value)228 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
229                         mlir::Value value) {
230   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
231   state.addOperands(value);
232 }
233 
verify(TransposeOp op)234 static mlir::LogicalResult verify(TransposeOp op) {
235   auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
236   auto resultType = op.getType().dyn_cast<RankedTensorType>();
237   if (!inputType || !resultType)
238     return mlir::success();
239 
240   auto inputShape = inputType.getShape();
241   if (!std::equal(inputShape.begin(), inputShape.end(),
242                   resultType.getShape().rbegin())) {
243     return op.emitError()
244            << "expected result shape to be a transpose of the input";
245   }
246   return mlir::success();
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // TableGen'd op method definitions
251 //===----------------------------------------------------------------------===//
252 
253 #define GET_OP_CLASSES
254 #include "toy/Ops.cpp.inc"
255