1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 // Include this before the using namespace lines below to
26 // test that we don't have namespace dependencies.
27 #include "TestOpsDialect.cpp.inc"
28 
29 using namespace mlir;
30 using namespace test;
31 
registerTestDialect(DialectRegistry & registry)32 void test::registerTestDialect(DialectRegistry &registry) {
33   registry.insert<TestDialect>();
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // TestDialect Interfaces
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 
42 /// Testing the correctness of some traits.
43 static_assert(
44     llvm::is_detected<OpTrait::has_implicit_terminator_t,
45                       SingleBlockImplicitTerminatorOp>::value,
46     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
47 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
48                   SingleBlockImplicitTerminatorOp>::value,
49               "hasSingleBlockImplicitTerminator does not match "
50               "SingleBlockImplicitTerminatorOp");
51 
52 // Test support for interacting with the AsmPrinter.
53 struct TestOpAsmInterface : public OpAsmDialectInterface {
54   using OpAsmDialectInterface::OpAsmDialectInterface;
55 
getAlias__anon140dcf5b0111::TestOpAsmInterface56   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
57     StringAttr strAttr = attr.dyn_cast<StringAttr>();
58     if (!strAttr)
59       return AliasResult::NoAlias;
60 
61     // Check the contents of the string attribute to see what the test alias
62     // should be named.
63     Optional<StringRef> aliasName =
64         StringSwitch<Optional<StringRef>>(strAttr.getValue())
65             .Case("alias_test:dot_in_name", StringRef("test.alias"))
66             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
67             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
68             .Case("alias_test:sanitize_conflict_a",
69                   StringRef("test_alias_conflict0"))
70             .Case("alias_test:sanitize_conflict_b",
71                   StringRef("test_alias_conflict0_"))
72             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
73             .Default(llvm::None);
74     if (!aliasName)
75       return AliasResult::NoAlias;
76 
77     os << *aliasName;
78     return AliasResult::FinalAlias;
79   }
80 
getAlias__anon140dcf5b0111::TestOpAsmInterface81   AliasResult getAlias(Type type, raw_ostream &os) const final {
82     if (auto tupleType = type.dyn_cast<TupleType>()) {
83       if (tupleType.size() > 0 &&
84           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
85             return elemType.isa<SimpleAType>();
86           })) {
87         os << "test_tuple";
88         return AliasResult::FinalAlias;
89       }
90     }
91     if (auto intType = type.dyn_cast<TestIntegerType>()) {
92       if (intType.getSignedness() ==
93               TestIntegerType::SignednessSemantics::Unsigned &&
94           intType.getWidth() == 8) {
95         os << "test_ui8";
96         return AliasResult::FinalAlias;
97       }
98     }
99     return AliasResult::NoAlias;
100   }
101 
getAsmResultNames__anon140dcf5b0111::TestOpAsmInterface102   void getAsmResultNames(Operation *op,
103                          OpAsmSetValueNameFn setNameFn) const final {
104     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
105       setNameFn(asmOp, "result");
106   }
107 
getAsmBlockArgumentNames__anon140dcf5b0111::TestOpAsmInterface108   void getAsmBlockArgumentNames(Block *block,
109                                 OpAsmSetValueNameFn setNameFn) const final {
110     auto op = block->getParentOp();
111     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
112     if (!arrayAttr)
113       return;
114     auto args = block->getArguments();
115     auto e = std::min(arrayAttr.size(), args.size());
116     for (unsigned i = 0; i < e; ++i) {
117       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
118         setNameFn(args[i], strAttr.getValue());
119     }
120   }
121 };
122 
123 struct TestDialectFoldInterface : public DialectFoldInterface {
124   using DialectFoldInterface::DialectFoldInterface;
125 
126   /// Registered hook to check if the given region, which is attached to an
127   /// operation that is *not* isolated from above, should be used when
128   /// materializing constants.
shouldMaterializeInto__anon140dcf5b0111::TestDialectFoldInterface129   bool shouldMaterializeInto(Region *region) const final {
130     // If this is a one region operation, then insert into it.
131     return isa<OneRegionOp>(region->getParentOp());
132   }
133 };
134 
135 /// This class defines the interface for handling inlining with standard
136 /// operations.
137 struct TestInlinerInterface : public DialectInlinerInterface {
138   using DialectInlinerInterface::DialectInlinerInterface;
139 
140   //===--------------------------------------------------------------------===//
141   // Analysis Hooks
142   //===--------------------------------------------------------------------===//
143 
isLegalToInline__anon140dcf5b0111::TestInlinerInterface144   bool isLegalToInline(Operation *call, Operation *callable,
145                        bool wouldBeCloned) const final {
146     // Don't allow inlining calls that are marked `noinline`.
147     return !call->hasAttr("noinline");
148   }
isLegalToInline__anon140dcf5b0111::TestInlinerInterface149   bool isLegalToInline(Region *, Region *, bool,
150                        BlockAndValueMapping &) const final {
151     // Inlining into test dialect regions is legal.
152     return true;
153   }
isLegalToInline__anon140dcf5b0111::TestInlinerInterface154   bool isLegalToInline(Operation *, Region *, bool,
155                        BlockAndValueMapping &) const final {
156     return true;
157   }
158 
shouldAnalyzeRecursively__anon140dcf5b0111::TestInlinerInterface159   bool shouldAnalyzeRecursively(Operation *op) const final {
160     // Analyze recursively if this is not a functional region operation, it
161     // froms a separate functional scope.
162     return !isa<FunctionalRegionOp>(op);
163   }
164 
165   //===--------------------------------------------------------------------===//
166   // Transformation Hooks
167   //===--------------------------------------------------------------------===//
168 
169   /// Handle the given inlined terminator by replacing it with a new operation
170   /// as necessary.
handleTerminator__anon140dcf5b0111::TestInlinerInterface171   void handleTerminator(Operation *op,
172                         ArrayRef<Value> valuesToRepl) const final {
173     // Only handle "test.return" here.
174     auto returnOp = dyn_cast<TestReturnOp>(op);
175     if (!returnOp)
176       return;
177 
178     // Replace the values directly with the return operands.
179     assert(returnOp.getNumOperands() == valuesToRepl.size());
180     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
181       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
182   }
183 
184   /// Attempt to materialize a conversion for a type mismatch between a call
185   /// from this dialect, and a callable region. This method should generate an
186   /// operation that takes 'input' as the only operand, and produces a single
187   /// result of 'resultType'. If a conversion can not be generated, nullptr
188   /// should be returned.
materializeCallConversion__anon140dcf5b0111::TestInlinerInterface189   Operation *materializeCallConversion(OpBuilder &builder, Value input,
190                                        Type resultType,
191                                        Location conversionLoc) const final {
192     // Only allow conversion for i16/i32 types.
193     if (!(resultType.isSignlessInteger(16) ||
194           resultType.isSignlessInteger(32)) ||
195         !(input.getType().isSignlessInteger(16) ||
196           input.getType().isSignlessInteger(32)))
197       return nullptr;
198     return builder.create<TestCastOp>(conversionLoc, resultType, input);
199   }
200 
processInlinedCallBlocks__anon140dcf5b0111::TestInlinerInterface201   void processInlinedCallBlocks(
202       Operation *call,
203       iterator_range<Region::iterator> inlinedBlocks) const final {
204     if (!isa<ConversionCallOp>(call))
205       return;
206 
207     // Set attributed on all ops in the inlined blocks.
208     for (Block &block : inlinedBlocks) {
209       block.walk([&](Operation *op) {
210         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
211       });
212     }
213   }
214 };
215 
216 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
217 public:
TestReductionPatternInterface__anon140dcf5b0111::TestReductionPatternInterface218   TestReductionPatternInterface(Dialect *dialect)
219       : DialectReductionPatternInterface(dialect) {}
220 
populateReductionPatterns__anon140dcf5b0111::TestReductionPatternInterface221   void populateReductionPatterns(RewritePatternSet &patterns) const final {
222     populateTestReductionPatterns(patterns);
223   }
224 };
225 
226 } // end anonymous namespace
227 
228 //===----------------------------------------------------------------------===//
229 // TestDialect
230 //===----------------------------------------------------------------------===//
231 
232 static void testSideEffectOpGetEffect(
233     Operation *op,
234     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
235 
236 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
237 struct TestOpEffectInterfaceFallback
238     : public TestEffectOpInterface::FallbackModel<
239           TestOpEffectInterfaceFallback> {
classofTestOpEffectInterfaceFallback240   static bool classof(Operation *op) {
241     bool isSupportedOp =
242         op->getName().getStringRef() == "test.unregistered_side_effect_op";
243     assert(isSupportedOp && "Unexpected dispatch");
244     return isSupportedOp;
245   }
246 
247   void
getEffectsTestOpEffectInterfaceFallback248   getEffects(Operation *op,
249              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
250                  &effects) const {
251     testSideEffectOpGetEffect(op, effects);
252   }
253 };
254 
initialize()255 void TestDialect::initialize() {
256   registerAttributes();
257   registerTypes();
258   addOperations<
259 #define GET_OP_LIST
260 #include "TestOps.cpp.inc"
261       >();
262   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
263                 TestInlinerInterface, TestReductionPatternInterface>();
264   allowUnknownOperations();
265 
266   // Instantiate our fallback op interface that we'll use on specific
267   // unregistered op.
268   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
269 }
~TestDialect()270 TestDialect::~TestDialect() {
271   delete static_cast<TestOpEffectInterfaceFallback *>(
272       fallbackEffectOpInterfaces);
273 }
274 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)275 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
276                                             Type type, Location loc) {
277   return builder.create<TestOpConstant>(loc, type, value);
278 }
279 
getRegisteredInterfaceForOp(TypeID typeID,OperationName opName)280 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
281                                                OperationName opName) {
282   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
283       typeID == TypeID::get<TestEffectOpInterface>())
284     return fallbackEffectOpInterfaces;
285   return nullptr;
286 }
287 
verifyOperationAttribute(Operation * op,NamedAttribute namedAttr)288 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
289                                                     NamedAttribute namedAttr) {
290   if (namedAttr.first == "test.invalid_attr")
291     return op->emitError() << "invalid to use 'test.invalid_attr'";
292   return success();
293 }
294 
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute namedAttr)295 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
296                                                     unsigned regionIndex,
297                                                     unsigned argIndex,
298                                                     NamedAttribute namedAttr) {
299   if (namedAttr.first == "test.invalid_attr")
300     return op->emitError() << "invalid to use 'test.invalid_attr'";
301   return success();
302 }
303 
304 LogicalResult
verifyRegionResultAttribute(Operation * op,unsigned regionIndex,unsigned resultIndex,NamedAttribute namedAttr)305 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
306                                          unsigned resultIndex,
307                                          NamedAttribute namedAttr) {
308   if (namedAttr.first == "test.invalid_attr")
309     return op->emitError() << "invalid to use 'test.invalid_attr'";
310   return success();
311 }
312 
313 Optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const314 TestDialect::getParseOperationHook(StringRef opName) const {
315   if (opName == "test.dialect_custom_printer") {
316     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
317       return parser.parseKeyword("custom_format");
318     }};
319   }
320   return None;
321 }
322 
323 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation * op) const324 TestDialect::getOperationPrinter(Operation *op) const {
325   StringRef opName = op->getName().getStringRef();
326   if (opName == "test.dialect_custom_printer") {
327     return [](Operation *op, OpAsmPrinter &printer) {
328       printer.getStream() << " custom_format";
329     };
330   }
331   return {};
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // TestBranchOp
336 //===----------------------------------------------------------------------===//
337 
338 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)339 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
340   assert(index == 0 && "invalid successor index");
341   return targetOperandsMutable();
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // TestDialectCanonicalizerOp
346 //===----------------------------------------------------------------------===//
347 
348 static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,PatternRewriter & rewriter)349 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
350                                PatternRewriter &rewriter) {
351   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
352                                           rewriter.getI32IntegerAttr(42));
353   return success();
354 }
355 
getCanonicalizationPatterns(RewritePatternSet & results) const356 void TestDialect::getCanonicalizationPatterns(
357     RewritePatternSet &results) const {
358   results.add(&dialectCanonicalizationPattern);
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // TestFoldToCallOp
363 //===----------------------------------------------------------------------===//
364 
365 namespace {
366 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
367   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
368 
matchAndRewrite__anon140dcf5b0611::FoldToCallOpPattern369   LogicalResult matchAndRewrite(FoldToCallOp op,
370                                 PatternRewriter &rewriter) const override {
371     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
372                                         ValueRange());
373     return success();
374   }
375 };
376 } // end anonymous namespace
377 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)378 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
379                                                MLIRContext *context) {
380   results.add<FoldToCallOpPattern>(context);
381 }
382 
383 //===----------------------------------------------------------------------===//
384 // Test Format* operations
385 //===----------------------------------------------------------------------===//
386 
387 //===----------------------------------------------------------------------===//
388 // Parsing
389 
parseCustomDirectiveOperands(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands)390 static ParseResult parseCustomDirectiveOperands(
391     OpAsmParser &parser, OpAsmParser::OperandType &operand,
392     Optional<OpAsmParser::OperandType> &optOperand,
393     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
394   if (parser.parseOperand(operand))
395     return failure();
396   if (succeeded(parser.parseOptionalComma())) {
397     optOperand.emplace();
398     if (parser.parseOperand(*optOperand))
399       return failure();
400   }
401   if (parser.parseArrow() || parser.parseLParen() ||
402       parser.parseOperandList(varOperands) || parser.parseRParen())
403     return failure();
404   return success();
405 }
406 static ParseResult
parseCustomDirectiveResults(OpAsmParser & parser,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)407 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
408                             Type &optOperandType,
409                             SmallVectorImpl<Type> &varOperandTypes) {
410   if (parser.parseColon())
411     return failure();
412 
413   if (parser.parseType(operandType))
414     return failure();
415   if (succeeded(parser.parseOptionalComma())) {
416     if (parser.parseType(optOperandType))
417       return failure();
418   }
419   if (parser.parseArrow() || parser.parseLParen() ||
420       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
421     return failure();
422   return success();
423 }
424 static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser & parser,Type operandType,Type optOperandType,const SmallVectorImpl<Type> & varOperandTypes)425 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
426                                  Type optOperandType,
427                                  const SmallVectorImpl<Type> &varOperandTypes) {
428   if (parser.parseKeyword("type_refs_capture"))
429     return failure();
430 
431   Type operandType2, optOperandType2;
432   SmallVector<Type, 1> varOperandTypes2;
433   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
434                                   varOperandTypes2))
435     return failure();
436 
437   if (operandType != operandType2 || optOperandType != optOperandType2 ||
438       varOperandTypes != varOperandTypes2)
439     return failure();
440 
441   return success();
442 }
parseCustomDirectiveOperandsAndTypes(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)443 static ParseResult parseCustomDirectiveOperandsAndTypes(
444     OpAsmParser &parser, OpAsmParser::OperandType &operand,
445     Optional<OpAsmParser::OperandType> &optOperand,
446     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
447     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
448   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
449       parseCustomDirectiveResults(parser, operandType, optOperandType,
450                                   varOperandTypes))
451     return failure();
452   return success();
453 }
parseCustomDirectiveRegions(OpAsmParser & parser,Region & region,SmallVectorImpl<std::unique_ptr<Region>> & varRegions)454 static ParseResult parseCustomDirectiveRegions(
455     OpAsmParser &parser, Region &region,
456     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
457   if (parser.parseRegion(region))
458     return failure();
459   if (failed(parser.parseOptionalComma()))
460     return success();
461   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
462   if (parser.parseRegion(*varRegion))
463     return failure();
464   varRegions.emplace_back(std::move(varRegion));
465   return success();
466 }
467 static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser & parser,Block * & successor,SmallVectorImpl<Block * > & varSuccessors)468 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
469                                SmallVectorImpl<Block *> &varSuccessors) {
470   if (parser.parseSuccessor(successor))
471     return failure();
472   if (failed(parser.parseOptionalComma()))
473     return success();
474   Block *varSuccessor;
475   if (parser.parseSuccessor(varSuccessor))
476     return failure();
477   varSuccessors.append(2, varSuccessor);
478   return success();
479 }
parseCustomDirectiveAttributes(OpAsmParser & parser,IntegerAttr & attr,IntegerAttr & optAttr)480 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
481                                                   IntegerAttr &attr,
482                                                   IntegerAttr &optAttr) {
483   if (parser.parseAttribute(attr))
484     return failure();
485   if (succeeded(parser.parseOptionalComma())) {
486     if (parser.parseAttribute(optAttr))
487       return failure();
488   }
489   return success();
490 }
491 
parseCustomDirectiveAttrDict(OpAsmParser & parser,NamedAttrList & attrs)492 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
493                                                 NamedAttrList &attrs) {
494   return parser.parseOptionalAttrDict(attrs);
495 }
parseCustomDirectiveOptionalOperandRef(OpAsmParser & parser,Optional<OpAsmParser::OperandType> & optOperand)496 static ParseResult parseCustomDirectiveOptionalOperandRef(
497     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
498   int64_t operandCount = 0;
499   if (parser.parseInteger(operandCount))
500     return failure();
501   bool expectedOptionalOperand = operandCount == 0;
502   return success(expectedOptionalOperand != optOperand.hasValue());
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // Printing
507 
printCustomDirectiveOperands(OpAsmPrinter & printer,Operation *,Value operand,Value optOperand,OperandRange varOperands)508 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
509                                          Value operand, Value optOperand,
510                                          OperandRange varOperands) {
511   printer << operand;
512   if (optOperand)
513     printer << ", " << optOperand;
514   printer << " -> (" << varOperands << ")";
515 }
printCustomDirectiveResults(OpAsmPrinter & printer,Operation *,Type operandType,Type optOperandType,TypeRange varOperandTypes)516 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
517                                         Type operandType, Type optOperandType,
518                                         TypeRange varOperandTypes) {
519   printer << " : " << operandType;
520   if (optOperandType)
521     printer << ", " << optOperandType;
522   printer << " -> (" << varOperandTypes << ")";
523 }
printCustomDirectiveWithTypeRefs(OpAsmPrinter & printer,Operation * op,Type operandType,Type optOperandType,TypeRange varOperandTypes)524 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
525                                              Operation *op, Type operandType,
526                                              Type optOperandType,
527                                              TypeRange varOperandTypes) {
528   printer << " type_refs_capture ";
529   printCustomDirectiveResults(printer, op, operandType, optOperandType,
530                               varOperandTypes);
531 }
printCustomDirectiveOperandsAndTypes(OpAsmPrinter & printer,Operation * op,Value operand,Value optOperand,OperandRange varOperands,Type operandType,Type optOperandType,TypeRange varOperandTypes)532 static void printCustomDirectiveOperandsAndTypes(
533     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
534     OperandRange varOperands, Type operandType, Type optOperandType,
535     TypeRange varOperandTypes) {
536   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
537   printCustomDirectiveResults(printer, op, operandType, optOperandType,
538                               varOperandTypes);
539 }
printCustomDirectiveRegions(OpAsmPrinter & printer,Operation *,Region & region,MutableArrayRef<Region> varRegions)540 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
541                                         Region &region,
542                                         MutableArrayRef<Region> varRegions) {
543   printer.printRegion(region);
544   if (!varRegions.empty()) {
545     printer << ", ";
546     for (Region &region : varRegions)
547       printer.printRegion(region);
548   }
549 }
printCustomDirectiveSuccessors(OpAsmPrinter & printer,Operation *,Block * successor,SuccessorRange varSuccessors)550 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
551                                            Block *successor,
552                                            SuccessorRange varSuccessors) {
553   printer << successor;
554   if (!varSuccessors.empty())
555     printer << ", " << varSuccessors.front();
556 }
printCustomDirectiveAttributes(OpAsmPrinter & printer,Operation *,Attribute attribute,Attribute optAttribute)557 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
558                                            Attribute attribute,
559                                            Attribute optAttribute) {
560   printer << attribute;
561   if (optAttribute)
562     printer << ", " << optAttribute;
563 }
564 
printCustomDirectiveAttrDict(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)565 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
566                                          DictionaryAttr attrs) {
567   printer.printOptionalAttrDict(attrs.getValue());
568 }
569 
printCustomDirectiveOptionalOperandRef(OpAsmPrinter & printer,Operation * op,Value optOperand)570 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
571                                                    Operation *op,
572                                                    Value optOperand) {
573   printer << (optOperand ? "1" : "0");
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // Test IsolatedRegionOp - parse passthrough region arguments.
578 //===----------------------------------------------------------------------===//
579 
parseIsolatedRegionOp(OpAsmParser & parser,OperationState & result)580 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
581                                          OperationState &result) {
582   OpAsmParser::OperandType argInfo;
583   Type argType = parser.getBuilder().getIndexType();
584 
585   // Parse the input operand.
586   if (parser.parseOperand(argInfo) ||
587       parser.resolveOperand(argInfo, argType, result.operands))
588     return failure();
589 
590   // Parse the body region, and reuse the operand info as the argument info.
591   Region *body = result.addRegion();
592   return parser.parseRegion(*body, argInfo, argType,
593                             /*enableNameShadowing=*/true);
594 }
595 
print(OpAsmPrinter & p,IsolatedRegionOp op)596 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
597   p << "test.isolated_region ";
598   p.printOperand(op.getOperand());
599   p.shadowRegionArgs(op.region(), op.getOperand());
600   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // Test SSACFGRegionOp
605 //===----------------------------------------------------------------------===//
606 
getRegionKind(unsigned index)607 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
608   return RegionKind::SSACFG;
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // Test GraphRegionOp
613 //===----------------------------------------------------------------------===//
614 
parseGraphRegionOp(OpAsmParser & parser,OperationState & result)615 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
616                                       OperationState &result) {
617   // Parse the body region, and reuse the operand info as the argument info.
618   Region *body = result.addRegion();
619   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
620 }
621 
print(OpAsmPrinter & p,GraphRegionOp op)622 static void print(OpAsmPrinter &p, GraphRegionOp op) {
623   p << "test.graph_region ";
624   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
625 }
626 
getRegionKind(unsigned index)627 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
628   return RegionKind::Graph;
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // Test AffineScopeOp
633 //===----------------------------------------------------------------------===//
634 
parseAffineScopeOp(OpAsmParser & parser,OperationState & result)635 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
636                                       OperationState &result) {
637   // Parse the body region, and reuse the operand info as the argument info.
638   Region *body = result.addRegion();
639   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
640 }
641 
print(OpAsmPrinter & p,AffineScopeOp op)642 static void print(OpAsmPrinter &p, AffineScopeOp op) {
643   p << "test.affine_scope ";
644   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // Test parser.
649 //===----------------------------------------------------------------------===//
650 
parseParseIntegerLiteralOp(OpAsmParser & parser,OperationState & result)651 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
652                                               OperationState &result) {
653   if (parser.parseOptionalColon())
654     return success();
655   uint64_t numResults;
656   if (parser.parseInteger(numResults))
657     return failure();
658 
659   IndexType type = parser.getBuilder().getIndexType();
660   for (unsigned i = 0; i < numResults; ++i)
661     result.addTypes(type);
662   return success();
663 }
664 
print(OpAsmPrinter & p,ParseIntegerLiteralOp op)665 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
666   if (unsigned numResults = op->getNumResults())
667     p << " : " << numResults;
668 }
669 
parseParseWrappedKeywordOp(OpAsmParser & parser,OperationState & result)670 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
671                                               OperationState &result) {
672   StringRef keyword;
673   if (parser.parseKeyword(&keyword))
674     return failure();
675   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
676   return success();
677 }
678 
print(OpAsmPrinter & p,ParseWrappedKeywordOp op)679 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
680   p << " " << op.keyword();
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
685 
parseWrappingRegionOp(OpAsmParser & parser,OperationState & result)686 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
687                                          OperationState &result) {
688   if (parser.parseKeyword("wraps"))
689     return failure();
690 
691   // Parse the wrapped op in a region
692   Region &body = *result.addRegion();
693   body.push_back(new Block);
694   Block &block = body.back();
695   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
696   if (!wrapped_op)
697     return failure();
698 
699   // Create a return terminator in the inner region, pass as operand to the
700   // terminator the returned values from the wrapped operation.
701   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
702   OpBuilder builder(parser.getContext());
703   builder.setInsertionPointToEnd(&block);
704   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
705 
706   // Get the results type for the wrapping op from the terminator operands.
707   Operation &return_op = body.back().back();
708   result.types.append(return_op.operand_type_begin(),
709                       return_op.operand_type_end());
710 
711   // Use the location of the wrapped op for the "test.wrapping_region" op.
712   result.location = wrapped_op->getLoc();
713 
714   return success();
715 }
716 
print(OpAsmPrinter & p,WrappingRegionOp op)717 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
718   p << " wraps ";
719   p.printGenericOp(&op.region().front().front());
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // Test PolyForOp - parse list of region arguments.
724 //===----------------------------------------------------------------------===//
725 
parsePolyForOp(OpAsmParser & parser,OperationState & result)726 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
727   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
728   // Parse list of region arguments without a delimiter.
729   if (parser.parseRegionArgumentList(ivsInfo))
730     return failure();
731 
732   // Parse the body region.
733   Region *body = result.addRegion();
734   auto &builder = parser.getBuilder();
735   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
736   return parser.parseRegion(*body, ivsInfo, argTypes);
737 }
738 
739 //===----------------------------------------------------------------------===//
740 // Test removing op with inner ops.
741 //===----------------------------------------------------------------------===//
742 
743 namespace {
744 struct TestRemoveOpWithInnerOps
745     : public OpRewritePattern<TestOpWithRegionPattern> {
746   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
747 
initialize__anon140dcf5b0711::TestRemoveOpWithInnerOps748   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
749 
matchAndRewrite__anon140dcf5b0711::TestRemoveOpWithInnerOps750   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
751                                 PatternRewriter &rewriter) const override {
752     rewriter.eraseOp(op);
753     return success();
754   }
755 };
756 } // end anonymous namespace
757 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)758 void TestOpWithRegionPattern::getCanonicalizationPatterns(
759     RewritePatternSet &results, MLIRContext *context) {
760   results.add<TestRemoveOpWithInnerOps>(context);
761 }
762 
fold(ArrayRef<Attribute> operands)763 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
764   return operand();
765 }
766 
fold(ArrayRef<Attribute> operands)767 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
768   return getValue();
769 }
770 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)771 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
772     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
773   for (Value input : this->operands()) {
774     results.push_back(input);
775   }
776   return success();
777 }
778 
fold(ArrayRef<Attribute> operands)779 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
780   assert(operands.size() == 1);
781   if (operands.front()) {
782     (*this)->setAttr("attr", operands.front());
783     return getResult();
784   }
785   return {};
786 }
787 
fold(ArrayRef<Attribute> operands)788 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
789   return getOperand();
790 }
791 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)792 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
793     MLIRContext *, Optional<Location> location, ValueRange operands,
794     DictionaryAttr attributes, RegionRange regions,
795     SmallVectorImpl<Type> &inferredReturnTypes) {
796   if (operands[0].getType() != operands[1].getType()) {
797     return emitOptionalError(location, "operand type mismatch ",
798                              operands[0].getType(), " vs ",
799                              operands[1].getType());
800   }
801   inferredReturnTypes.assign({operands[0].getType()});
802   return success();
803 }
804 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)805 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
806     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
807     DictionaryAttr attributes, RegionRange regions,
808     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
809   // Create return type consisting of the last element of the first operand.
810   auto operandType = operands.front().getType();
811   auto sval = operandType.dyn_cast<ShapedType>();
812   if (!sval) {
813     return emitOptionalError(location, "only shaped type operands allowed");
814   }
815   int64_t dim =
816       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
817   auto type = IntegerType::get(context, 17);
818   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
819   return success();
820 }
821 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)822 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
823     OpBuilder &builder, ValueRange operands,
824     llvm::SmallVectorImpl<Value> &shapes) {
825   shapes = SmallVector<Value, 1>{
826       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
827   return success();
828 }
829 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)830 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
831     OpBuilder &builder, ValueRange operands,
832     llvm::SmallVectorImpl<Value> &shapes) {
833   Location loc = getLoc();
834   shapes.reserve(operands.size());
835   for (Value operand : llvm::reverse(operands)) {
836     auto currShape = llvm::to_vector<4>(llvm::map_range(
837         llvm::seq<int64_t>(
838             0, operand.getType().cast<RankedTensorType>().getRank()),
839         [&](int64_t dim) -> Value {
840           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
841         }));
842     shapes.push_back(builder.create<tensor::FromElementsOp>(
843         getLoc(), builder.getIndexType(), currShape));
844   }
845   return success();
846 }
847 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & shapes)848 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
849     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
850   Location loc = getLoc();
851   shapes.reserve(getNumOperands());
852   for (Value operand : llvm::reverse(getOperands())) {
853     auto currShape = llvm::to_vector<4>(llvm::map_range(
854         llvm::seq<int64_t>(
855             0, operand.getType().cast<RankedTensorType>().getRank()),
856         [&](int64_t dim) -> Value {
857           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
858         }));
859     shapes.emplace_back(std::move(currShape));
860   }
861   return success();
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // Test SideEffect interfaces
866 //===----------------------------------------------------------------------===//
867 
868 namespace {
869 /// A test resource for side effects.
870 struct TestResource : public SideEffects::Resource::Base<TestResource> {
getName__anon140dcf5b0a11::TestResource871   StringRef getName() final { return "<Test>"; }
872 };
873 } // end anonymous namespace
874 
testSideEffectOpGetEffect(Operation * op,SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> & effects)875 static void testSideEffectOpGetEffect(
876     Operation *op,
877     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
878         &effects) {
879   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
880   if (!effectsAttr)
881     return;
882 
883   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
884 }
885 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)886 void SideEffectOp::getEffects(
887     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
888   // Check for an effects attribute on the op instance.
889   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
890   if (!effectsAttr)
891     return;
892 
893   // If there is one, it is an array of dictionary attributes that hold
894   // information on the effects of this operation.
895   for (Attribute element : effectsAttr) {
896     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
897 
898     // Get the specific memory effect.
899     MemoryEffects::Effect *effect =
900         StringSwitch<MemoryEffects::Effect *>(
901             effectElement.get("effect").cast<StringAttr>().getValue())
902             .Case("allocate", MemoryEffects::Allocate::get())
903             .Case("free", MemoryEffects::Free::get())
904             .Case("read", MemoryEffects::Read::get())
905             .Case("write", MemoryEffects::Write::get());
906 
907     // Check for a non-default resource to use.
908     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
909     if (effectElement.get("test_resource"))
910       resource = TestResource::get();
911 
912     // Check for a result to affect.
913     if (effectElement.get("on_result"))
914       effects.emplace_back(effect, getResult(), resource);
915     else if (Attribute ref = effectElement.get("on_reference"))
916       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
917     else
918       effects.emplace_back(effect, resource);
919   }
920 }
921 
getEffects(SmallVectorImpl<TestEffects::EffectInstance> & effects)922 void SideEffectOp::getEffects(
923     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
924   testSideEffectOpGetEffect(getOperation(), effects);
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // StringAttrPrettyNameOp
929 //===----------------------------------------------------------------------===//
930 
931 // This op has fancy handling of its SSA result name.
parseStringAttrPrettyNameOp(OpAsmParser & parser,OperationState & result)932 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
933                                                OperationState &result) {
934   // Add the result types.
935   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
936     result.addTypes(parser.getBuilder().getIntegerType(32));
937 
938   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
939     return failure();
940 
941   // If the attribute dictionary contains no 'names' attribute, infer it from
942   // the SSA name (if specified).
943   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
944     return attr.first == "names";
945   });
946 
947   // If there was no name specified, check to see if there was a useful name
948   // specified in the asm file.
949   if (hadNames || parser.getNumResults() == 0)
950     return success();
951 
952   SmallVector<StringRef, 4> names;
953   auto *context = result.getContext();
954 
955   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
956     auto resultName = parser.getResultName(i);
957     StringRef nameStr;
958     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
959       nameStr = resultName.first;
960 
961     names.push_back(nameStr);
962   }
963 
964   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
965   result.attributes.push_back({Identifier::get("names", context), namesAttr});
966   return success();
967 }
968 
print(OpAsmPrinter & p,StringAttrPrettyNameOp op)969 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
970   // Note that we only need to print the "name" attribute if the asmprinter
971   // result name disagrees with it.  This can happen in strange cases, e.g.
972   // when there are conflicts.
973   bool namesDisagree = op.names().size() != op.getNumResults();
974 
975   SmallString<32> resultNameStr;
976   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
977     resultNameStr.clear();
978     llvm::raw_svector_ostream tmpStream(resultNameStr);
979     p.printOperand(op.getResult(i), tmpStream);
980 
981     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
982     if (!expectedName ||
983         tmpStream.str().drop_front() != expectedName.getValue()) {
984       namesDisagree = true;
985     }
986   }
987 
988   if (namesDisagree)
989     p.printOptionalAttrDictWithKeyword(op->getAttrs());
990   else
991     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
992 }
993 
994 // We set the SSA name in the asm syntax to the contents of the name
995 // attribute.
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)996 void StringAttrPrettyNameOp::getAsmResultNames(
997     function_ref<void(Value, StringRef)> setNameFn) {
998 
999   auto value = names();
1000   for (size_t i = 0, e = value.size(); i != e; ++i)
1001     if (auto str = value[i].dyn_cast<StringAttr>())
1002       if (!str.getValue().empty())
1003         setNameFn(getResult(i), str.getValue());
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // RegionIfOp
1008 //===----------------------------------------------------------------------===//
1009 
print(OpAsmPrinter & p,RegionIfOp op)1010 static void print(OpAsmPrinter &p, RegionIfOp op) {
1011   p << " ";
1012   p.printOperands(op.getOperands());
1013   p << ": " << op.getOperandTypes();
1014   p.printArrowTypeList(op.getResultTypes());
1015   p << " then";
1016   p.printRegion(op.thenRegion(),
1017                 /*printEntryBlockArgs=*/true,
1018                 /*printBlockTerminators=*/true);
1019   p << " else";
1020   p.printRegion(op.elseRegion(),
1021                 /*printEntryBlockArgs=*/true,
1022                 /*printBlockTerminators=*/true);
1023   p << " join";
1024   p.printRegion(op.joinRegion(),
1025                 /*printEntryBlockArgs=*/true,
1026                 /*printBlockTerminators=*/true);
1027 }
1028 
parseRegionIfOp(OpAsmParser & parser,OperationState & result)1029 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1030                                    OperationState &result) {
1031   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1032   SmallVector<Type, 2> operandTypes;
1033 
1034   result.regions.reserve(3);
1035   Region *thenRegion = result.addRegion();
1036   Region *elseRegion = result.addRegion();
1037   Region *joinRegion = result.addRegion();
1038 
1039   // Parse operand, type and arrow type lists.
1040   if (parser.parseOperandList(operandInfos) ||
1041       parser.parseColonTypeList(operandTypes) ||
1042       parser.parseArrowTypeList(result.types))
1043     return failure();
1044 
1045   // Parse all attached regions.
1046   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1047       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1048       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1049     return failure();
1050 
1051   return parser.resolveOperands(operandInfos, operandTypes,
1052                                 parser.getCurrentLocation(), result.operands);
1053 }
1054 
getSuccessorEntryOperands(unsigned index)1055 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1056   assert(index < 2 && "invalid region index");
1057   return getOperands();
1058 }
1059 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1060 void RegionIfOp::getSuccessorRegions(
1061     Optional<unsigned> index, ArrayRef<Attribute> operands,
1062     SmallVectorImpl<RegionSuccessor> &regions) {
1063   // We always branch to the join region.
1064   if (index.hasValue()) {
1065     if (index.getValue() < 2)
1066       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1067     else
1068       regions.push_back(RegionSuccessor(getResults()));
1069     return;
1070   }
1071 
1072   // The then and else regions are the entry regions of this op.
1073   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1074   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // SingleNoTerminatorCustomAsmOp
1079 //===----------------------------------------------------------------------===//
1080 
parseSingleNoTerminatorCustomAsmOp(OpAsmParser & parser,OperationState & state)1081 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1082                                                       OperationState &state) {
1083   Region *body = state.addRegion();
1084   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1085     return failure();
1086   return success();
1087 }
1088 
print(SingleNoTerminatorCustomAsmOp op,OpAsmPrinter & printer)1089 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1090   printer.printRegion(
1091       op.getRegion(), /*printEntryBlockArgs=*/false,
1092       // This op has a single block without terminators. But explicitly mark
1093       // as not printing block terminators for testing.
1094       /*printBlockTerminators=*/false);
1095 }
1096 
1097 #include "TestOpEnums.cpp.inc"
1098 #include "TestOpInterfaces.cpp.inc"
1099 #include "TestOpStructs.cpp.inc"
1100 #include "TestTypeInterfaces.cpp.inc"
1101 
1102 #define GET_OP_CLASSES
1103 #include "TestOps.cpp.inc"
1104