1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/Dialect.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/OpImplementation.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 
22 using namespace mlir;
23 using namespace mlir::toy;
24 
25 #include "toy/Dialect.cpp.inc"
26 
27 //===----------------------------------------------------------------------===//
28 // ToyInlinerInterface
29 //===----------------------------------------------------------------------===//
30 
31 /// This class defines the interface for handling inlining with Toy
32 /// operations.
33 struct ToyInlinerInterface : public DialectInlinerInterface {
34   using DialectInlinerInterface::DialectInlinerInterface;
35 
36   //===--------------------------------------------------------------------===//
37   // Analysis Hooks
38   //===--------------------------------------------------------------------===//
39 
40   /// All call operations within toy can be inlined.
isLegalToInlineToyInlinerInterface41   bool isLegalToInline(Operation *call, Operation *callable,
42                        bool wouldBeCloned) const final {
43     return true;
44   }
45 
46   /// All operations within toy can be inlined.
isLegalToInlineToyInlinerInterface47   bool isLegalToInline(Operation *, Region *, bool,
48                        BlockAndValueMapping &) const final {
49     return true;
50   }
51 
52   //===--------------------------------------------------------------------===//
53   // Transformation Hooks
54   //===--------------------------------------------------------------------===//
55 
56   /// Handle the given inlined terminator(toy.return) by replacing it with a new
57   /// operation as necessary.
handleTerminatorToyInlinerInterface58   void handleTerminator(Operation *op,
59                         ArrayRef<Value> valuesToRepl) const final {
60     // Only "toy.return" needs to be handled here.
61     auto returnOp = cast<ReturnOp>(op);
62 
63     // Replace the values directly with the return operands.
64     assert(returnOp.getNumOperands() == valuesToRepl.size());
65     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
66       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
67   }
68 
69   /// Attempts to materialize a conversion for a type mismatch between a call
70   /// from this dialect, and a callable region. This method should generate an
71   /// operation that takes 'input' as the only operand, and produces a single
72   /// result of 'resultType'. If a conversion can not be generated, nullptr
73   /// should be returned.
materializeCallConversionToyInlinerInterface74   Operation *materializeCallConversion(OpBuilder &builder, Value input,
75                                        Type resultType,
76                                        Location conversionLoc) const final {
77     return builder.create<CastOp>(conversionLoc, resultType, input);
78   }
79 };
80 
81 //===----------------------------------------------------------------------===//
82 // Toy Operations
83 //===----------------------------------------------------------------------===//
84 
85 /// A generalized parser for binary operations. This parses the different forms
86 /// of 'printBinaryOp' below.
parseBinaryOp(mlir::OpAsmParser & parser,mlir::OperationState & result)87 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
88                                        mlir::OperationState &result) {
89   SmallVector<mlir::OpAsmParser::OperandType, 2> operands;
90   llvm::SMLoc operandsLoc = parser.getCurrentLocation();
91   Type type;
92   if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
93       parser.parseOptionalAttrDict(result.attributes) ||
94       parser.parseColonType(type))
95     return mlir::failure();
96 
97   // If the type is a function type, it contains the input and result types of
98   // this operation.
99   if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
100     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
101                                result.operands))
102       return mlir::failure();
103     result.addTypes(funcType.getResults());
104     return mlir::success();
105   }
106 
107   // Otherwise, the parsed type is the type of both operands and results.
108   if (parser.resolveOperands(operands, type, result.operands))
109     return mlir::failure();
110   result.addTypes(type);
111   return mlir::success();
112 }
113 
114 /// A generalized printer for binary operations. It prints in two different
115 /// forms depending on if all of the types match.
printBinaryOp(mlir::OpAsmPrinter & printer,mlir::Operation * op)116 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
117   printer << op->getName() << " " << op->getOperands();
118   printer.printOptionalAttrDict(op->getAttrs());
119   printer << " : ";
120 
121   // If all of the types are the same, print the type directly.
122   Type resultType = *op->result_type_begin();
123   if (llvm::all_of(op->getOperandTypes(),
124                    [=](Type type) { return type == resultType; })) {
125     printer << resultType;
126     return;
127   }
128 
129   // Otherwise, print a functional type.
130   printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // ConstantOp
135 
136 /// Build a constant operation.
137 /// The builder is passed as an argument, so is the state that this method is
138 /// expected to fill in order to build the operation.
build(mlir::OpBuilder & builder,mlir::OperationState & state,double value)139 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
140                        double value) {
141   auto dataType = RankedTensorType::get({}, builder.getF64Type());
142   auto dataAttribute = DenseElementsAttr::get(dataType, value);
143   ConstantOp::build(builder, state, dataType, dataAttribute);
144 }
145 
146 /// The 'OpAsmParser' class provides a collection of methods for parsing
147 /// various punctuation, as well as attributes, operands, types, etc. Each of
148 /// these methods returns a `ParseResult`. This class is a wrapper around
149 /// `LogicalResult` that can be converted to a boolean `true` value on failure,
150 /// or `false` on success. This allows for easily chaining together a set of
151 /// parser rules. These rules are used to populate an `mlir::OperationState`
152 /// similarly to the `build` methods described above.
parseConstantOp(mlir::OpAsmParser & parser,mlir::OperationState & result)153 static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
154                                          mlir::OperationState &result) {
155   mlir::DenseElementsAttr value;
156   if (parser.parseOptionalAttrDict(result.attributes) ||
157       parser.parseAttribute(value, "value", result.attributes))
158     return failure();
159 
160   result.addTypes(value.getType());
161   return success();
162 }
163 
164 /// The 'OpAsmPrinter' class is a stream that allows for formatting
165 /// strings, attributes, operands, types, etc.
print(mlir::OpAsmPrinter & printer,ConstantOp op)166 static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
167   printer << "toy.constant ";
168   printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
169   printer << op.value();
170 }
171 
172 /// Verify that the given attribute value is valid for the given type.
verifyConstantForType(mlir::Type type,mlir::Attribute opaqueValue,mlir::Operation * op)173 static mlir::LogicalResult verifyConstantForType(mlir::Type type,
174                                                  mlir::Attribute opaqueValue,
175                                                  mlir::Operation *op) {
176   if (type.isa<mlir::TensorType>()) {
177     // Check that the value is an elements attribute.
178     auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
179     if (!attrValue)
180       return op->emitError("constant of TensorType must be initialized by "
181                            "a DenseFPElementsAttr, got ")
182              << opaqueValue;
183 
184     // If the return type of the constant is not an unranked tensor, the shape
185     // must match the shape of the attribute holding the data.
186     auto resultType = type.dyn_cast<mlir::RankedTensorType>();
187     if (!resultType)
188       return success();
189 
190     // Check that the rank of the attribute type matches the rank of the
191     // constant result type.
192     auto attrType = attrValue.getType().cast<mlir::TensorType>();
193     if (attrType.getRank() != resultType.getRank()) {
194       return op->emitOpError("return type must match the one of the attached "
195                              "value attribute: ")
196              << attrType.getRank() << " != " << resultType.getRank();
197     }
198 
199     // Check that each of the dimensions match between the two types.
200     for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
201       if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
202         return op->emitOpError(
203                    "return type shape mismatches its attribute at dimension ")
204                << dim << ": " << attrType.getShape()[dim]
205                << " != " << resultType.getShape()[dim];
206       }
207     }
208     return mlir::success();
209   }
210   auto resultType = type.cast<StructType>();
211   llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
212 
213   // Verify that the initializer is an Array.
214   auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
215   if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
216     return op->emitError("constant of StructType must be initialized by an "
217                          "ArrayAttr with the same number of elements, got ")
218            << opaqueValue;
219 
220   // Check that each of the elements are valid.
221   llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
222   for (const auto it : llvm::zip(resultElementTypes, attrElementValues))
223     if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
224       return mlir::failure();
225   return mlir::success();
226 }
227 
228 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
229 /// in the op definition.
verify(ConstantOp op)230 static mlir::LogicalResult verify(ConstantOp op) {
231   return verifyConstantForType(op.getResult().getType(), op.value(), op);
232 }
233 
verify(StructConstantOp op)234 static mlir::LogicalResult verify(StructConstantOp op) {
235   return verifyConstantForType(op.getResult().getType(), op.value(), op);
236 }
237 
238 /// Infer the output shape of the ConstantOp, this is required by the shape
239 /// inference interface.
inferShapes()240 void ConstantOp::inferShapes() { getResult().setType(value().getType()); }
241 
242 //===----------------------------------------------------------------------===//
243 // AddOp
244 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)245 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
246                   mlir::Value lhs, mlir::Value rhs) {
247   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
248   state.addOperands({lhs, rhs});
249 }
250 
251 /// Infer the output shape of the AddOp, this is required by the shape inference
252 /// interface.
inferShapes()253 void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
254 
255 //===----------------------------------------------------------------------===//
256 // CastOp
257 
258 /// Infer the output shape of the CastOp, this is required by the shape
259 /// inference interface.
inferShapes()260 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
261 
262 /// Returns true if the given set of input and result types are compatible with
263 /// this cast operation. This is required by the `CastOpInterface` to verify
264 /// this operation and provide other additional utilities.
areCastCompatible(TypeRange inputs,TypeRange outputs)265 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
266   if (inputs.size() != 1 || outputs.size() != 1)
267     return false;
268   // The inputs must be Tensors with the same element type.
269   TensorType input = inputs.front().dyn_cast<TensorType>();
270   TensorType output = outputs.front().dyn_cast<TensorType>();
271   if (!input || !output || input.getElementType() != output.getElementType())
272     return false;
273   // The shape is required to match if both types are ranked.
274   return !input.hasRank() || !output.hasRank() || input == output;
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // GenericCallOp
279 
build(mlir::OpBuilder & builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)280 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
281                           StringRef callee, ArrayRef<mlir::Value> arguments) {
282   // Generic call always returns an unranked Tensor initially.
283   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
284   state.addOperands(arguments);
285   state.addAttribute("callee", builder.getSymbolRefAttr(callee));
286 }
287 
288 /// Return the callee of the generic call operation, this is required by the
289 /// call interface.
getCallableForCallee()290 CallInterfaceCallable GenericCallOp::getCallableForCallee() {
291   return (*this)->getAttrOfType<SymbolRefAttr>("callee");
292 }
293 
294 /// Get the argument operands to the called function, this is required by the
295 /// call interface.
getArgOperands()296 Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
297 
298 //===----------------------------------------------------------------------===//
299 // MulOp
300 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)301 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
302                   mlir::Value lhs, mlir::Value rhs) {
303   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
304   state.addOperands({lhs, rhs});
305 }
306 
307 /// Infer the output shape of the MulOp, this is required by the shape inference
308 /// interface.
inferShapes()309 void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
310 
311 //===----------------------------------------------------------------------===//
312 // ReturnOp
313 
verify(ReturnOp op)314 static mlir::LogicalResult verify(ReturnOp op) {
315   // We know that the parent operation is a function, because of the 'HasParent'
316   // trait attached to the operation definition.
317   auto function = cast<FuncOp>(op->getParentOp());
318 
319   /// ReturnOps can only have a single optional operand.
320   if (op.getNumOperands() > 1)
321     return op.emitOpError() << "expects at most 1 return operand";
322 
323   // The operand number and types must match the function signature.
324   const auto &results = function.getType().getResults();
325   if (op.getNumOperands() != results.size())
326     return op.emitOpError()
327            << "does not return the same number of values ("
328            << op.getNumOperands() << ") as the enclosing function ("
329            << results.size() << ")";
330 
331   // If the operation does not have an input, we are done.
332   if (!op.hasOperand())
333     return mlir::success();
334 
335   auto inputType = *op.operand_type_begin();
336   auto resultType = results.front();
337 
338   // Check that the result type of the function matches the operand type.
339   if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
340       resultType.isa<mlir::UnrankedTensorType>())
341     return mlir::success();
342 
343   return op.emitError() << "type of return operand (" << inputType
344                         << ") doesn't match function result type ("
345                         << resultType << ")";
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // StructAccessOp
350 
build(mlir::OpBuilder & b,mlir::OperationState & state,mlir::Value input,size_t index)351 void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
352                            mlir::Value input, size_t index) {
353   // Extract the result type from the input type.
354   StructType structTy = input.getType().cast<StructType>();
355   assert(index < structTy.getNumElementTypes());
356   mlir::Type resultType = structTy.getElementTypes()[index];
357 
358   // Call into the auto-generated build method.
359   build(b, state, resultType, input, b.getI64IntegerAttr(index));
360 }
361 
verify(StructAccessOp op)362 static mlir::LogicalResult verify(StructAccessOp op) {
363   StructType structTy = op.input().getType().cast<StructType>();
364   size_t index = op.index();
365   if (index >= structTy.getNumElementTypes())
366     return op.emitOpError()
367            << "index should be within the range of the input struct type";
368   mlir::Type resultType = op.getResult().getType();
369   if (resultType != structTy.getElementTypes()[index])
370     return op.emitOpError() << "must have the same result type as the struct "
371                                "element referred to by the index";
372   return mlir::success();
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // TransposeOp
377 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value value)378 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
379                         mlir::Value value) {
380   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
381   state.addOperands(value);
382 }
383 
inferShapes()384 void TransposeOp::inferShapes() {
385   auto arrayTy = getOperand().getType().cast<RankedTensorType>();
386   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
387   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
388 }
389 
verify(TransposeOp op)390 static mlir::LogicalResult verify(TransposeOp op) {
391   auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
392   auto resultType = op.getType().dyn_cast<RankedTensorType>();
393   if (!inputType || !resultType)
394     return mlir::success();
395 
396   auto inputShape = inputType.getShape();
397   if (!std::equal(inputShape.begin(), inputShape.end(),
398                   resultType.getShape().rbegin())) {
399     return op.emitError()
400            << "expected result shape to be a transpose of the input";
401   }
402   return mlir::success();
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // Toy Types
407 //===----------------------------------------------------------------------===//
408 
409 namespace mlir {
410 namespace toy {
411 namespace detail {
412 /// This class represents the internal storage of the Toy `StructType`.
413 struct StructTypeStorage : public mlir::TypeStorage {
414   /// The `KeyTy` is a required type that provides an interface for the storage
415   /// instance. This type will be used when uniquing an instance of the type
416   /// storage. For our struct type, we will unique each instance structurally on
417   /// the elements that it contains.
418   using KeyTy = llvm::ArrayRef<mlir::Type>;
419 
420   /// A constructor for the type storage instance.
StructTypeStoragemlir::toy::detail::StructTypeStorage421   StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
422       : elementTypes(elementTypes) {}
423 
424   /// Define the comparison function for the key type with the current storage
425   /// instance. This is used when constructing a new instance to ensure that we
426   /// haven't already uniqued an instance of the given key.
operator ==mlir::toy::detail::StructTypeStorage427   bool operator==(const KeyTy &key) const { return key == elementTypes; }
428 
429   /// Define a hash function for the key type. This is used when uniquing
430   /// instances of the storage, see the `StructType::get` method.
431   /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
432   /// have hash functions available, so we could just omit this entirely.
hashKeymlir::toy::detail::StructTypeStorage433   static llvm::hash_code hashKey(const KeyTy &key) {
434     return llvm::hash_value(key);
435   }
436 
437   /// Define a construction function for the key type from a set of parameters.
438   /// These parameters will be provided when constructing the storage instance
439   /// itself.
440   /// Note: This method isn't necessary because KeyTy can be directly
441   /// constructed with the given parameters.
getKeymlir::toy::detail::StructTypeStorage442   static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
443     return KeyTy(elementTypes);
444   }
445 
446   /// Define a construction method for creating a new instance of this storage.
447   /// This method takes an instance of a storage allocator, and an instance of a
448   /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
449   /// allocations used to create the type storage and its internal.
constructmlir::toy::detail::StructTypeStorage450   static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
451                                       const KeyTy &key) {
452     // Copy the elements from the provided `KeyTy` into the allocator.
453     llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
454 
455     // Allocate the storage instance and construct it.
456     return new (allocator.allocate<StructTypeStorage>())
457         StructTypeStorage(elementTypes);
458   }
459 
460   /// The following field contains the element types of the struct.
461   llvm::ArrayRef<mlir::Type> elementTypes;
462 };
463 } // end namespace detail
464 } // end namespace toy
465 } // end namespace mlir
466 
467 /// Create an instance of a `StructType` with the given element types. There
468 /// *must* be at least one element type.
get(llvm::ArrayRef<mlir::Type> elementTypes)469 StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
470   assert(!elementTypes.empty() && "expected at least 1 element type");
471 
472   // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
473   // of this type. The first parameter is the context to unique in. The
474   // parameters after the context are forwarded to the storage instance.
475   mlir::MLIRContext *ctx = elementTypes.front().getContext();
476   return Base::get(ctx, elementTypes);
477 }
478 
479 /// Returns the element types of this struct type.
getElementTypes()480 llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
481   // 'getImpl' returns a pointer to the internal storage instance.
482   return getImpl()->elementTypes;
483 }
484 
485 /// Parse an instance of a type registered to the toy dialect.
parseType(mlir::DialectAsmParser & parser) const486 mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
487   // Parse a struct type in the following form:
488   //   struct-type ::= `struct` `<` type (`,` type)* `>`
489 
490   // NOTE: All MLIR parser function return a ParseResult. This is a
491   // specialization of LogicalResult that auto-converts to a `true` boolean
492   // value on failure to allow for chaining, but may be used with explicit
493   // `mlir::failed/mlir::succeeded` as desired.
494 
495   // Parse: `struct` `<`
496   if (parser.parseKeyword("struct") || parser.parseLess())
497     return Type();
498 
499   // Parse the element types of the struct.
500   SmallVector<mlir::Type, 1> elementTypes;
501   do {
502     // Parse the current element type.
503     llvm::SMLoc typeLoc = parser.getCurrentLocation();
504     mlir::Type elementType;
505     if (parser.parseType(elementType))
506       return nullptr;
507 
508     // Check that the type is either a TensorType or another StructType.
509     if (!elementType.isa<mlir::TensorType, StructType>()) {
510       parser.emitError(typeLoc, "element type for a struct must either "
511                                 "be a TensorType or a StructType, got: ")
512           << elementType;
513       return Type();
514     }
515     elementTypes.push_back(elementType);
516 
517     // Parse the optional: `,`
518   } while (succeeded(parser.parseOptionalComma()));
519 
520   // Parse: `>`
521   if (parser.parseGreater())
522     return Type();
523   return StructType::get(elementTypes);
524 }
525 
526 /// Print an instance of a type registered to the toy dialect.
printType(mlir::Type type,mlir::DialectAsmPrinter & printer) const527 void ToyDialect::printType(mlir::Type type,
528                            mlir::DialectAsmPrinter &printer) const {
529   // Currently the only toy type is a struct type.
530   StructType structType = type.cast<StructType>();
531 
532   // Print the struct type according to the parser format.
533   printer << "struct<";
534   llvm::interleaveComma(structType.getElementTypes(), printer);
535   printer << '>';
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // TableGen'd op method definitions
540 //===----------------------------------------------------------------------===//
541 
542 #define GET_OP_CLASSES
543 #include "toy/Ops.cpp.inc"
544 
545 //===----------------------------------------------------------------------===//
546 // ToyDialect
547 //===----------------------------------------------------------------------===//
548 
549 /// Dialect initialization, the instance will be owned by the context. This is
550 /// the point of registration of types and operations for the dialect.
initialize()551 void ToyDialect::initialize() {
552   addOperations<
553 #define GET_OP_LIST
554 #include "toy/Ops.cpp.inc"
555       >();
556   addInterfaces<ToyInlinerInterface>();
557   addTypes<StructType>();
558 }
559 
materializeConstant(mlir::OpBuilder & builder,mlir::Attribute value,mlir::Type type,mlir::Location loc)560 mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
561                                                  mlir::Attribute value,
562                                                  mlir::Type type,
563                                                  mlir::Location loc) {
564   if (type.isa<StructType>())
565     return builder.create<StructConstantOp>(loc, type,
566                                             value.cast<mlir::ArrayAttr>());
567   return builder.create<ConstantOp>(loc, type,
568                                     value.cast<mlir::DenseElementsAttr>());
569 }
570