1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24
25 // Include this before the using namespace lines below to
26 // test that we don't have namespace dependencies.
27 #include "TestOpsDialect.cpp.inc"
28
29 using namespace mlir;
30 using namespace test;
31
registerTestDialect(DialectRegistry & registry)32 void test::registerTestDialect(DialectRegistry ®istry) {
33 registry.insert<TestDialect>();
34 }
35
36 //===----------------------------------------------------------------------===//
37 // TestDialect Interfaces
38 //===----------------------------------------------------------------------===//
39
40 namespace {
41
42 /// Testing the correctness of some traits.
43 static_assert(
44 llvm::is_detected<OpTrait::has_implicit_terminator_t,
45 SingleBlockImplicitTerminatorOp>::value,
46 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
47 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
48 SingleBlockImplicitTerminatorOp>::value,
49 "hasSingleBlockImplicitTerminator does not match "
50 "SingleBlockImplicitTerminatorOp");
51
52 // Test support for interacting with the AsmPrinter.
53 struct TestOpAsmInterface : public OpAsmDialectInterface {
54 using OpAsmDialectInterface::OpAsmDialectInterface;
55
getAlias__anon140dcf5b0111::TestOpAsmInterface56 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
57 StringAttr strAttr = attr.dyn_cast<StringAttr>();
58 if (!strAttr)
59 return AliasResult::NoAlias;
60
61 // Check the contents of the string attribute to see what the test alias
62 // should be named.
63 Optional<StringRef> aliasName =
64 StringSwitch<Optional<StringRef>>(strAttr.getValue())
65 .Case("alias_test:dot_in_name", StringRef("test.alias"))
66 .Case("alias_test:trailing_digit", StringRef("test_alias0"))
67 .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
68 .Case("alias_test:sanitize_conflict_a",
69 StringRef("test_alias_conflict0"))
70 .Case("alias_test:sanitize_conflict_b",
71 StringRef("test_alias_conflict0_"))
72 .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
73 .Default(llvm::None);
74 if (!aliasName)
75 return AliasResult::NoAlias;
76
77 os << *aliasName;
78 return AliasResult::FinalAlias;
79 }
80
getAlias__anon140dcf5b0111::TestOpAsmInterface81 AliasResult getAlias(Type type, raw_ostream &os) const final {
82 if (auto tupleType = type.dyn_cast<TupleType>()) {
83 if (tupleType.size() > 0 &&
84 llvm::all_of(tupleType.getTypes(), [](Type elemType) {
85 return elemType.isa<SimpleAType>();
86 })) {
87 os << "test_tuple";
88 return AliasResult::FinalAlias;
89 }
90 }
91 if (auto intType = type.dyn_cast<TestIntegerType>()) {
92 if (intType.getSignedness() ==
93 TestIntegerType::SignednessSemantics::Unsigned &&
94 intType.getWidth() == 8) {
95 os << "test_ui8";
96 return AliasResult::FinalAlias;
97 }
98 }
99 return AliasResult::NoAlias;
100 }
101
getAsmResultNames__anon140dcf5b0111::TestOpAsmInterface102 void getAsmResultNames(Operation *op,
103 OpAsmSetValueNameFn setNameFn) const final {
104 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
105 setNameFn(asmOp, "result");
106 }
107
getAsmBlockArgumentNames__anon140dcf5b0111::TestOpAsmInterface108 void getAsmBlockArgumentNames(Block *block,
109 OpAsmSetValueNameFn setNameFn) const final {
110 auto op = block->getParentOp();
111 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
112 if (!arrayAttr)
113 return;
114 auto args = block->getArguments();
115 auto e = std::min(arrayAttr.size(), args.size());
116 for (unsigned i = 0; i < e; ++i) {
117 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
118 setNameFn(args[i], strAttr.getValue());
119 }
120 }
121 };
122
123 struct TestDialectFoldInterface : public DialectFoldInterface {
124 using DialectFoldInterface::DialectFoldInterface;
125
126 /// Registered hook to check if the given region, which is attached to an
127 /// operation that is *not* isolated from above, should be used when
128 /// materializing constants.
shouldMaterializeInto__anon140dcf5b0111::TestDialectFoldInterface129 bool shouldMaterializeInto(Region *region) const final {
130 // If this is a one region operation, then insert into it.
131 return isa<OneRegionOp>(region->getParentOp());
132 }
133 };
134
135 /// This class defines the interface for handling inlining with standard
136 /// operations.
137 struct TestInlinerInterface : public DialectInlinerInterface {
138 using DialectInlinerInterface::DialectInlinerInterface;
139
140 //===--------------------------------------------------------------------===//
141 // Analysis Hooks
142 //===--------------------------------------------------------------------===//
143
isLegalToInline__anon140dcf5b0111::TestInlinerInterface144 bool isLegalToInline(Operation *call, Operation *callable,
145 bool wouldBeCloned) const final {
146 // Don't allow inlining calls that are marked `noinline`.
147 return !call->hasAttr("noinline");
148 }
isLegalToInline__anon140dcf5b0111::TestInlinerInterface149 bool isLegalToInline(Region *, Region *, bool,
150 BlockAndValueMapping &) const final {
151 // Inlining into test dialect regions is legal.
152 return true;
153 }
isLegalToInline__anon140dcf5b0111::TestInlinerInterface154 bool isLegalToInline(Operation *, Region *, bool,
155 BlockAndValueMapping &) const final {
156 return true;
157 }
158
shouldAnalyzeRecursively__anon140dcf5b0111::TestInlinerInterface159 bool shouldAnalyzeRecursively(Operation *op) const final {
160 // Analyze recursively if this is not a functional region operation, it
161 // froms a separate functional scope.
162 return !isa<FunctionalRegionOp>(op);
163 }
164
165 //===--------------------------------------------------------------------===//
166 // Transformation Hooks
167 //===--------------------------------------------------------------------===//
168
169 /// Handle the given inlined terminator by replacing it with a new operation
170 /// as necessary.
handleTerminator__anon140dcf5b0111::TestInlinerInterface171 void handleTerminator(Operation *op,
172 ArrayRef<Value> valuesToRepl) const final {
173 // Only handle "test.return" here.
174 auto returnOp = dyn_cast<TestReturnOp>(op);
175 if (!returnOp)
176 return;
177
178 // Replace the values directly with the return operands.
179 assert(returnOp.getNumOperands() == valuesToRepl.size());
180 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
181 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
182 }
183
184 /// Attempt to materialize a conversion for a type mismatch between a call
185 /// from this dialect, and a callable region. This method should generate an
186 /// operation that takes 'input' as the only operand, and produces a single
187 /// result of 'resultType'. If a conversion can not be generated, nullptr
188 /// should be returned.
materializeCallConversion__anon140dcf5b0111::TestInlinerInterface189 Operation *materializeCallConversion(OpBuilder &builder, Value input,
190 Type resultType,
191 Location conversionLoc) const final {
192 // Only allow conversion for i16/i32 types.
193 if (!(resultType.isSignlessInteger(16) ||
194 resultType.isSignlessInteger(32)) ||
195 !(input.getType().isSignlessInteger(16) ||
196 input.getType().isSignlessInteger(32)))
197 return nullptr;
198 return builder.create<TestCastOp>(conversionLoc, resultType, input);
199 }
200
processInlinedCallBlocks__anon140dcf5b0111::TestInlinerInterface201 void processInlinedCallBlocks(
202 Operation *call,
203 iterator_range<Region::iterator> inlinedBlocks) const final {
204 if (!isa<ConversionCallOp>(call))
205 return;
206
207 // Set attributed on all ops in the inlined blocks.
208 for (Block &block : inlinedBlocks) {
209 block.walk([&](Operation *op) {
210 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
211 });
212 }
213 }
214 };
215
216 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
217 public:
TestReductionPatternInterface__anon140dcf5b0111::TestReductionPatternInterface218 TestReductionPatternInterface(Dialect *dialect)
219 : DialectReductionPatternInterface(dialect) {}
220
populateReductionPatterns__anon140dcf5b0111::TestReductionPatternInterface221 void populateReductionPatterns(RewritePatternSet &patterns) const final {
222 populateTestReductionPatterns(patterns);
223 }
224 };
225
226 } // end anonymous namespace
227
228 //===----------------------------------------------------------------------===//
229 // TestDialect
230 //===----------------------------------------------------------------------===//
231
232 static void testSideEffectOpGetEffect(
233 Operation *op,
234 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
235
236 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
237 struct TestOpEffectInterfaceFallback
238 : public TestEffectOpInterface::FallbackModel<
239 TestOpEffectInterfaceFallback> {
classofTestOpEffectInterfaceFallback240 static bool classof(Operation *op) {
241 bool isSupportedOp =
242 op->getName().getStringRef() == "test.unregistered_side_effect_op";
243 assert(isSupportedOp && "Unexpected dispatch");
244 return isSupportedOp;
245 }
246
247 void
getEffectsTestOpEffectInterfaceFallback248 getEffects(Operation *op,
249 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
250 &effects) const {
251 testSideEffectOpGetEffect(op, effects);
252 }
253 };
254
initialize()255 void TestDialect::initialize() {
256 registerAttributes();
257 registerTypes();
258 addOperations<
259 #define GET_OP_LIST
260 #include "TestOps.cpp.inc"
261 >();
262 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
263 TestInlinerInterface, TestReductionPatternInterface>();
264 allowUnknownOperations();
265
266 // Instantiate our fallback op interface that we'll use on specific
267 // unregistered op.
268 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
269 }
~TestDialect()270 TestDialect::~TestDialect() {
271 delete static_cast<TestOpEffectInterfaceFallback *>(
272 fallbackEffectOpInterfaces);
273 }
274
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)275 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
276 Type type, Location loc) {
277 return builder.create<TestOpConstant>(loc, type, value);
278 }
279
getRegisteredInterfaceForOp(TypeID typeID,OperationName opName)280 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
281 OperationName opName) {
282 if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
283 typeID == TypeID::get<TestEffectOpInterface>())
284 return fallbackEffectOpInterfaces;
285 return nullptr;
286 }
287
verifyOperationAttribute(Operation * op,NamedAttribute namedAttr)288 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
289 NamedAttribute namedAttr) {
290 if (namedAttr.first == "test.invalid_attr")
291 return op->emitError() << "invalid to use 'test.invalid_attr'";
292 return success();
293 }
294
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute namedAttr)295 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
296 unsigned regionIndex,
297 unsigned argIndex,
298 NamedAttribute namedAttr) {
299 if (namedAttr.first == "test.invalid_attr")
300 return op->emitError() << "invalid to use 'test.invalid_attr'";
301 return success();
302 }
303
304 LogicalResult
verifyRegionResultAttribute(Operation * op,unsigned regionIndex,unsigned resultIndex,NamedAttribute namedAttr)305 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
306 unsigned resultIndex,
307 NamedAttribute namedAttr) {
308 if (namedAttr.first == "test.invalid_attr")
309 return op->emitError() << "invalid to use 'test.invalid_attr'";
310 return success();
311 }
312
313 Optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const314 TestDialect::getParseOperationHook(StringRef opName) const {
315 if (opName == "test.dialect_custom_printer") {
316 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
317 return parser.parseKeyword("custom_format");
318 }};
319 }
320 return None;
321 }
322
323 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation * op) const324 TestDialect::getOperationPrinter(Operation *op) const {
325 StringRef opName = op->getName().getStringRef();
326 if (opName == "test.dialect_custom_printer") {
327 return [](Operation *op, OpAsmPrinter &printer) {
328 printer.getStream() << " custom_format";
329 };
330 }
331 return {};
332 }
333
334 //===----------------------------------------------------------------------===//
335 // TestBranchOp
336 //===----------------------------------------------------------------------===//
337
338 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)339 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
340 assert(index == 0 && "invalid successor index");
341 return targetOperandsMutable();
342 }
343
344 //===----------------------------------------------------------------------===//
345 // TestDialectCanonicalizerOp
346 //===----------------------------------------------------------------------===//
347
348 static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,PatternRewriter & rewriter)349 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
350 PatternRewriter &rewriter) {
351 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
352 rewriter.getI32IntegerAttr(42));
353 return success();
354 }
355
getCanonicalizationPatterns(RewritePatternSet & results) const356 void TestDialect::getCanonicalizationPatterns(
357 RewritePatternSet &results) const {
358 results.add(&dialectCanonicalizationPattern);
359 }
360
361 //===----------------------------------------------------------------------===//
362 // TestFoldToCallOp
363 //===----------------------------------------------------------------------===//
364
365 namespace {
366 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
367 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
368
matchAndRewrite__anon140dcf5b0611::FoldToCallOpPattern369 LogicalResult matchAndRewrite(FoldToCallOp op,
370 PatternRewriter &rewriter) const override {
371 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
372 ValueRange());
373 return success();
374 }
375 };
376 } // end anonymous namespace
377
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)378 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
379 MLIRContext *context) {
380 results.add<FoldToCallOpPattern>(context);
381 }
382
383 //===----------------------------------------------------------------------===//
384 // Test Format* operations
385 //===----------------------------------------------------------------------===//
386
387 //===----------------------------------------------------------------------===//
388 // Parsing
389
parseCustomDirectiveOperands(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands)390 static ParseResult parseCustomDirectiveOperands(
391 OpAsmParser &parser, OpAsmParser::OperandType &operand,
392 Optional<OpAsmParser::OperandType> &optOperand,
393 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
394 if (parser.parseOperand(operand))
395 return failure();
396 if (succeeded(parser.parseOptionalComma())) {
397 optOperand.emplace();
398 if (parser.parseOperand(*optOperand))
399 return failure();
400 }
401 if (parser.parseArrow() || parser.parseLParen() ||
402 parser.parseOperandList(varOperands) || parser.parseRParen())
403 return failure();
404 return success();
405 }
406 static ParseResult
parseCustomDirectiveResults(OpAsmParser & parser,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)407 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
408 Type &optOperandType,
409 SmallVectorImpl<Type> &varOperandTypes) {
410 if (parser.parseColon())
411 return failure();
412
413 if (parser.parseType(operandType))
414 return failure();
415 if (succeeded(parser.parseOptionalComma())) {
416 if (parser.parseType(optOperandType))
417 return failure();
418 }
419 if (parser.parseArrow() || parser.parseLParen() ||
420 parser.parseTypeList(varOperandTypes) || parser.parseRParen())
421 return failure();
422 return success();
423 }
424 static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser & parser,Type operandType,Type optOperandType,const SmallVectorImpl<Type> & varOperandTypes)425 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
426 Type optOperandType,
427 const SmallVectorImpl<Type> &varOperandTypes) {
428 if (parser.parseKeyword("type_refs_capture"))
429 return failure();
430
431 Type operandType2, optOperandType2;
432 SmallVector<Type, 1> varOperandTypes2;
433 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
434 varOperandTypes2))
435 return failure();
436
437 if (operandType != operandType2 || optOperandType != optOperandType2 ||
438 varOperandTypes != varOperandTypes2)
439 return failure();
440
441 return success();
442 }
parseCustomDirectiveOperandsAndTypes(OpAsmParser & parser,OpAsmParser::OperandType & operand,Optional<OpAsmParser::OperandType> & optOperand,SmallVectorImpl<OpAsmParser::OperandType> & varOperands,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)443 static ParseResult parseCustomDirectiveOperandsAndTypes(
444 OpAsmParser &parser, OpAsmParser::OperandType &operand,
445 Optional<OpAsmParser::OperandType> &optOperand,
446 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
447 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
448 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
449 parseCustomDirectiveResults(parser, operandType, optOperandType,
450 varOperandTypes))
451 return failure();
452 return success();
453 }
parseCustomDirectiveRegions(OpAsmParser & parser,Region & region,SmallVectorImpl<std::unique_ptr<Region>> & varRegions)454 static ParseResult parseCustomDirectiveRegions(
455 OpAsmParser &parser, Region ®ion,
456 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
457 if (parser.parseRegion(region))
458 return failure();
459 if (failed(parser.parseOptionalComma()))
460 return success();
461 std::unique_ptr<Region> varRegion = std::make_unique<Region>();
462 if (parser.parseRegion(*varRegion))
463 return failure();
464 varRegions.emplace_back(std::move(varRegion));
465 return success();
466 }
467 static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser & parser,Block * & successor,SmallVectorImpl<Block * > & varSuccessors)468 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
469 SmallVectorImpl<Block *> &varSuccessors) {
470 if (parser.parseSuccessor(successor))
471 return failure();
472 if (failed(parser.parseOptionalComma()))
473 return success();
474 Block *varSuccessor;
475 if (parser.parseSuccessor(varSuccessor))
476 return failure();
477 varSuccessors.append(2, varSuccessor);
478 return success();
479 }
parseCustomDirectiveAttributes(OpAsmParser & parser,IntegerAttr & attr,IntegerAttr & optAttr)480 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
481 IntegerAttr &attr,
482 IntegerAttr &optAttr) {
483 if (parser.parseAttribute(attr))
484 return failure();
485 if (succeeded(parser.parseOptionalComma())) {
486 if (parser.parseAttribute(optAttr))
487 return failure();
488 }
489 return success();
490 }
491
parseCustomDirectiveAttrDict(OpAsmParser & parser,NamedAttrList & attrs)492 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
493 NamedAttrList &attrs) {
494 return parser.parseOptionalAttrDict(attrs);
495 }
parseCustomDirectiveOptionalOperandRef(OpAsmParser & parser,Optional<OpAsmParser::OperandType> & optOperand)496 static ParseResult parseCustomDirectiveOptionalOperandRef(
497 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
498 int64_t operandCount = 0;
499 if (parser.parseInteger(operandCount))
500 return failure();
501 bool expectedOptionalOperand = operandCount == 0;
502 return success(expectedOptionalOperand != optOperand.hasValue());
503 }
504
505 //===----------------------------------------------------------------------===//
506 // Printing
507
printCustomDirectiveOperands(OpAsmPrinter & printer,Operation *,Value operand,Value optOperand,OperandRange varOperands)508 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
509 Value operand, Value optOperand,
510 OperandRange varOperands) {
511 printer << operand;
512 if (optOperand)
513 printer << ", " << optOperand;
514 printer << " -> (" << varOperands << ")";
515 }
printCustomDirectiveResults(OpAsmPrinter & printer,Operation *,Type operandType,Type optOperandType,TypeRange varOperandTypes)516 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
517 Type operandType, Type optOperandType,
518 TypeRange varOperandTypes) {
519 printer << " : " << operandType;
520 if (optOperandType)
521 printer << ", " << optOperandType;
522 printer << " -> (" << varOperandTypes << ")";
523 }
printCustomDirectiveWithTypeRefs(OpAsmPrinter & printer,Operation * op,Type operandType,Type optOperandType,TypeRange varOperandTypes)524 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
525 Operation *op, Type operandType,
526 Type optOperandType,
527 TypeRange varOperandTypes) {
528 printer << " type_refs_capture ";
529 printCustomDirectiveResults(printer, op, operandType, optOperandType,
530 varOperandTypes);
531 }
printCustomDirectiveOperandsAndTypes(OpAsmPrinter & printer,Operation * op,Value operand,Value optOperand,OperandRange varOperands,Type operandType,Type optOperandType,TypeRange varOperandTypes)532 static void printCustomDirectiveOperandsAndTypes(
533 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
534 OperandRange varOperands, Type operandType, Type optOperandType,
535 TypeRange varOperandTypes) {
536 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
537 printCustomDirectiveResults(printer, op, operandType, optOperandType,
538 varOperandTypes);
539 }
printCustomDirectiveRegions(OpAsmPrinter & printer,Operation *,Region & region,MutableArrayRef<Region> varRegions)540 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
541 Region ®ion,
542 MutableArrayRef<Region> varRegions) {
543 printer.printRegion(region);
544 if (!varRegions.empty()) {
545 printer << ", ";
546 for (Region ®ion : varRegions)
547 printer.printRegion(region);
548 }
549 }
printCustomDirectiveSuccessors(OpAsmPrinter & printer,Operation *,Block * successor,SuccessorRange varSuccessors)550 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
551 Block *successor,
552 SuccessorRange varSuccessors) {
553 printer << successor;
554 if (!varSuccessors.empty())
555 printer << ", " << varSuccessors.front();
556 }
printCustomDirectiveAttributes(OpAsmPrinter & printer,Operation *,Attribute attribute,Attribute optAttribute)557 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
558 Attribute attribute,
559 Attribute optAttribute) {
560 printer << attribute;
561 if (optAttribute)
562 printer << ", " << optAttribute;
563 }
564
printCustomDirectiveAttrDict(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)565 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
566 DictionaryAttr attrs) {
567 printer.printOptionalAttrDict(attrs.getValue());
568 }
569
printCustomDirectiveOptionalOperandRef(OpAsmPrinter & printer,Operation * op,Value optOperand)570 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
571 Operation *op,
572 Value optOperand) {
573 printer << (optOperand ? "1" : "0");
574 }
575
576 //===----------------------------------------------------------------------===//
577 // Test IsolatedRegionOp - parse passthrough region arguments.
578 //===----------------------------------------------------------------------===//
579
parseIsolatedRegionOp(OpAsmParser & parser,OperationState & result)580 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
581 OperationState &result) {
582 OpAsmParser::OperandType argInfo;
583 Type argType = parser.getBuilder().getIndexType();
584
585 // Parse the input operand.
586 if (parser.parseOperand(argInfo) ||
587 parser.resolveOperand(argInfo, argType, result.operands))
588 return failure();
589
590 // Parse the body region, and reuse the operand info as the argument info.
591 Region *body = result.addRegion();
592 return parser.parseRegion(*body, argInfo, argType,
593 /*enableNameShadowing=*/true);
594 }
595
print(OpAsmPrinter & p,IsolatedRegionOp op)596 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
597 p << "test.isolated_region ";
598 p.printOperand(op.getOperand());
599 p.shadowRegionArgs(op.region(), op.getOperand());
600 p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
601 }
602
603 //===----------------------------------------------------------------------===//
604 // Test SSACFGRegionOp
605 //===----------------------------------------------------------------------===//
606
getRegionKind(unsigned index)607 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
608 return RegionKind::SSACFG;
609 }
610
611 //===----------------------------------------------------------------------===//
612 // Test GraphRegionOp
613 //===----------------------------------------------------------------------===//
614
parseGraphRegionOp(OpAsmParser & parser,OperationState & result)615 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
616 OperationState &result) {
617 // Parse the body region, and reuse the operand info as the argument info.
618 Region *body = result.addRegion();
619 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
620 }
621
print(OpAsmPrinter & p,GraphRegionOp op)622 static void print(OpAsmPrinter &p, GraphRegionOp op) {
623 p << "test.graph_region ";
624 p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
625 }
626
getRegionKind(unsigned index)627 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
628 return RegionKind::Graph;
629 }
630
631 //===----------------------------------------------------------------------===//
632 // Test AffineScopeOp
633 //===----------------------------------------------------------------------===//
634
parseAffineScopeOp(OpAsmParser & parser,OperationState & result)635 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
636 OperationState &result) {
637 // Parse the body region, and reuse the operand info as the argument info.
638 Region *body = result.addRegion();
639 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
640 }
641
print(OpAsmPrinter & p,AffineScopeOp op)642 static void print(OpAsmPrinter &p, AffineScopeOp op) {
643 p << "test.affine_scope ";
644 p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
645 }
646
647 //===----------------------------------------------------------------------===//
648 // Test parser.
649 //===----------------------------------------------------------------------===//
650
parseParseIntegerLiteralOp(OpAsmParser & parser,OperationState & result)651 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
652 OperationState &result) {
653 if (parser.parseOptionalColon())
654 return success();
655 uint64_t numResults;
656 if (parser.parseInteger(numResults))
657 return failure();
658
659 IndexType type = parser.getBuilder().getIndexType();
660 for (unsigned i = 0; i < numResults; ++i)
661 result.addTypes(type);
662 return success();
663 }
664
print(OpAsmPrinter & p,ParseIntegerLiteralOp op)665 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
666 if (unsigned numResults = op->getNumResults())
667 p << " : " << numResults;
668 }
669
parseParseWrappedKeywordOp(OpAsmParser & parser,OperationState & result)670 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
671 OperationState &result) {
672 StringRef keyword;
673 if (parser.parseKeyword(&keyword))
674 return failure();
675 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
676 return success();
677 }
678
print(OpAsmPrinter & p,ParseWrappedKeywordOp op)679 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
680 p << " " << op.keyword();
681 }
682
683 //===----------------------------------------------------------------------===//
684 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
685
parseWrappingRegionOp(OpAsmParser & parser,OperationState & result)686 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
687 OperationState &result) {
688 if (parser.parseKeyword("wraps"))
689 return failure();
690
691 // Parse the wrapped op in a region
692 Region &body = *result.addRegion();
693 body.push_back(new Block);
694 Block &block = body.back();
695 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
696 if (!wrapped_op)
697 return failure();
698
699 // Create a return terminator in the inner region, pass as operand to the
700 // terminator the returned values from the wrapped operation.
701 SmallVector<Value, 8> return_operands(wrapped_op->getResults());
702 OpBuilder builder(parser.getContext());
703 builder.setInsertionPointToEnd(&block);
704 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
705
706 // Get the results type for the wrapping op from the terminator operands.
707 Operation &return_op = body.back().back();
708 result.types.append(return_op.operand_type_begin(),
709 return_op.operand_type_end());
710
711 // Use the location of the wrapped op for the "test.wrapping_region" op.
712 result.location = wrapped_op->getLoc();
713
714 return success();
715 }
716
print(OpAsmPrinter & p,WrappingRegionOp op)717 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
718 p << " wraps ";
719 p.printGenericOp(&op.region().front().front());
720 }
721
722 //===----------------------------------------------------------------------===//
723 // Test PolyForOp - parse list of region arguments.
724 //===----------------------------------------------------------------------===//
725
parsePolyForOp(OpAsmParser & parser,OperationState & result)726 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
727 SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
728 // Parse list of region arguments without a delimiter.
729 if (parser.parseRegionArgumentList(ivsInfo))
730 return failure();
731
732 // Parse the body region.
733 Region *body = result.addRegion();
734 auto &builder = parser.getBuilder();
735 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
736 return parser.parseRegion(*body, ivsInfo, argTypes);
737 }
738
739 //===----------------------------------------------------------------------===//
740 // Test removing op with inner ops.
741 //===----------------------------------------------------------------------===//
742
743 namespace {
744 struct TestRemoveOpWithInnerOps
745 : public OpRewritePattern<TestOpWithRegionPattern> {
746 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
747
initialize__anon140dcf5b0711::TestRemoveOpWithInnerOps748 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
749
matchAndRewrite__anon140dcf5b0711::TestRemoveOpWithInnerOps750 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
751 PatternRewriter &rewriter) const override {
752 rewriter.eraseOp(op);
753 return success();
754 }
755 };
756 } // end anonymous namespace
757
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)758 void TestOpWithRegionPattern::getCanonicalizationPatterns(
759 RewritePatternSet &results, MLIRContext *context) {
760 results.add<TestRemoveOpWithInnerOps>(context);
761 }
762
fold(ArrayRef<Attribute> operands)763 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
764 return operand();
765 }
766
fold(ArrayRef<Attribute> operands)767 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
768 return getValue();
769 }
770
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)771 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
772 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
773 for (Value input : this->operands()) {
774 results.push_back(input);
775 }
776 return success();
777 }
778
fold(ArrayRef<Attribute> operands)779 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
780 assert(operands.size() == 1);
781 if (operands.front()) {
782 (*this)->setAttr("attr", operands.front());
783 return getResult();
784 }
785 return {};
786 }
787
fold(ArrayRef<Attribute> operands)788 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
789 return getOperand();
790 }
791
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)792 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
793 MLIRContext *, Optional<Location> location, ValueRange operands,
794 DictionaryAttr attributes, RegionRange regions,
795 SmallVectorImpl<Type> &inferredReturnTypes) {
796 if (operands[0].getType() != operands[1].getType()) {
797 return emitOptionalError(location, "operand type mismatch ",
798 operands[0].getType(), " vs ",
799 operands[1].getType());
800 }
801 inferredReturnTypes.assign({operands[0].getType()});
802 return success();
803 }
804
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)805 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
806 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
807 DictionaryAttr attributes, RegionRange regions,
808 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
809 // Create return type consisting of the last element of the first operand.
810 auto operandType = operands.front().getType();
811 auto sval = operandType.dyn_cast<ShapedType>();
812 if (!sval) {
813 return emitOptionalError(location, "only shaped type operands allowed");
814 }
815 int64_t dim =
816 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
817 auto type = IntegerType::get(context, 17);
818 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
819 return success();
820 }
821
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)822 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
823 OpBuilder &builder, ValueRange operands,
824 llvm::SmallVectorImpl<Value> &shapes) {
825 shapes = SmallVector<Value, 1>{
826 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
827 return success();
828 }
829
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)830 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
831 OpBuilder &builder, ValueRange operands,
832 llvm::SmallVectorImpl<Value> &shapes) {
833 Location loc = getLoc();
834 shapes.reserve(operands.size());
835 for (Value operand : llvm::reverse(operands)) {
836 auto currShape = llvm::to_vector<4>(llvm::map_range(
837 llvm::seq<int64_t>(
838 0, operand.getType().cast<RankedTensorType>().getRank()),
839 [&](int64_t dim) -> Value {
840 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
841 }));
842 shapes.push_back(builder.create<tensor::FromElementsOp>(
843 getLoc(), builder.getIndexType(), currShape));
844 }
845 return success();
846 }
847
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & shapes)848 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
849 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
850 Location loc = getLoc();
851 shapes.reserve(getNumOperands());
852 for (Value operand : llvm::reverse(getOperands())) {
853 auto currShape = llvm::to_vector<4>(llvm::map_range(
854 llvm::seq<int64_t>(
855 0, operand.getType().cast<RankedTensorType>().getRank()),
856 [&](int64_t dim) -> Value {
857 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
858 }));
859 shapes.emplace_back(std::move(currShape));
860 }
861 return success();
862 }
863
864 //===----------------------------------------------------------------------===//
865 // Test SideEffect interfaces
866 //===----------------------------------------------------------------------===//
867
868 namespace {
869 /// A test resource for side effects.
870 struct TestResource : public SideEffects::Resource::Base<TestResource> {
getName__anon140dcf5b0a11::TestResource871 StringRef getName() final { return "<Test>"; }
872 };
873 } // end anonymous namespace
874
testSideEffectOpGetEffect(Operation * op,SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> & effects)875 static void testSideEffectOpGetEffect(
876 Operation *op,
877 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
878 &effects) {
879 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
880 if (!effectsAttr)
881 return;
882
883 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
884 }
885
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)886 void SideEffectOp::getEffects(
887 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
888 // Check for an effects attribute on the op instance.
889 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
890 if (!effectsAttr)
891 return;
892
893 // If there is one, it is an array of dictionary attributes that hold
894 // information on the effects of this operation.
895 for (Attribute element : effectsAttr) {
896 DictionaryAttr effectElement = element.cast<DictionaryAttr>();
897
898 // Get the specific memory effect.
899 MemoryEffects::Effect *effect =
900 StringSwitch<MemoryEffects::Effect *>(
901 effectElement.get("effect").cast<StringAttr>().getValue())
902 .Case("allocate", MemoryEffects::Allocate::get())
903 .Case("free", MemoryEffects::Free::get())
904 .Case("read", MemoryEffects::Read::get())
905 .Case("write", MemoryEffects::Write::get());
906
907 // Check for a non-default resource to use.
908 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
909 if (effectElement.get("test_resource"))
910 resource = TestResource::get();
911
912 // Check for a result to affect.
913 if (effectElement.get("on_result"))
914 effects.emplace_back(effect, getResult(), resource);
915 else if (Attribute ref = effectElement.get("on_reference"))
916 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
917 else
918 effects.emplace_back(effect, resource);
919 }
920 }
921
getEffects(SmallVectorImpl<TestEffects::EffectInstance> & effects)922 void SideEffectOp::getEffects(
923 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
924 testSideEffectOpGetEffect(getOperation(), effects);
925 }
926
927 //===----------------------------------------------------------------------===//
928 // StringAttrPrettyNameOp
929 //===----------------------------------------------------------------------===//
930
931 // This op has fancy handling of its SSA result name.
parseStringAttrPrettyNameOp(OpAsmParser & parser,OperationState & result)932 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
933 OperationState &result) {
934 // Add the result types.
935 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
936 result.addTypes(parser.getBuilder().getIntegerType(32));
937
938 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
939 return failure();
940
941 // If the attribute dictionary contains no 'names' attribute, infer it from
942 // the SSA name (if specified).
943 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
944 return attr.first == "names";
945 });
946
947 // If there was no name specified, check to see if there was a useful name
948 // specified in the asm file.
949 if (hadNames || parser.getNumResults() == 0)
950 return success();
951
952 SmallVector<StringRef, 4> names;
953 auto *context = result.getContext();
954
955 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
956 auto resultName = parser.getResultName(i);
957 StringRef nameStr;
958 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
959 nameStr = resultName.first;
960
961 names.push_back(nameStr);
962 }
963
964 auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
965 result.attributes.push_back({Identifier::get("names", context), namesAttr});
966 return success();
967 }
968
print(OpAsmPrinter & p,StringAttrPrettyNameOp op)969 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
970 // Note that we only need to print the "name" attribute if the asmprinter
971 // result name disagrees with it. This can happen in strange cases, e.g.
972 // when there are conflicts.
973 bool namesDisagree = op.names().size() != op.getNumResults();
974
975 SmallString<32> resultNameStr;
976 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
977 resultNameStr.clear();
978 llvm::raw_svector_ostream tmpStream(resultNameStr);
979 p.printOperand(op.getResult(i), tmpStream);
980
981 auto expectedName = op.names()[i].dyn_cast<StringAttr>();
982 if (!expectedName ||
983 tmpStream.str().drop_front() != expectedName.getValue()) {
984 namesDisagree = true;
985 }
986 }
987
988 if (namesDisagree)
989 p.printOptionalAttrDictWithKeyword(op->getAttrs());
990 else
991 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
992 }
993
994 // We set the SSA name in the asm syntax to the contents of the name
995 // attribute.
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)996 void StringAttrPrettyNameOp::getAsmResultNames(
997 function_ref<void(Value, StringRef)> setNameFn) {
998
999 auto value = names();
1000 for (size_t i = 0, e = value.size(); i != e; ++i)
1001 if (auto str = value[i].dyn_cast<StringAttr>())
1002 if (!str.getValue().empty())
1003 setNameFn(getResult(i), str.getValue());
1004 }
1005
1006 //===----------------------------------------------------------------------===//
1007 // RegionIfOp
1008 //===----------------------------------------------------------------------===//
1009
print(OpAsmPrinter & p,RegionIfOp op)1010 static void print(OpAsmPrinter &p, RegionIfOp op) {
1011 p << " ";
1012 p.printOperands(op.getOperands());
1013 p << ": " << op.getOperandTypes();
1014 p.printArrowTypeList(op.getResultTypes());
1015 p << " then";
1016 p.printRegion(op.thenRegion(),
1017 /*printEntryBlockArgs=*/true,
1018 /*printBlockTerminators=*/true);
1019 p << " else";
1020 p.printRegion(op.elseRegion(),
1021 /*printEntryBlockArgs=*/true,
1022 /*printBlockTerminators=*/true);
1023 p << " join";
1024 p.printRegion(op.joinRegion(),
1025 /*printEntryBlockArgs=*/true,
1026 /*printBlockTerminators=*/true);
1027 }
1028
parseRegionIfOp(OpAsmParser & parser,OperationState & result)1029 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1030 OperationState &result) {
1031 SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1032 SmallVector<Type, 2> operandTypes;
1033
1034 result.regions.reserve(3);
1035 Region *thenRegion = result.addRegion();
1036 Region *elseRegion = result.addRegion();
1037 Region *joinRegion = result.addRegion();
1038
1039 // Parse operand, type and arrow type lists.
1040 if (parser.parseOperandList(operandInfos) ||
1041 parser.parseColonTypeList(operandTypes) ||
1042 parser.parseArrowTypeList(result.types))
1043 return failure();
1044
1045 // Parse all attached regions.
1046 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1047 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1048 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1049 return failure();
1050
1051 return parser.resolveOperands(operandInfos, operandTypes,
1052 parser.getCurrentLocation(), result.operands);
1053 }
1054
getSuccessorEntryOperands(unsigned index)1055 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1056 assert(index < 2 && "invalid region index");
1057 return getOperands();
1058 }
1059
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1060 void RegionIfOp::getSuccessorRegions(
1061 Optional<unsigned> index, ArrayRef<Attribute> operands,
1062 SmallVectorImpl<RegionSuccessor> ®ions) {
1063 // We always branch to the join region.
1064 if (index.hasValue()) {
1065 if (index.getValue() < 2)
1066 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1067 else
1068 regions.push_back(RegionSuccessor(getResults()));
1069 return;
1070 }
1071
1072 // The then and else regions are the entry regions of this op.
1073 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1074 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1075 }
1076
1077 //===----------------------------------------------------------------------===//
1078 // SingleNoTerminatorCustomAsmOp
1079 //===----------------------------------------------------------------------===//
1080
parseSingleNoTerminatorCustomAsmOp(OpAsmParser & parser,OperationState & state)1081 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1082 OperationState &state) {
1083 Region *body = state.addRegion();
1084 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1085 return failure();
1086 return success();
1087 }
1088
print(SingleNoTerminatorCustomAsmOp op,OpAsmPrinter & printer)1089 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1090 printer.printRegion(
1091 op.getRegion(), /*printEntryBlockArgs=*/false,
1092 // This op has a single block without terminators. But explicitly mark
1093 // as not printing block terminators for testing.
1094 /*printBlockTerminators=*/false);
1095 }
1096
1097 #include "TestOpEnums.cpp.inc"
1098 #include "TestOpInterfaces.cpp.inc"
1099 #include "TestOpStructs.cpp.inc"
1100 #include "TestTypeInterfaces.cpp.inc"
1101
1102 #define GET_OP_CLASSES
1103 #include "TestOps.cpp.inc"
1104