1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::test;
27 
28 #include "TestOpsDialect.cpp.inc"
29 
registerTestDialect(DialectRegistry & registry)30 void mlir::test::registerTestDialect(DialectRegistry &registry) {
31   registry.insert<TestDialect>();
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // TestDialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 
40 /// Testing the correctness of some traits.
41 static_assert(
42     llvm::is_detected<OpTrait::has_implicit_terminator_t,
43                       SingleBlockImplicitTerminatorOp>::value,
44     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
45 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
46                   SingleBlockImplicitTerminatorOp>::value,
47               "hasSingleBlockImplicitTerminator does not match "
48               "SingleBlockImplicitTerminatorOp");
49 
50 // Test support for interacting with the AsmPrinter.
51 struct TestOpAsmInterface : public OpAsmDialectInterface {
52   using OpAsmDialectInterface::OpAsmDialectInterface;
53 
getAlias__anon5e0e5cfc0111::TestOpAsmInterface54   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
55     StringAttr strAttr = attr.dyn_cast<StringAttr>();
56     if (!strAttr)
57       return failure();
58 
59     // Check the contents of the string attribute to see what the test alias
60     // should be named.
61     Optional<StringRef> aliasName =
62         StringSwitch<Optional<StringRef>>(strAttr.getValue())
63             .Case("alias_test:dot_in_name", StringRef("test.alias"))
64             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
65             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
66             .Case("alias_test:sanitize_conflict_a",
67                   StringRef("test_alias_conflict0"))
68             .Case("alias_test:sanitize_conflict_b",
69                   StringRef("test_alias_conflict0_"))
70             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
71             .Default(llvm::None);
72     if (!aliasName)
73       return failure();
74 
75     os << *aliasName;
76     return success();
77   }
78 
getAsmResultNames__anon5e0e5cfc0111::TestOpAsmInterface79   void getAsmResultNames(Operation *op,
80                          OpAsmSetValueNameFn setNameFn) const final {
81     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
82       setNameFn(asmOp, "result");
83   }
84 
getAsmBlockArgumentNames__anon5e0e5cfc0111::TestOpAsmInterface85   void getAsmBlockArgumentNames(Block *block,
86                                 OpAsmSetValueNameFn setNameFn) const final {
87     auto op = block->getParentOp();
88     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
89     if (!arrayAttr)
90       return;
91     auto args = block->getArguments();
92     auto e = std::min(arrayAttr.size(), args.size());
93     for (unsigned i = 0; i < e; ++i) {
94       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
95         setNameFn(args[i], strAttr.getValue());
96     }
97   }
98 };
99 
100 struct TestDialectFoldInterface : public DialectFoldInterface {
101   using DialectFoldInterface::DialectFoldInterface;
102 
103   /// Registered hook to check if the given region, which is attached to an
104   /// operation that is *not* isolated from above, should be used when
105   /// materializing constants.
shouldMaterializeInto__anon5e0e5cfc0111::TestDialectFoldInterface106   bool shouldMaterializeInto(Region *region) const final {
107     // If this is a one region operation, then insert into it.
108     return isa<OneRegionOp>(region->getParentOp());
109   }
110 };
111 
112 /// This class defines the interface for handling inlining with standard
113 /// operations.
114 struct TestInlinerInterface : public DialectInlinerInterface {
115   using DialectInlinerInterface::DialectInlinerInterface;
116 
117   //===--------------------------------------------------------------------===//
118   // Analysis Hooks
119   //===--------------------------------------------------------------------===//
120 
isLegalToInline__anon5e0e5cfc0111::TestInlinerInterface121   bool isLegalToInline(Operation *call, Operation *callable,
122                        bool wouldBeCloned) const final {
123     // Don't allow inlining calls that are marked `noinline`.
124     return !call->hasAttr("noinline");
125   }
isLegalToInline__anon5e0e5cfc0111::TestInlinerInterface126   bool isLegalToInline(Region *, Region *, bool,
127                        BlockAndValueMapping &) const final {
128     // Inlining into test dialect regions is legal.
129     return true;
130   }
isLegalToInline__anon5e0e5cfc0111::TestInlinerInterface131   bool isLegalToInline(Operation *, Region *, bool,
132                        BlockAndValueMapping &) const final {
133     return true;
134   }
135 
shouldAnalyzeRecursively__anon5e0e5cfc0111::TestInlinerInterface136   bool shouldAnalyzeRecursively(Operation *op) const final {
137     // Analyze recursively if this is not a functional region operation, it
138     // froms a separate functional scope.
139     return !isa<FunctionalRegionOp>(op);
140   }
141 
142   //===--------------------------------------------------------------------===//
143   // Transformation Hooks
144   //===--------------------------------------------------------------------===//
145 
146   /// Handle the given inlined terminator by replacing it with a new operation
147   /// as necessary.
handleTerminator__anon5e0e5cfc0111::TestInlinerInterface148   void handleTerminator(Operation *op,
149                         ArrayRef<Value> valuesToRepl) const final {
150     // Only handle "test.return" here.
151     auto returnOp = dyn_cast<TestReturnOp>(op);
152     if (!returnOp)
153       return;
154 
155     // Replace the values directly with the return operands.
156     assert(returnOp.getNumOperands() == valuesToRepl.size());
157     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
158       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
159   }
160 
161   /// Attempt to materialize a conversion for a type mismatch between a call
162   /// from this dialect, and a callable region. This method should generate an
163   /// operation that takes 'input' as the only operand, and produces a single
164   /// result of 'resultType'. If a conversion can not be generated, nullptr
165   /// should be returned.
materializeCallConversion__anon5e0e5cfc0111::TestInlinerInterface166   Operation *materializeCallConversion(OpBuilder &builder, Value input,
167                                        Type resultType,
168                                        Location conversionLoc) const final {
169     // Only allow conversion for i16/i32 types.
170     if (!(resultType.isSignlessInteger(16) ||
171           resultType.isSignlessInteger(32)) ||
172         !(input.getType().isSignlessInteger(16) ||
173           input.getType().isSignlessInteger(32)))
174       return nullptr;
175     return builder.create<TestCastOp>(conversionLoc, resultType, input);
176   }
177 
processInlinedCallBlocks__anon5e0e5cfc0111::TestInlinerInterface178   void processInlinedCallBlocks(
179       Operation *call,
180       iterator_range<Region::iterator> inlinedBlocks) const final {
181     if (!isa<ConversionCallOp>(call))
182       return;
183 
184     // Set attributed on all ops in the inlined blocks.
185     for (Block &block : inlinedBlocks) {
186       block.walk([&](Operation *op) {
187         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
188       });
189     }
190   }
191 };
192 
193 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
194 public:
TestReductionPatternInterface__anon5e0e5cfc0111::TestReductionPatternInterface195   TestReductionPatternInterface(Dialect *dialect)
196       : DialectReductionPatternInterface(dialect) {}
197 
198   virtual void
populateReductionPatterns__anon5e0e5cfc0111::TestReductionPatternInterface199   populateReductionPatterns(RewritePatternSet &patterns) const final {
200     populateTestReductionPatterns(patterns);
201   }
202 };
203 
204 } // end anonymous namespace
205 
206 //===----------------------------------------------------------------------===//
207 // TestDialect
208 //===----------------------------------------------------------------------===//
209 
210 static void testSideEffectOpGetEffect(
211     Operation *op,
212     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
213 
214 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
215 struct TestOpEffectInterfaceFallback
216     : public TestEffectOpInterface::FallbackModel<
217           TestOpEffectInterfaceFallback> {
classofTestOpEffectInterfaceFallback218   static bool classof(Operation *op) {
219     bool isSupportedOp =
220         op->getName().getStringRef() == "test.unregistered_side_effect_op";
221     assert(isSupportedOp && "Unexpected dispatch");
222     return isSupportedOp;
223   }
224 
225   void
getEffectsTestOpEffectInterfaceFallback226   getEffects(Operation *op,
227              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
228                  &effects) const {
229     testSideEffectOpGetEffect(op, effects);
230   }
231 };
232 
initialize()233 void TestDialect::initialize() {
234   registerAttributes();
235   registerTypes();
236   addOperations<
237 #define GET_OP_LIST
238 #include "TestOps.cpp.inc"
239       >();
240   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
241                 TestInlinerInterface, TestReductionPatternInterface>();
242   allowUnknownOperations();
243 
244   // Instantiate our fallback op interface that we'll use on specific
245   // unregistered op.
246   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
247 }
~TestDialect()248 TestDialect::~TestDialect() {
249   delete static_cast<TestOpEffectInterfaceFallback *>(
250       fallbackEffectOpInterfaces);
251 }
252 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)253 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
254                                             Type type, Location loc) {
255   return builder.create<TestOpConstant>(loc, type, value);
256 }
257 
getRegisteredInterfaceForOp(TypeID typeID,OperationName opName)258 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
259                                                OperationName opName) {
260   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
261       typeID == TypeID::get<TestEffectOpInterface>())
262     return fallbackEffectOpInterfaces;
263   return nullptr;
264 }
265 
verifyOperationAttribute(Operation * op,NamedAttribute namedAttr)266 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
267                                                     NamedAttribute namedAttr) {
268   if (namedAttr.first == "test.invalid_attr")
269     return op->emitError() << "invalid to use 'test.invalid_attr'";
270   return success();
271 }
272 
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute namedAttr)273 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
274                                                     unsigned regionIndex,
275                                                     unsigned argIndex,
276                                                     NamedAttribute namedAttr) {
277   if (namedAttr.first == "test.invalid_attr")
278     return op->emitError() << "invalid to use 'test.invalid_attr'";
279   return success();
280 }
281 
282 LogicalResult
verifyRegionResultAttribute(Operation * op,unsigned regionIndex,unsigned resultIndex,NamedAttribute namedAttr)283 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
284                                          unsigned resultIndex,
285                                          NamedAttribute namedAttr) {
286   if (namedAttr.first == "test.invalid_attr")
287     return op->emitError() << "invalid to use 'test.invalid_attr'";
288   return success();
289 }
290 
291 Optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const292 TestDialect::getParseOperationHook(StringRef opName) const {
293   if (opName == "test.dialect_custom_printer") {
294     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
295       return parser.parseKeyword("custom_format");
296     }};
297   }
298   return None;
299 }
300 
printOperation(Operation * op,OpAsmPrinter & printer) const301 LogicalResult TestDialect::printOperation(Operation *op,
302                                           OpAsmPrinter &printer) const {
303   StringRef opName = op->getName().getStringRef();
304   if (opName == "test.dialect_custom_printer") {
305     printer.getStream() << opName << " custom_format";
306     return success();
307   }
308   return failure();
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // TestBranchOp
313 //===----------------------------------------------------------------------===//
314 
315 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)316 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
317   assert(index == 0 && "invalid successor index");
318   return targetOperandsMutable();
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // TestDialectCanonicalizerOp
323 //===----------------------------------------------------------------------===//
324 
325 static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,PatternRewriter & rewriter)326 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
327                                PatternRewriter &rewriter) {
328   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
329                                           rewriter.getI32IntegerAttr(42));
330   return success();
331 }
332 
getCanonicalizationPatterns(RewritePatternSet & results) const333 void TestDialect::getCanonicalizationPatterns(
334     RewritePatternSet &results) const {
335   results.add(&dialectCanonicalizationPattern);
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // TestFoldToCallOp
340 //===----------------------------------------------------------------------===//
341 
342 namespace {
343 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
344   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
345 
matchAndRewrite__anon5e0e5cfc0411::FoldToCallOpPattern346   LogicalResult matchAndRewrite(FoldToCallOp op,
347                                 PatternRewriter &rewriter) const override {
348     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
349                                         ValueRange());
350     return success();
351   }
352 };
353 } // end anonymous namespace
354 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)355 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
356                                                MLIRContext *context) {
357   results.add<FoldToCallOpPattern>(context);
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // Test Format* operations
362 //===----------------------------------------------------------------------===//
363 
364 //===----------------------------------------------------------------------===//
365 // Parsing
366 
parseCustomDirectiveOperands(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands)367 static ParseResult parseCustomDirectiveOperands(
368     OpAsmParser &parser, OpAsmParser::OperandType &operand,
369     Optional<OpAsmParser::OperandType> &optOperand,
370     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
371   if (parser.parseOperand(operand))
372     return failure();
373   if (succeeded(parser.parseOptionalComma())) {
374     optOperand.emplace();
375     if (parser.parseOperand(*optOperand))
376       return failure();
377   }
378   if (parser.parseArrow() || parser.parseLParen() ||
379       parser.parseOperandList(varOperands) || parser.parseRParen())
380     return failure();
381   return success();
382 }
383 static ParseResult
parseCustomDirectiveResults(OpAsmParser & parser,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)384 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
385                             Type &optOperandType,
386                             SmallVectorImpl<Type> &varOperandTypes) {
387   if (parser.parseColon())
388     return failure();
389 
390   if (parser.parseType(operandType))
391     return failure();
392   if (succeeded(parser.parseOptionalComma())) {
393     if (parser.parseType(optOperandType))
394       return failure();
395   }
396   if (parser.parseArrow() || parser.parseLParen() ||
397       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
398     return failure();
399   return success();
400 }
401 static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser & parser,Type operandType,Type optOperandType,const SmallVectorImpl<Type> & varOperandTypes)402 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
403                                  Type optOperandType,
404                                  const SmallVectorImpl<Type> &varOperandTypes) {
405   if (parser.parseKeyword("type_refs_capture"))
406     return failure();
407 
408   Type operandType2, optOperandType2;
409   SmallVector<Type, 1> varOperandTypes2;
410   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
411                                   varOperandTypes2))
412     return failure();
413 
414   if (operandType != operandType2 || optOperandType != optOperandType2 ||
415       varOperandTypes != varOperandTypes2)
416     return failure();
417 
418   return success();
419 }
parseCustomDirectiveOperandsAndTypes(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)420 static ParseResult parseCustomDirectiveOperandsAndTypes(
421     OpAsmParser &parser, OpAsmParser::OperandType &operand,
422     Optional<OpAsmParser::OperandType> &optOperand,
423     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
424     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
425   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
426       parseCustomDirectiveResults(parser, operandType, optOperandType,
427                                   varOperandTypes))
428     return failure();
429   return success();
430 }
parseCustomDirectiveRegions(OpAsmParser & parser,Region & region,SmallVectorImpl<std::unique_ptr<Region>> & varRegions)431 static ParseResult parseCustomDirectiveRegions(
432     OpAsmParser &parser, Region &region,
433     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
434   if (parser.parseRegion(region))
435     return failure();
436   if (failed(parser.parseOptionalComma()))
437     return success();
438   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
439   if (parser.parseRegion(*varRegion))
440     return failure();
441   varRegions.emplace_back(std::move(varRegion));
442   return success();
443 }
444 static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser & parser,Block * & successor,SmallVectorImpl<Block * > & varSuccessors)445 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
446                                SmallVectorImpl<Block *> &varSuccessors) {
447   if (parser.parseSuccessor(successor))
448     return failure();
449   if (failed(parser.parseOptionalComma()))
450     return success();
451   Block *varSuccessor;
452   if (parser.parseSuccessor(varSuccessor))
453     return failure();
454   varSuccessors.append(2, varSuccessor);
455   return success();
456 }
parseCustomDirectiveAttributes(OpAsmParser & parser,IntegerAttr & attr,IntegerAttr & optAttr)457 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
458                                                   IntegerAttr &attr,
459                                                   IntegerAttr &optAttr) {
460   if (parser.parseAttribute(attr))
461     return failure();
462   if (succeeded(parser.parseOptionalComma())) {
463     if (parser.parseAttribute(optAttr))
464       return failure();
465   }
466   return success();
467 }
468 
parseCustomDirectiveAttrDict(OpAsmParser & parser,NamedAttrList & attrs)469 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
470                                                 NamedAttrList &attrs) {
471   return parser.parseOptionalAttrDict(attrs);
472 }
parseCustomDirectiveOptionalOperandRef(OpAsmParser & parser,Optional<OpAsmParser::OperandType> & optOperand)473 static ParseResult parseCustomDirectiveOptionalOperandRef(
474     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
475   int64_t operandCount = 0;
476   if (parser.parseInteger(operandCount))
477     return failure();
478   bool expectedOptionalOperand = operandCount == 0;
479   return success(expectedOptionalOperand != optOperand.hasValue());
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // Printing
484 
printCustomDirectiveOperands(OpAsmPrinter & printer,Operation *,Value operand,Value optOperand,OperandRange varOperands)485 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
486                                          Value operand, Value optOperand,
487                                          OperandRange varOperands) {
488   printer << operand;
489   if (optOperand)
490     printer << ", " << optOperand;
491   printer << " -> (" << varOperands << ")";
492 }
printCustomDirectiveResults(OpAsmPrinter & printer,Operation *,Type operandType,Type optOperandType,TypeRange varOperandTypes)493 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
494                                         Type operandType, Type optOperandType,
495                                         TypeRange varOperandTypes) {
496   printer << " : " << operandType;
497   if (optOperandType)
498     printer << ", " << optOperandType;
499   printer << " -> (" << varOperandTypes << ")";
500 }
printCustomDirectiveWithTypeRefs(OpAsmPrinter & printer,Operation * op,Type operandType,Type optOperandType,TypeRange varOperandTypes)501 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
502                                              Operation *op, Type operandType,
503                                              Type optOperandType,
504                                              TypeRange varOperandTypes) {
505   printer << " type_refs_capture ";
506   printCustomDirectiveResults(printer, op, operandType, optOperandType,
507                               varOperandTypes);
508 }
printCustomDirectiveOperandsAndTypes(OpAsmPrinter & printer,Operation * op,Value operand,Value optOperand,OperandRange varOperands,Type operandType,Type optOperandType,TypeRange varOperandTypes)509 static void printCustomDirectiveOperandsAndTypes(
510     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
511     OperandRange varOperands, Type operandType, Type optOperandType,
512     TypeRange varOperandTypes) {
513   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
514   printCustomDirectiveResults(printer, op, operandType, optOperandType,
515                               varOperandTypes);
516 }
printCustomDirectiveRegions(OpAsmPrinter & printer,Operation *,Region & region,MutableArrayRef<Region> varRegions)517 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
518                                         Region &region,
519                                         MutableArrayRef<Region> varRegions) {
520   printer.printRegion(region);
521   if (!varRegions.empty()) {
522     printer << ", ";
523     for (Region &region : varRegions)
524       printer.printRegion(region);
525   }
526 }
printCustomDirectiveSuccessors(OpAsmPrinter & printer,Operation *,Block * successor,SuccessorRange varSuccessors)527 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
528                                            Block *successor,
529                                            SuccessorRange varSuccessors) {
530   printer << successor;
531   if (!varSuccessors.empty())
532     printer << ", " << varSuccessors.front();
533 }
printCustomDirectiveAttributes(OpAsmPrinter & printer,Operation *,Attribute attribute,Attribute optAttribute)534 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
535                                            Attribute attribute,
536                                            Attribute optAttribute) {
537   printer << attribute;
538   if (optAttribute)
539     printer << ", " << optAttribute;
540 }
541 
printCustomDirectiveAttrDict(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)542 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
543                                          DictionaryAttr attrs) {
544   printer.printOptionalAttrDict(attrs.getValue());
545 }
546 
printCustomDirectiveOptionalOperandRef(OpAsmPrinter & printer,Operation * op,Value optOperand)547 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
548                                                    Operation *op,
549                                                    Value optOperand) {
550   printer << (optOperand ? "1" : "0");
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // Test IsolatedRegionOp - parse passthrough region arguments.
555 //===----------------------------------------------------------------------===//
556 
parseIsolatedRegionOp(OpAsmParser & parser,OperationState & result)557 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
558                                          OperationState &result) {
559   OpAsmParser::OperandType argInfo;
560   Type argType = parser.getBuilder().getIndexType();
561 
562   // Parse the input operand.
563   if (parser.parseOperand(argInfo) ||
564       parser.resolveOperand(argInfo, argType, result.operands))
565     return failure();
566 
567   // Parse the body region, and reuse the operand info as the argument info.
568   Region *body = result.addRegion();
569   return parser.parseRegion(*body, argInfo, argType,
570                             /*enableNameShadowing=*/true);
571 }
572 
print(OpAsmPrinter & p,IsolatedRegionOp op)573 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
574   p << "test.isolated_region ";
575   p.printOperand(op.getOperand());
576   p.shadowRegionArgs(op.region(), op.getOperand());
577   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // Test SSACFGRegionOp
582 //===----------------------------------------------------------------------===//
583 
getRegionKind(unsigned index)584 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
585   return RegionKind::SSACFG;
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // Test GraphRegionOp
590 //===----------------------------------------------------------------------===//
591 
parseGraphRegionOp(OpAsmParser & parser,OperationState & result)592 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
593                                       OperationState &result) {
594   // Parse the body region, and reuse the operand info as the argument info.
595   Region *body = result.addRegion();
596   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
597 }
598 
print(OpAsmPrinter & p,GraphRegionOp op)599 static void print(OpAsmPrinter &p, GraphRegionOp op) {
600   p << "test.graph_region ";
601   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
602 }
603 
getRegionKind(unsigned index)604 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
605   return RegionKind::Graph;
606 }
607 
608 //===----------------------------------------------------------------------===//
609 // Test AffineScopeOp
610 //===----------------------------------------------------------------------===//
611 
parseAffineScopeOp(OpAsmParser & parser,OperationState & result)612 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
613                                       OperationState &result) {
614   // Parse the body region, and reuse the operand info as the argument info.
615   Region *body = result.addRegion();
616   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
617 }
618 
print(OpAsmPrinter & p,AffineScopeOp op)619 static void print(OpAsmPrinter &p, AffineScopeOp op) {
620   p << "test.affine_scope ";
621   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // Test parser.
626 //===----------------------------------------------------------------------===//
627 
parseParseIntegerLiteralOp(OpAsmParser & parser,OperationState & result)628 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
629                                               OperationState &result) {
630   if (parser.parseOptionalColon())
631     return success();
632   uint64_t numResults;
633   if (parser.parseInteger(numResults))
634     return failure();
635 
636   IndexType type = parser.getBuilder().getIndexType();
637   for (unsigned i = 0; i < numResults; ++i)
638     result.addTypes(type);
639   return success();
640 }
641 
print(OpAsmPrinter & p,ParseIntegerLiteralOp op)642 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
643   p << ParseIntegerLiteralOp::getOperationName();
644   if (unsigned numResults = op->getNumResults())
645     p << " : " << numResults;
646 }
647 
parseParseWrappedKeywordOp(OpAsmParser & parser,OperationState & result)648 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
649                                               OperationState &result) {
650   StringRef keyword;
651   if (parser.parseKeyword(&keyword))
652     return failure();
653   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
654   return success();
655 }
656 
print(OpAsmPrinter & p,ParseWrappedKeywordOp op)657 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
658   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
663 
parseWrappingRegionOp(OpAsmParser & parser,OperationState & result)664 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
665                                          OperationState &result) {
666   if (parser.parseKeyword("wraps"))
667     return failure();
668 
669   // Parse the wrapped op in a region
670   Region &body = *result.addRegion();
671   body.push_back(new Block);
672   Block &block = body.back();
673   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
674   if (!wrapped_op)
675     return failure();
676 
677   // Create a return terminator in the inner region, pass as operand to the
678   // terminator the returned values from the wrapped operation.
679   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
680   OpBuilder builder(parser.getBuilder().getContext());
681   builder.setInsertionPointToEnd(&block);
682   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
683 
684   // Get the results type for the wrapping op from the terminator operands.
685   Operation &return_op = body.back().back();
686   result.types.append(return_op.operand_type_begin(),
687                       return_op.operand_type_end());
688 
689   // Use the location of the wrapped op for the "test.wrapping_region" op.
690   result.location = wrapped_op->getLoc();
691 
692   return success();
693 }
694 
print(OpAsmPrinter & p,WrappingRegionOp op)695 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
696   p << op.getOperationName() << " wraps ";
697   p.printGenericOp(&op.region().front().front());
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // Test PolyForOp - parse list of region arguments.
702 //===----------------------------------------------------------------------===//
703 
parsePolyForOp(OpAsmParser & parser,OperationState & result)704 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
705   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
706   // Parse list of region arguments without a delimiter.
707   if (parser.parseRegionArgumentList(ivsInfo))
708     return failure();
709 
710   // Parse the body region.
711   Region *body = result.addRegion();
712   auto &builder = parser.getBuilder();
713   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
714   return parser.parseRegion(*body, ivsInfo, argTypes);
715 }
716 
717 //===----------------------------------------------------------------------===//
718 // Test removing op with inner ops.
719 //===----------------------------------------------------------------------===//
720 
721 namespace {
722 struct TestRemoveOpWithInnerOps
723     : public OpRewritePattern<TestOpWithRegionPattern> {
724   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
725 
initialize__anon5e0e5cfc0511::TestRemoveOpWithInnerOps726   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
727 
matchAndRewrite__anon5e0e5cfc0511::TestRemoveOpWithInnerOps728   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
729                                 PatternRewriter &rewriter) const override {
730     rewriter.eraseOp(op);
731     return success();
732   }
733 };
734 } // end anonymous namespace
735 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)736 void TestOpWithRegionPattern::getCanonicalizationPatterns(
737     RewritePatternSet &results, MLIRContext *context) {
738   results.add<TestRemoveOpWithInnerOps>(context);
739 }
740 
fold(ArrayRef<Attribute> operands)741 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
742   return operand();
743 }
744 
fold(ArrayRef<Attribute> operands)745 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
746   return getValue();
747 }
748 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)749 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
750     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
751   for (Value input : this->operands()) {
752     results.push_back(input);
753   }
754   return success();
755 }
756 
fold(ArrayRef<Attribute> operands)757 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
758   assert(operands.size() == 1);
759   if (operands.front()) {
760     (*this)->setAttr("attr", operands.front());
761     return getResult();
762   }
763   return {};
764 }
765 
fold(ArrayRef<Attribute> operands)766 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
767   return getOperand();
768 }
769 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)770 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
771     MLIRContext *, Optional<Location> location, ValueRange operands,
772     DictionaryAttr attributes, RegionRange regions,
773     SmallVectorImpl<Type> &inferredReturnTypes) {
774   if (operands[0].getType() != operands[1].getType()) {
775     return emitOptionalError(location, "operand type mismatch ",
776                              operands[0].getType(), " vs ",
777                              operands[1].getType());
778   }
779   inferredReturnTypes.assign({operands[0].getType()});
780   return success();
781 }
782 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)783 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
784     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
785     DictionaryAttr attributes, RegionRange regions,
786     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
787   // Create return type consisting of the last element of the first operand.
788   auto operandType = operands.front().getType();
789   auto sval = operandType.dyn_cast<ShapedType>();
790   if (!sval) {
791     return emitOptionalError(location, "only shaped type operands allowed");
792   }
793   int64_t dim =
794       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
795   auto type = IntegerType::get(context, 17);
796   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
797   return success();
798 }
799 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)800 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
801     OpBuilder &builder, ValueRange operands,
802     llvm::SmallVectorImpl<Value> &shapes) {
803   shapes = SmallVector<Value, 1>{
804       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
805   return success();
806 }
807 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)808 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
809     OpBuilder &builder, ValueRange operands,
810     llvm::SmallVectorImpl<Value> &shapes) {
811   Location loc = getLoc();
812   shapes.reserve(operands.size());
813   for (Value operand : llvm::reverse(operands)) {
814     auto currShape = llvm::to_vector<4>(llvm::map_range(
815         llvm::seq<int64_t>(
816             0, operand.getType().cast<RankedTensorType>().getRank()),
817         [&](int64_t dim) -> Value {
818           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
819         }));
820     shapes.push_back(builder.create<tensor::FromElementsOp>(
821         getLoc(), builder.getIndexType(), currShape));
822   }
823   return success();
824 }
825 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & shapes)826 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
827     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
828   Location loc = getLoc();
829   shapes.reserve(getNumOperands());
830   for (Value operand : llvm::reverse(getOperands())) {
831     auto currShape = llvm::to_vector<4>(llvm::map_range(
832         llvm::seq<int64_t>(
833             0, operand.getType().cast<RankedTensorType>().getRank()),
834         [&](int64_t dim) -> Value {
835           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
836         }));
837     shapes.emplace_back(std::move(currShape));
838   }
839   return success();
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // Test SideEffect interfaces
844 //===----------------------------------------------------------------------===//
845 
846 namespace {
847 /// A test resource for side effects.
848 struct TestResource : public SideEffects::Resource::Base<TestResource> {
getName__anon5e0e5cfc0811::TestResource849   StringRef getName() final { return "<Test>"; }
850 };
851 } // end anonymous namespace
852 
testSideEffectOpGetEffect(Operation * op,SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> & effects)853 static void testSideEffectOpGetEffect(
854     Operation *op,
855     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
856         &effects) {
857   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
858   if (!effectsAttr)
859     return;
860 
861   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
862 }
863 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)864 void SideEffectOp::getEffects(
865     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
866   // Check for an effects attribute on the op instance.
867   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
868   if (!effectsAttr)
869     return;
870 
871   // If there is one, it is an array of dictionary attributes that hold
872   // information on the effects of this operation.
873   for (Attribute element : effectsAttr) {
874     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
875 
876     // Get the specific memory effect.
877     MemoryEffects::Effect *effect =
878         StringSwitch<MemoryEffects::Effect *>(
879             effectElement.get("effect").cast<StringAttr>().getValue())
880             .Case("allocate", MemoryEffects::Allocate::get())
881             .Case("free", MemoryEffects::Free::get())
882             .Case("read", MemoryEffects::Read::get())
883             .Case("write", MemoryEffects::Write::get());
884 
885     // Check for a non-default resource to use.
886     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
887     if (effectElement.get("test_resource"))
888       resource = TestResource::get();
889 
890     // Check for a result to affect.
891     if (effectElement.get("on_result"))
892       effects.emplace_back(effect, getResult(), resource);
893     else if (Attribute ref = effectElement.get("on_reference"))
894       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
895     else
896       effects.emplace_back(effect, resource);
897   }
898 }
899 
getEffects(SmallVectorImpl<TestEffects::EffectInstance> & effects)900 void SideEffectOp::getEffects(
901     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
902   testSideEffectOpGetEffect(getOperation(), effects);
903 }
904 
905 //===----------------------------------------------------------------------===//
906 // StringAttrPrettyNameOp
907 //===----------------------------------------------------------------------===//
908 
909 // This op has fancy handling of its SSA result name.
parseStringAttrPrettyNameOp(OpAsmParser & parser,OperationState & result)910 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
911                                                OperationState &result) {
912   // Add the result types.
913   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
914     result.addTypes(parser.getBuilder().getIntegerType(32));
915 
916   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
917     return failure();
918 
919   // If the attribute dictionary contains no 'names' attribute, infer it from
920   // the SSA name (if specified).
921   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
922     return attr.first == "names";
923   });
924 
925   // If there was no name specified, check to see if there was a useful name
926   // specified in the asm file.
927   if (hadNames || parser.getNumResults() == 0)
928     return success();
929 
930   SmallVector<StringRef, 4> names;
931   auto *context = result.getContext();
932 
933   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
934     auto resultName = parser.getResultName(i);
935     StringRef nameStr;
936     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
937       nameStr = resultName.first;
938 
939     names.push_back(nameStr);
940   }
941 
942   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
943   result.attributes.push_back({Identifier::get("names", context), namesAttr});
944   return success();
945 }
946 
print(OpAsmPrinter & p,StringAttrPrettyNameOp op)947 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
948   p << "test.string_attr_pretty_name";
949 
950   // Note that we only need to print the "name" attribute if the asmprinter
951   // result name disagrees with it.  This can happen in strange cases, e.g.
952   // when there are conflicts.
953   bool namesDisagree = op.names().size() != op.getNumResults();
954 
955   SmallString<32> resultNameStr;
956   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
957     resultNameStr.clear();
958     llvm::raw_svector_ostream tmpStream(resultNameStr);
959     p.printOperand(op.getResult(i), tmpStream);
960 
961     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
962     if (!expectedName ||
963         tmpStream.str().drop_front() != expectedName.getValue()) {
964       namesDisagree = true;
965     }
966   }
967 
968   if (namesDisagree)
969     p.printOptionalAttrDictWithKeyword(op->getAttrs());
970   else
971     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
972 }
973 
974 // We set the SSA name in the asm syntax to the contents of the name
975 // attribute.
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)976 void StringAttrPrettyNameOp::getAsmResultNames(
977     function_ref<void(Value, StringRef)> setNameFn) {
978 
979   auto value = names();
980   for (size_t i = 0, e = value.size(); i != e; ++i)
981     if (auto str = value[i].dyn_cast<StringAttr>())
982       if (!str.getValue().empty())
983         setNameFn(getResult(i), str.getValue());
984 }
985 
986 //===----------------------------------------------------------------------===//
987 // RegionIfOp
988 //===----------------------------------------------------------------------===//
989 
print(OpAsmPrinter & p,RegionIfOp op)990 static void print(OpAsmPrinter &p, RegionIfOp op) {
991   p << RegionIfOp::getOperationName() << " ";
992   p.printOperands(op.getOperands());
993   p << ": " << op.getOperandTypes();
994   p.printArrowTypeList(op.getResultTypes());
995   p << " then";
996   p.printRegion(op.thenRegion(),
997                 /*printEntryBlockArgs=*/true,
998                 /*printBlockTerminators=*/true);
999   p << " else";
1000   p.printRegion(op.elseRegion(),
1001                 /*printEntryBlockArgs=*/true,
1002                 /*printBlockTerminators=*/true);
1003   p << " join";
1004   p.printRegion(op.joinRegion(),
1005                 /*printEntryBlockArgs=*/true,
1006                 /*printBlockTerminators=*/true);
1007 }
1008 
parseRegionIfOp(OpAsmParser & parser,OperationState & result)1009 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1010                                    OperationState &result) {
1011   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1012   SmallVector<Type, 2> operandTypes;
1013 
1014   result.regions.reserve(3);
1015   Region *thenRegion = result.addRegion();
1016   Region *elseRegion = result.addRegion();
1017   Region *joinRegion = result.addRegion();
1018 
1019   // Parse operand, type and arrow type lists.
1020   if (parser.parseOperandList(operandInfos) ||
1021       parser.parseColonTypeList(operandTypes) ||
1022       parser.parseArrowTypeList(result.types))
1023     return failure();
1024 
1025   // Parse all attached regions.
1026   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1027       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1028       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1029     return failure();
1030 
1031   return parser.resolveOperands(operandInfos, operandTypes,
1032                                 parser.getCurrentLocation(), result.operands);
1033 }
1034 
getSuccessorEntryOperands(unsigned index)1035 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1036   assert(index < 2 && "invalid region index");
1037   return getOperands();
1038 }
1039 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1040 void RegionIfOp::getSuccessorRegions(
1041     Optional<unsigned> index, ArrayRef<Attribute> operands,
1042     SmallVectorImpl<RegionSuccessor> &regions) {
1043   // We always branch to the join region.
1044   if (index.hasValue()) {
1045     if (index.getValue() < 2)
1046       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1047     else
1048       regions.push_back(RegionSuccessor(getResults()));
1049     return;
1050   }
1051 
1052   // The then and else regions are the entry regions of this op.
1053   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1054   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // SingleNoTerminatorCustomAsmOp
1059 //===----------------------------------------------------------------------===//
1060 
parseSingleNoTerminatorCustomAsmOp(OpAsmParser & parser,OperationState & state)1061 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1062                                                       OperationState &state) {
1063   Region *body = state.addRegion();
1064   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1065     return failure();
1066   return success();
1067 }
1068 
print(SingleNoTerminatorCustomAsmOp op,OpAsmPrinter & printer)1069 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1070   printer << op.getOperationName();
1071   printer.printRegion(
1072       op.getRegion(), /*printEntryBlockArgs=*/false,
1073       // This op has a single block without terminators. But explicitly mark
1074       // as not printing block terminators for testing.
1075       /*printBlockTerminators=*/false);
1076 }
1077 
1078 #include "TestOpEnums.cpp.inc"
1079 #include "TestOpInterfaces.cpp.inc"
1080 #include "TestOpStructs.cpp.inc"
1081 #include "TestTypeInterfaces.cpp.inc"
1082 
1083 #define GET_OP_CLASSES
1084 #include "TestOps.cpp.inc"
1085