1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/Dialect.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/StandardTypes.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 
21 using namespace mlir;
22 using namespace mlir::toy;
23 
24 //===----------------------------------------------------------------------===//
25 // ToyInlinerInterface
26 //===----------------------------------------------------------------------===//
27 
28 /// This class defines the interface for handling inlining with Toy
29 /// operations.
30 struct ToyInlinerInterface : public DialectInlinerInterface {
31   using DialectInlinerInterface::DialectInlinerInterface;
32 
33   //===--------------------------------------------------------------------===//
34   // Analysis Hooks
35   //===--------------------------------------------------------------------===//
36 
37   /// All operations within toy can be inlined.
isLegalToInlineToyInlinerInterface38   bool isLegalToInline(Operation *, Region *,
39                        BlockAndValueMapping &) const final {
40     return true;
41   }
42 
43   //===--------------------------------------------------------------------===//
44   // Transformation Hooks
45   //===--------------------------------------------------------------------===//
46 
47   /// Handle the given inlined terminator(toy.return) by replacing it with a new
48   /// operation as necessary.
handleTerminatorToyInlinerInterface49   void handleTerminator(Operation *op,
50                         ArrayRef<Value> valuesToRepl) const final {
51     // Only "toy.return" needs to be handled here.
52     auto returnOp = cast<ReturnOp>(op);
53 
54     // Replace the values directly with the return operands.
55     assert(returnOp.getNumOperands() == valuesToRepl.size());
56     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
57       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
58   }
59 
60   /// Attempts to materialize a conversion for a type mismatch between a call
61   /// from this dialect, and a callable region. This method should generate an
62   /// operation that takes 'input' as the only operand, and produces a single
63   /// result of 'resultType'. If a conversion can not be generated, nullptr
64   /// should be returned.
materializeCallConversionToyInlinerInterface65   Operation *materializeCallConversion(OpBuilder &builder, Value input,
66                                        Type resultType,
67                                        Location conversionLoc) const final {
68     return builder.create<CastOp>(conversionLoc, resultType, input);
69   }
70 };
71 
72 //===----------------------------------------------------------------------===//
73 // ToyDialect
74 //===----------------------------------------------------------------------===//
75 
76 /// Dialect creation, the instance will be owned by the context. This is the
77 /// point of registration of custom types and operations for the dialect.
ToyDialect(mlir::MLIRContext * ctx)78 ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
79   addOperations<
80 #define GET_OP_LIST
81 #include "toy/Ops.cpp.inc"
82       >();
83   addInterfaces<ToyInlinerInterface>();
84   addTypes<StructType>();
85 }
86 
materializeConstant(mlir::OpBuilder & builder,mlir::Attribute value,mlir::Type type,mlir::Location loc)87 mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
88                                                  mlir::Attribute value,
89                                                  mlir::Type type,
90                                                  mlir::Location loc) {
91   if (type.isa<StructType>())
92     return builder.create<StructConstantOp>(loc, type,
93                                             value.cast<mlir::ArrayAttr>());
94   return builder.create<ConstantOp>(loc, type,
95                                     value.cast<mlir::DenseElementsAttr>());
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // Toy Operations
100 //===----------------------------------------------------------------------===//
101 
102 //===----------------------------------------------------------------------===//
103 // ConstantOp
104 
105 /// Build a constant operation.
106 /// The builder is passed as an argument, so is the state that this method is
107 /// expected to fill in order to build the operation.
build(mlir::Builder * builder,mlir::OperationState & state,double value)108 void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
109                        double value) {
110   auto dataType = RankedTensorType::get({}, builder->getF64Type());
111   auto dataAttribute = DenseElementsAttr::get(dataType, value);
112   ConstantOp::build(builder, state, dataType, dataAttribute);
113 }
114 
115 /// Verify that the given attribute value is valid for the given type.
verifyConstantForType(mlir::Type type,mlir::Attribute opaqueValue,mlir::Operation * op)116 static mlir::LogicalResult verifyConstantForType(mlir::Type type,
117                                                  mlir::Attribute opaqueValue,
118                                                  mlir::Operation *op) {
119   if (type.isa<mlir::TensorType>()) {
120     // Check that the value is a elements attribute.
121     auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
122     if (!attrValue)
123       return op->emitError("constant of TensorType must be initialized by "
124                            "a DenseFPElementsAttr, got ")
125              << opaqueValue;
126 
127     // If the return type of the constant is not an unranked tensor, the shape
128     // must match the shape of the attribute holding the data.
129     auto resultType = type.dyn_cast<mlir::RankedTensorType>();
130     if (!resultType)
131       return success();
132 
133     // Check that the rank of the attribute type matches the rank of the
134     // constant result type.
135     auto attrType = attrValue.getType().cast<mlir::TensorType>();
136     if (attrType.getRank() != resultType.getRank()) {
137       return op->emitOpError("return type must match the one of the attached "
138                              "value attribute: ")
139              << attrType.getRank() << " != " << resultType.getRank();
140     }
141 
142     // Check that each of the dimensions match between the two types.
143     for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
144       if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
145         return op->emitOpError(
146                    "return type shape mismatches its attribute at dimension ")
147                << dim << ": " << attrType.getShape()[dim]
148                << " != " << resultType.getShape()[dim];
149       }
150     }
151     return mlir::success();
152   }
153   auto resultType = type.cast<StructType>();
154   llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
155 
156   // Verify that the initializer is an Array.
157   auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
158   if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
159     return op->emitError("constant of StructType must be initialized by an "
160                          "ArrayAttr with the same number of elements, got ")
161            << opaqueValue;
162 
163   // Check that each of the elements are valid.
164   llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
165   for (const auto &it : llvm::zip(resultElementTypes, attrElementValues))
166     if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
167       return mlir::failure();
168   return mlir::success();
169 }
170 
171 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
172 /// in the op definition.
verify(ConstantOp op)173 static mlir::LogicalResult verify(ConstantOp op) {
174   return verifyConstantForType(op.getResult().getType(), op.value(), op);
175 }
176 
verify(StructConstantOp op)177 static mlir::LogicalResult verify(StructConstantOp op) {
178   return verifyConstantForType(op.getResult().getType(), op.value(), op);
179 }
180 
181 /// Infer the output shape of the ConstantOp, this is required by the shape
182 /// inference interface.
inferShapes()183 void ConstantOp::inferShapes() { getResult().setType(value().getType()); }
184 
185 //===----------------------------------------------------------------------===//
186 // AddOp
187 
build(mlir::Builder * builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)188 void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
189                   mlir::Value lhs, mlir::Value rhs) {
190   state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
191   state.addOperands({lhs, rhs});
192 }
193 
194 /// Infer the output shape of the AddOp, this is required by the shape inference
195 /// interface.
inferShapes()196 void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
197 
198 //===----------------------------------------------------------------------===//
199 // CastOp
200 
201 /// Infer the output shape of the CastOp, this is required by the shape
202 /// inference interface.
inferShapes()203 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
204 
205 //===----------------------------------------------------------------------===//
206 // GenericCallOp
207 
build(mlir::Builder * builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)208 void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
209                           StringRef callee, ArrayRef<mlir::Value> arguments) {
210   // Generic call always returns an unranked Tensor initially.
211   state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
212   state.addOperands(arguments);
213   state.addAttribute("callee", builder->getSymbolRefAttr(callee));
214 }
215 
216 /// Return the callee of the generic call operation, this is required by the
217 /// call interface.
getCallableForCallee()218 CallInterfaceCallable GenericCallOp::getCallableForCallee() {
219   return getAttrOfType<SymbolRefAttr>("callee");
220 }
221 
222 /// Get the argument operands to the called function, this is required by the
223 /// call interface.
getArgOperands()224 Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
225 
226 //===----------------------------------------------------------------------===//
227 // MulOp
228 
build(mlir::Builder * builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)229 void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
230                   mlir::Value lhs, mlir::Value rhs) {
231   state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
232   state.addOperands({lhs, rhs});
233 }
234 
235 /// Infer the output shape of the MulOp, this is required by the shape inference
236 /// interface.
inferShapes()237 void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
238 
239 //===----------------------------------------------------------------------===//
240 // ReturnOp
241 
verify(ReturnOp op)242 static mlir::LogicalResult verify(ReturnOp op) {
243   // We know that the parent operation is a function, because of the 'HasParent'
244   // trait attached to the operation definition.
245   auto function = cast<FuncOp>(op.getParentOp());
246 
247   /// ReturnOps can only have a single optional operand.
248   if (op.getNumOperands() > 1)
249     return op.emitOpError() << "expects at most 1 return operand";
250 
251   // The operand number and types must match the function signature.
252   const auto &results = function.getType().getResults();
253   if (op.getNumOperands() != results.size())
254     return op.emitOpError()
255            << "does not return the same number of values ("
256            << op.getNumOperands() << ") as the enclosing function ("
257            << results.size() << ")";
258 
259   // If the operation does not have an input, we are done.
260   if (!op.hasOperand())
261     return mlir::success();
262 
263   auto inputType = *op.operand_type_begin();
264   auto resultType = results.front();
265 
266   // Check that the result type of the function matches the operand type.
267   if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
268       resultType.isa<mlir::UnrankedTensorType>())
269     return mlir::success();
270 
271   return op.emitError() << "type of return operand ("
272                         << *op.operand_type_begin()
273                         << ") doesn't match function result type ("
274                         << results.front() << ")";
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // StructAccessOp
279 
build(mlir::Builder * b,mlir::OperationState & state,mlir::Value input,size_t index)280 void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state,
281                            mlir::Value input, size_t index) {
282   // Extract the result type from the input type.
283   StructType structTy = input.getType().cast<StructType>();
284   assert(index < structTy.getNumElementTypes());
285   mlir::Type resultType = structTy.getElementTypes()[index];
286 
287   // Call into the auto-generated build method.
288   build(b, state, resultType, input, b->getI64IntegerAttr(index));
289 }
290 
verify(StructAccessOp op)291 static mlir::LogicalResult verify(StructAccessOp op) {
292   StructType structTy = op.input().getType().cast<StructType>();
293   size_t index = op.index().getZExtValue();
294   if (index >= structTy.getNumElementTypes())
295     return op.emitOpError()
296            << "index should be within the range of the input struct type";
297   mlir::Type resultType = op.getResult().getType();
298   if (resultType != structTy.getElementTypes()[index])
299     return op.emitOpError() << "must have the same result type as the struct "
300                                "element referred to by the index";
301   return mlir::success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // TransposeOp
306 
build(mlir::Builder * builder,mlir::OperationState & state,mlir::Value value)307 void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
308                         mlir::Value value) {
309   state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
310   state.addOperands(value);
311 }
312 
inferShapes()313 void TransposeOp::inferShapes() {
314   auto arrayTy = getOperand().getType().cast<RankedTensorType>();
315   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
316   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
317 }
318 
verify(TransposeOp op)319 static mlir::LogicalResult verify(TransposeOp op) {
320   auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
321   auto resultType = op.getType().dyn_cast<RankedTensorType>();
322   if (!inputType || !resultType)
323     return mlir::success();
324 
325   auto inputShape = inputType.getShape();
326   if (!std::equal(inputShape.begin(), inputShape.end(),
327                   resultType.getShape().rbegin())) {
328     return op.emitError()
329            << "expected result shape to be a transpose of the input";
330   }
331   return mlir::success();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // Toy Types
336 //===----------------------------------------------------------------------===//
337 
338 namespace mlir {
339 namespace toy {
340 namespace detail {
341 /// This class represents the internal storage of the Toy `StructType`.
342 struct StructTypeStorage : public mlir::TypeStorage {
343   /// The `KeyTy` is a required type that provides an interface for the storage
344   /// instance. This type will be used when uniquing an instance of the type
345   /// storage. For our struct type, we will unique each instance structurally on
346   /// the elements that it contains.
347   using KeyTy = llvm::ArrayRef<mlir::Type>;
348 
349   /// A constructor for the type storage instance.
StructTypeStoragemlir::toy::detail::StructTypeStorage350   StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
351       : elementTypes(elementTypes) {}
352 
353   /// Define the comparison function for the key type with the current storage
354   /// instance. This is used when constructing a new instance to ensure that we
355   /// haven't already uniqued an instance of the given key.
operator ==mlir::toy::detail::StructTypeStorage356   bool operator==(const KeyTy &key) const { return key == elementTypes; }
357 
358   /// Define a hash function for the key type. This is used when uniquing
359   /// instances of the storage, see the `StructType::get` method.
360   /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
361   /// have hash functions available, so we could just omit this entirely.
hashKeymlir::toy::detail::StructTypeStorage362   static llvm::hash_code hashKey(const KeyTy &key) {
363     return llvm::hash_value(key);
364   }
365 
366   /// Define a construction function for the key type from a set of parameters.
367   /// These parameters will be provided when constructing the storage instance
368   /// itself.
369   /// Note: This method isn't necessary because KeyTy can be directly
370   /// constructed with the given parameters.
getKeymlir::toy::detail::StructTypeStorage371   static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
372     return KeyTy(elementTypes);
373   }
374 
375   /// Define a construction method for creating a new instance of this storage.
376   /// This method takes an instance of a storage allocator, and an instance of a
377   /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
378   /// allocations used to create the type storage and its internal.
constructmlir::toy::detail::StructTypeStorage379   static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
380                                       const KeyTy &key) {
381     // Copy the elements from the provided `KeyTy` into the allocator.
382     llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
383 
384     // Allocate the storage instance and construct it.
385     return new (allocator.allocate<StructTypeStorage>())
386         StructTypeStorage(elementTypes);
387   }
388 
389   /// The following field contains the element types of the struct.
390   llvm::ArrayRef<mlir::Type> elementTypes;
391 };
392 } // end namespace detail
393 } // end namespace toy
394 } // end namespace mlir
395 
396 /// Create an instance of a `StructType` with the given element types. There
397 /// *must* be at least one element type.
get(llvm::ArrayRef<mlir::Type> elementTypes)398 StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
399   assert(!elementTypes.empty() && "expected at least 1 element type");
400 
401   // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
402   // of this type. The first two parameters are the context to unique in and the
403   // kind of the type. The parameters after the type kind are forwarded to the
404   // storage instance.
405   mlir::MLIRContext *ctx = elementTypes.front().getContext();
406   return Base::get(ctx, ToyTypes::Struct, elementTypes);
407 }
408 
409 /// Returns the element types of this struct type.
getElementTypes()410 llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
411   // 'getImpl' returns a pointer to the internal storage instance.
412   return getImpl()->elementTypes;
413 }
414 
415 /// Parse an instance of a type registered to the toy dialect.
parseType(mlir::DialectAsmParser & parser) const416 mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
417   // Parse a struct type in the following form:
418   //   struct-type ::= `struct` `<` type (`,` type)* `>`
419 
420   // NOTE: All MLIR parser function return a ParseResult. This is a
421   // specialization of LogicalResult that auto-converts to a `true` boolean
422   // value on failure to allow for chaining, but may be used with explicit
423   // `mlir::failed/mlir::succeeded` as desired.
424 
425   // Parse: `struct` `<`
426   if (parser.parseKeyword("struct") || parser.parseLess())
427     return Type();
428 
429   // Parse the element types of the struct.
430   SmallVector<mlir::Type, 1> elementTypes;
431   do {
432     // Parse the current element type.
433     llvm::SMLoc typeLoc = parser.getCurrentLocation();
434     mlir::Type elementType;
435     if (parser.parseType(elementType))
436       return nullptr;
437 
438     // Check that the type is either a TensorType or another StructType.
439     if (!elementType.isa<mlir::TensorType>() &&
440         !elementType.isa<StructType>()) {
441       parser.emitError(typeLoc, "element type for a struct must either "
442                                 "be a TensorType or a StructType, got: ")
443           << elementType;
444       return Type();
445     }
446     elementTypes.push_back(elementType);
447 
448     // Parse the optional: `,`
449   } while (succeeded(parser.parseOptionalComma()));
450 
451   // Parse: `>`
452   if (parser.parseGreater())
453     return Type();
454   return StructType::get(elementTypes);
455 }
456 
457 /// Print an instance of a type registered to the toy dialect.
printType(mlir::Type type,mlir::DialectAsmPrinter & printer) const458 void ToyDialect::printType(mlir::Type type,
459                            mlir::DialectAsmPrinter &printer) const {
460   // Currently the only toy type is a struct type.
461   StructType structType = type.cast<StructType>();
462 
463   // Print the struct type according to the parser format.
464   printer << "struct<";
465   mlir::interleaveComma(structType.getElementTypes(), printer);
466   printer << '>';
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // TableGen'd op method definitions
471 //===----------------------------------------------------------------------===//
472 
473 #define GET_OP_CLASSES
474 #include "toy/Ops.cpp.inc"
475