1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "toy/Dialect.h"
15
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/StandardTypes.h"
19 #include "mlir/Transforms/InliningUtils.h"
20
21 using namespace mlir;
22 using namespace mlir::toy;
23
24 //===----------------------------------------------------------------------===//
25 // ToyInlinerInterface
26 //===----------------------------------------------------------------------===//
27
28 /// This class defines the interface for handling inlining with Toy
29 /// operations.
30 struct ToyInlinerInterface : public DialectInlinerInterface {
31 using DialectInlinerInterface::DialectInlinerInterface;
32
33 //===--------------------------------------------------------------------===//
34 // Analysis Hooks
35 //===--------------------------------------------------------------------===//
36
37 /// All operations within toy can be inlined.
isLegalToInlineToyInlinerInterface38 bool isLegalToInline(Operation *, Region *,
39 BlockAndValueMapping &) const final {
40 return true;
41 }
42
43 //===--------------------------------------------------------------------===//
44 // Transformation Hooks
45 //===--------------------------------------------------------------------===//
46
47 /// Handle the given inlined terminator(toy.return) by replacing it with a new
48 /// operation as necessary.
handleTerminatorToyInlinerInterface49 void handleTerminator(Operation *op,
50 ArrayRef<Value> valuesToRepl) const final {
51 // Only "toy.return" needs to be handled here.
52 auto returnOp = cast<ReturnOp>(op);
53
54 // Replace the values directly with the return operands.
55 assert(returnOp.getNumOperands() == valuesToRepl.size());
56 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
57 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
58 }
59
60 /// Attempts to materialize a conversion for a type mismatch between a call
61 /// from this dialect, and a callable region. This method should generate an
62 /// operation that takes 'input' as the only operand, and produces a single
63 /// result of 'resultType'. If a conversion can not be generated, nullptr
64 /// should be returned.
materializeCallConversionToyInlinerInterface65 Operation *materializeCallConversion(OpBuilder &builder, Value input,
66 Type resultType,
67 Location conversionLoc) const final {
68 return builder.create<CastOp>(conversionLoc, resultType, input);
69 }
70 };
71
72 //===----------------------------------------------------------------------===//
73 // ToyDialect
74 //===----------------------------------------------------------------------===//
75
76 /// Dialect creation, the instance will be owned by the context. This is the
77 /// point of registration of custom types and operations for the dialect.
ToyDialect(mlir::MLIRContext * ctx)78 ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
79 addOperations<
80 #define GET_OP_LIST
81 #include "toy/Ops.cpp.inc"
82 >();
83 addInterfaces<ToyInlinerInterface>();
84 addTypes<StructType>();
85 }
86
materializeConstant(mlir::OpBuilder & builder,mlir::Attribute value,mlir::Type type,mlir::Location loc)87 mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
88 mlir::Attribute value,
89 mlir::Type type,
90 mlir::Location loc) {
91 if (type.isa<StructType>())
92 return builder.create<StructConstantOp>(loc, type,
93 value.cast<mlir::ArrayAttr>());
94 return builder.create<ConstantOp>(loc, type,
95 value.cast<mlir::DenseElementsAttr>());
96 }
97
98 //===----------------------------------------------------------------------===//
99 // Toy Operations
100 //===----------------------------------------------------------------------===//
101
102 //===----------------------------------------------------------------------===//
103 // ConstantOp
104
105 /// Build a constant operation.
106 /// The builder is passed as an argument, so is the state that this method is
107 /// expected to fill in order to build the operation.
build(mlir::Builder * builder,mlir::OperationState & state,double value)108 void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
109 double value) {
110 auto dataType = RankedTensorType::get({}, builder->getF64Type());
111 auto dataAttribute = DenseElementsAttr::get(dataType, value);
112 ConstantOp::build(builder, state, dataType, dataAttribute);
113 }
114
115 /// Verify that the given attribute value is valid for the given type.
verifyConstantForType(mlir::Type type,mlir::Attribute opaqueValue,mlir::Operation * op)116 static mlir::LogicalResult verifyConstantForType(mlir::Type type,
117 mlir::Attribute opaqueValue,
118 mlir::Operation *op) {
119 if (type.isa<mlir::TensorType>()) {
120 // Check that the value is a elements attribute.
121 auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
122 if (!attrValue)
123 return op->emitError("constant of TensorType must be initialized by "
124 "a DenseFPElementsAttr, got ")
125 << opaqueValue;
126
127 // If the return type of the constant is not an unranked tensor, the shape
128 // must match the shape of the attribute holding the data.
129 auto resultType = type.dyn_cast<mlir::RankedTensorType>();
130 if (!resultType)
131 return success();
132
133 // Check that the rank of the attribute type matches the rank of the
134 // constant result type.
135 auto attrType = attrValue.getType().cast<mlir::TensorType>();
136 if (attrType.getRank() != resultType.getRank()) {
137 return op->emitOpError("return type must match the one of the attached "
138 "value attribute: ")
139 << attrType.getRank() << " != " << resultType.getRank();
140 }
141
142 // Check that each of the dimensions match between the two types.
143 for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
144 if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
145 return op->emitOpError(
146 "return type shape mismatches its attribute at dimension ")
147 << dim << ": " << attrType.getShape()[dim]
148 << " != " << resultType.getShape()[dim];
149 }
150 }
151 return mlir::success();
152 }
153 auto resultType = type.cast<StructType>();
154 llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
155
156 // Verify that the initializer is an Array.
157 auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
158 if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
159 return op->emitError("constant of StructType must be initialized by an "
160 "ArrayAttr with the same number of elements, got ")
161 << opaqueValue;
162
163 // Check that each of the elements are valid.
164 llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
165 for (const auto &it : llvm::zip(resultElementTypes, attrElementValues))
166 if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
167 return mlir::failure();
168 return mlir::success();
169 }
170
171 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
172 /// in the op definition.
verify(ConstantOp op)173 static mlir::LogicalResult verify(ConstantOp op) {
174 return verifyConstantForType(op.getResult().getType(), op.value(), op);
175 }
176
verify(StructConstantOp op)177 static mlir::LogicalResult verify(StructConstantOp op) {
178 return verifyConstantForType(op.getResult().getType(), op.value(), op);
179 }
180
181 /// Infer the output shape of the ConstantOp, this is required by the shape
182 /// inference interface.
inferShapes()183 void ConstantOp::inferShapes() { getResult().setType(value().getType()); }
184
185 //===----------------------------------------------------------------------===//
186 // AddOp
187
build(mlir::Builder * builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)188 void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
189 mlir::Value lhs, mlir::Value rhs) {
190 state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
191 state.addOperands({lhs, rhs});
192 }
193
194 /// Infer the output shape of the AddOp, this is required by the shape inference
195 /// interface.
inferShapes()196 void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
197
198 //===----------------------------------------------------------------------===//
199 // CastOp
200
201 /// Infer the output shape of the CastOp, this is required by the shape
202 /// inference interface.
inferShapes()203 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
204
205 //===----------------------------------------------------------------------===//
206 // GenericCallOp
207
build(mlir::Builder * builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)208 void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
209 StringRef callee, ArrayRef<mlir::Value> arguments) {
210 // Generic call always returns an unranked Tensor initially.
211 state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
212 state.addOperands(arguments);
213 state.addAttribute("callee", builder->getSymbolRefAttr(callee));
214 }
215
216 /// Return the callee of the generic call operation, this is required by the
217 /// call interface.
getCallableForCallee()218 CallInterfaceCallable GenericCallOp::getCallableForCallee() {
219 return getAttrOfType<SymbolRefAttr>("callee");
220 }
221
222 /// Get the argument operands to the called function, this is required by the
223 /// call interface.
getArgOperands()224 Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
225
226 //===----------------------------------------------------------------------===//
227 // MulOp
228
build(mlir::Builder * builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)229 void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
230 mlir::Value lhs, mlir::Value rhs) {
231 state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
232 state.addOperands({lhs, rhs});
233 }
234
235 /// Infer the output shape of the MulOp, this is required by the shape inference
236 /// interface.
inferShapes()237 void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
238
239 //===----------------------------------------------------------------------===//
240 // ReturnOp
241
verify(ReturnOp op)242 static mlir::LogicalResult verify(ReturnOp op) {
243 // We know that the parent operation is a function, because of the 'HasParent'
244 // trait attached to the operation definition.
245 auto function = cast<FuncOp>(op.getParentOp());
246
247 /// ReturnOps can only have a single optional operand.
248 if (op.getNumOperands() > 1)
249 return op.emitOpError() << "expects at most 1 return operand";
250
251 // The operand number and types must match the function signature.
252 const auto &results = function.getType().getResults();
253 if (op.getNumOperands() != results.size())
254 return op.emitOpError()
255 << "does not return the same number of values ("
256 << op.getNumOperands() << ") as the enclosing function ("
257 << results.size() << ")";
258
259 // If the operation does not have an input, we are done.
260 if (!op.hasOperand())
261 return mlir::success();
262
263 auto inputType = *op.operand_type_begin();
264 auto resultType = results.front();
265
266 // Check that the result type of the function matches the operand type.
267 if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
268 resultType.isa<mlir::UnrankedTensorType>())
269 return mlir::success();
270
271 return op.emitError() << "type of return operand ("
272 << *op.operand_type_begin()
273 << ") doesn't match function result type ("
274 << results.front() << ")";
275 }
276
277 //===----------------------------------------------------------------------===//
278 // StructAccessOp
279
build(mlir::Builder * b,mlir::OperationState & state,mlir::Value input,size_t index)280 void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state,
281 mlir::Value input, size_t index) {
282 // Extract the result type from the input type.
283 StructType structTy = input.getType().cast<StructType>();
284 assert(index < structTy.getNumElementTypes());
285 mlir::Type resultType = structTy.getElementTypes()[index];
286
287 // Call into the auto-generated build method.
288 build(b, state, resultType, input, b->getI64IntegerAttr(index));
289 }
290
verify(StructAccessOp op)291 static mlir::LogicalResult verify(StructAccessOp op) {
292 StructType structTy = op.input().getType().cast<StructType>();
293 size_t index = op.index().getZExtValue();
294 if (index >= structTy.getNumElementTypes())
295 return op.emitOpError()
296 << "index should be within the range of the input struct type";
297 mlir::Type resultType = op.getResult().getType();
298 if (resultType != structTy.getElementTypes()[index])
299 return op.emitOpError() << "must have the same result type as the struct "
300 "element referred to by the index";
301 return mlir::success();
302 }
303
304 //===----------------------------------------------------------------------===//
305 // TransposeOp
306
build(mlir::Builder * builder,mlir::OperationState & state,mlir::Value value)307 void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
308 mlir::Value value) {
309 state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
310 state.addOperands(value);
311 }
312
inferShapes()313 void TransposeOp::inferShapes() {
314 auto arrayTy = getOperand().getType().cast<RankedTensorType>();
315 SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
316 getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
317 }
318
verify(TransposeOp op)319 static mlir::LogicalResult verify(TransposeOp op) {
320 auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
321 auto resultType = op.getType().dyn_cast<RankedTensorType>();
322 if (!inputType || !resultType)
323 return mlir::success();
324
325 auto inputShape = inputType.getShape();
326 if (!std::equal(inputShape.begin(), inputShape.end(),
327 resultType.getShape().rbegin())) {
328 return op.emitError()
329 << "expected result shape to be a transpose of the input";
330 }
331 return mlir::success();
332 }
333
334 //===----------------------------------------------------------------------===//
335 // Toy Types
336 //===----------------------------------------------------------------------===//
337
338 namespace mlir {
339 namespace toy {
340 namespace detail {
341 /// This class represents the internal storage of the Toy `StructType`.
342 struct StructTypeStorage : public mlir::TypeStorage {
343 /// The `KeyTy` is a required type that provides an interface for the storage
344 /// instance. This type will be used when uniquing an instance of the type
345 /// storage. For our struct type, we will unique each instance structurally on
346 /// the elements that it contains.
347 using KeyTy = llvm::ArrayRef<mlir::Type>;
348
349 /// A constructor for the type storage instance.
StructTypeStoragemlir::toy::detail::StructTypeStorage350 StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
351 : elementTypes(elementTypes) {}
352
353 /// Define the comparison function for the key type with the current storage
354 /// instance. This is used when constructing a new instance to ensure that we
355 /// haven't already uniqued an instance of the given key.
operator ==mlir::toy::detail::StructTypeStorage356 bool operator==(const KeyTy &key) const { return key == elementTypes; }
357
358 /// Define a hash function for the key type. This is used when uniquing
359 /// instances of the storage, see the `StructType::get` method.
360 /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
361 /// have hash functions available, so we could just omit this entirely.
hashKeymlir::toy::detail::StructTypeStorage362 static llvm::hash_code hashKey(const KeyTy &key) {
363 return llvm::hash_value(key);
364 }
365
366 /// Define a construction function for the key type from a set of parameters.
367 /// These parameters will be provided when constructing the storage instance
368 /// itself.
369 /// Note: This method isn't necessary because KeyTy can be directly
370 /// constructed with the given parameters.
getKeymlir::toy::detail::StructTypeStorage371 static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
372 return KeyTy(elementTypes);
373 }
374
375 /// Define a construction method for creating a new instance of this storage.
376 /// This method takes an instance of a storage allocator, and an instance of a
377 /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
378 /// allocations used to create the type storage and its internal.
constructmlir::toy::detail::StructTypeStorage379 static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
380 const KeyTy &key) {
381 // Copy the elements from the provided `KeyTy` into the allocator.
382 llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
383
384 // Allocate the storage instance and construct it.
385 return new (allocator.allocate<StructTypeStorage>())
386 StructTypeStorage(elementTypes);
387 }
388
389 /// The following field contains the element types of the struct.
390 llvm::ArrayRef<mlir::Type> elementTypes;
391 };
392 } // end namespace detail
393 } // end namespace toy
394 } // end namespace mlir
395
396 /// Create an instance of a `StructType` with the given element types. There
397 /// *must* be at least one element type.
get(llvm::ArrayRef<mlir::Type> elementTypes)398 StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
399 assert(!elementTypes.empty() && "expected at least 1 element type");
400
401 // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
402 // of this type. The first two parameters are the context to unique in and the
403 // kind of the type. The parameters after the type kind are forwarded to the
404 // storage instance.
405 mlir::MLIRContext *ctx = elementTypes.front().getContext();
406 return Base::get(ctx, ToyTypes::Struct, elementTypes);
407 }
408
409 /// Returns the element types of this struct type.
getElementTypes()410 llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
411 // 'getImpl' returns a pointer to the internal storage instance.
412 return getImpl()->elementTypes;
413 }
414
415 /// Parse an instance of a type registered to the toy dialect.
parseType(mlir::DialectAsmParser & parser) const416 mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
417 // Parse a struct type in the following form:
418 // struct-type ::= `struct` `<` type (`,` type)* `>`
419
420 // NOTE: All MLIR parser function return a ParseResult. This is a
421 // specialization of LogicalResult that auto-converts to a `true` boolean
422 // value on failure to allow for chaining, but may be used with explicit
423 // `mlir::failed/mlir::succeeded` as desired.
424
425 // Parse: `struct` `<`
426 if (parser.parseKeyword("struct") || parser.parseLess())
427 return Type();
428
429 // Parse the element types of the struct.
430 SmallVector<mlir::Type, 1> elementTypes;
431 do {
432 // Parse the current element type.
433 llvm::SMLoc typeLoc = parser.getCurrentLocation();
434 mlir::Type elementType;
435 if (parser.parseType(elementType))
436 return nullptr;
437
438 // Check that the type is either a TensorType or another StructType.
439 if (!elementType.isa<mlir::TensorType>() &&
440 !elementType.isa<StructType>()) {
441 parser.emitError(typeLoc, "element type for a struct must either "
442 "be a TensorType or a StructType, got: ")
443 << elementType;
444 return Type();
445 }
446 elementTypes.push_back(elementType);
447
448 // Parse the optional: `,`
449 } while (succeeded(parser.parseOptionalComma()));
450
451 // Parse: `>`
452 if (parser.parseGreater())
453 return Type();
454 return StructType::get(elementTypes);
455 }
456
457 /// Print an instance of a type registered to the toy dialect.
printType(mlir::Type type,mlir::DialectAsmPrinter & printer) const458 void ToyDialect::printType(mlir::Type type,
459 mlir::DialectAsmPrinter &printer) const {
460 // Currently the only toy type is a struct type.
461 StructType structType = type.cast<StructType>();
462
463 // Print the struct type according to the parser format.
464 printer << "struct<";
465 mlir::interleaveComma(structType.getElementTypes(), printer);
466 printer << '>';
467 }
468
469 //===----------------------------------------------------------------------===//
470 // TableGen'd op method definitions
471 //===----------------------------------------------------------------------===//
472
473 #define GET_OP_CLASSES
474 #include "toy/Ops.cpp.inc"
475