1 //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper classes for implementing the "Op" types. This
10 // includes the Op type, which is the base class for Op class definitions,
11 // as well as number of traits in the OpTrait namespace that provide a
12 // declarative way to specify properties of Ops.
13 //
14 // The purpose of these types are to allow light-weight implementation of
15 // concrete ops (like DimOp) with very little boilerplate.
16 //
17 //===----------------------------------------------------------------------===//
18
19 #ifndef MLIR_IR_OPDEFINITION_H
20 #define MLIR_IR_OPDEFINITION_H
21
22 #include "mlir/IR/Operation.h"
23 #include "llvm/Support/PointerLikeTypeTraits.h"
24
25 #include <type_traits>
26
27 namespace mlir {
28 class Builder;
29 class OpBuilder;
30
31 /// This class represents success/failure for operation parsing. It is
32 /// essentially a simple wrapper class around LogicalResult that allows for
33 /// explicit conversion to bool. This allows for the parser to chain together
34 /// parse rules without the clutter of "failed/succeeded".
35 class ParseResult : public LogicalResult {
36 public:
LogicalResult(result)37 ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
38
39 // Allow diagnostics emitted during parsing to be converted to failure.
ParseResult(const InFlightDiagnostic &)40 ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
ParseResult(const Diagnostic &)41 ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
42
43 /// Failure is true in a boolean context.
44 explicit operator bool() const { return failed(*this); }
45 };
46 /// This class implements `Optional` functionality for ParseResult. We don't
47 /// directly use Optional here, because it provides an implicit conversion
48 /// to 'bool' which we want to avoid. This class is used to implement tri-state
49 /// 'parseOptional' functions that may have a failure mode when parsing that
50 /// shouldn't be attributed to "not present".
51 class OptionalParseResult {
52 public:
53 OptionalParseResult() = default;
OptionalParseResult(LogicalResult result)54 OptionalParseResult(LogicalResult result) : impl(result) {}
OptionalParseResult(ParseResult result)55 OptionalParseResult(ParseResult result) : impl(result) {}
OptionalParseResult(const InFlightDiagnostic &)56 OptionalParseResult(const InFlightDiagnostic &)
57 : OptionalParseResult(failure()) {}
OptionalParseResult(llvm::NoneType)58 OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
59
60 /// Returns true if we contain a valid ParseResult value.
hasValue()61 bool hasValue() const { return impl.hasValue(); }
62
63 /// Access the internal ParseResult value.
getValue()64 ParseResult getValue() const { return impl.getValue(); }
65 ParseResult operator*() const { return getValue(); }
66
67 private:
68 Optional<ParseResult> impl;
69 };
70
71 // These functions are out-of-line utilities, which avoids them being template
72 // instantiated/duplicated.
73 namespace impl {
74 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
75 /// region's only block if it does not have a terminator already. If the region
76 /// is empty, insert a new block first. `buildTerminatorOp` should return the
77 /// terminator operation to insert.
78 void ensureRegionTerminator(
79 Region ®ion, OpBuilder &builder, Location loc,
80 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
81 void ensureRegionTerminator(
82 Region ®ion, Builder &builder, Location loc,
83 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
84
85 } // namespace impl
86
87 /// This is the concrete base class that holds the operation pointer and has
88 /// non-generic methods that only depend on State (to avoid having them
89 /// instantiated on template types that don't affect them.
90 ///
91 /// This also has the fallback implementations of customization hooks for when
92 /// they aren't customized.
93 class OpState {
94 public:
95 /// Ops are pointer-like, so we allow implicit conversion to bool.
96 operator bool() { return getOperation() != nullptr; }
97
98 /// This implicitly converts to Operation*.
99 operator Operation *() const { return state; }
100
101 /// Shortcut of `->` to access a member of Operation.
102 Operation *operator->() const { return state; }
103
104 /// Return the operation that this refers to.
getOperation()105 Operation *getOperation() { return state; }
106
107 /// Return the context this operation belongs to.
getContext()108 MLIRContext *getContext() { return getOperation()->getContext(); }
109
110 /// Print the operation to the given stream.
111 void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
112 state->print(os, flags);
113 }
114 void print(raw_ostream &os, AsmState &asmState,
115 OpPrintingFlags flags = llvm::None) {
116 state->print(os, asmState, flags);
117 }
118
119 /// Dump this operation.
dump()120 void dump() { state->dump(); }
121
122 /// The source location the operation was defined or derived from.
getLoc()123 Location getLoc() { return state->getLoc(); }
setLoc(Location loc)124 void setLoc(Location loc) { state->setLoc(loc); }
125
126 /// Return all of the attributes on this operation.
getAttrs()127 ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
128
129 /// A utility iterator that filters out non-dialect attributes.
130 using dialect_attr_iterator = Operation::dialect_attr_iterator;
131 using dialect_attr_range = Operation::dialect_attr_range;
132
133 /// Set the dialect attributes for this operation, and preserve all dependent.
134 template <typename DialectAttrs>
setDialectAttrs(DialectAttrs && attrs)135 void setDialectAttrs(DialectAttrs &&attrs) {
136 state->setDialectAttrs(std::forward<DialectAttrs>(attrs));
137 }
138
139 /// Remove the attribute with the specified name if it exists. Return the
140 /// attribute that was erased, or nullptr if there was no attribute with such
141 /// name.
removeAttr(Identifier name)142 Attribute removeAttr(Identifier name) { return state->removeAttr(name); }
removeAttr(StringRef name)143 Attribute removeAttr(StringRef name) {
144 return state->removeAttr(Identifier::get(name, getContext()));
145 }
146
147 /// Return true if there are no users of any results of this operation.
use_empty()148 bool use_empty() { return state->use_empty(); }
149
150 /// Remove this operation from its parent block and delete it.
erase()151 void erase() { state->erase(); }
152
153 /// Emit an error with the op name prefixed, like "'dim' op " which is
154 /// convenient for verifiers.
155 InFlightDiagnostic emitOpError(const Twine &message = {});
156
157 /// Emit an error about fatal conditions with this operation, reporting up to
158 /// any diagnostic handlers that may be listening.
159 InFlightDiagnostic emitError(const Twine &message = {});
160
161 /// Emit a warning about this operation, reporting up to any diagnostic
162 /// handlers that may be listening.
163 InFlightDiagnostic emitWarning(const Twine &message = {});
164
165 /// Emit a remark about this operation, reporting up to any diagnostic
166 /// handlers that may be listening.
167 InFlightDiagnostic emitRemark(const Twine &message = {});
168
169 /// Walk the operation in postorder, calling the callback for each nested
170 /// operation(including this one).
171 /// See Operation::walk for more details.
172 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
walk(FnT && callback)173 RetT walk(FnT &&callback) {
174 return state->walk(std::forward<FnT>(callback));
175 }
176
177 // These are default implementations of customization hooks.
178 public:
179 /// This hook returns any canonicalization pattern rewrites that the operation
180 /// supports, for use by the canonicalization pass.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)181 static void getCanonicalizationPatterns(OwningRewritePatternList &results,
182 MLIRContext *context) {}
183
184 protected:
185 /// If the concrete type didn't implement a custom verifier hook, just fall
186 /// back to this one which accepts everything.
verify()187 LogicalResult verify() { return success(); }
188
189 /// Unless overridden, the custom assembly form of an op is always rejected.
190 /// Op implementations should implement this to return failure.
191 /// On success, they should fill in result with the fields to use.
192 static ParseResult parse(OpAsmParser &parser, OperationState &result);
193
194 // The fallback for the printer is to print it the generic assembly form.
195 static void print(Operation *op, OpAsmPrinter &p);
196
197 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
198 /// so we can cast it away here.
OpState(Operation * state)199 explicit OpState(Operation *state) : state(state) {}
200
201 private:
202 Operation *state;
203
204 /// Allow access to internal hook implementation methods.
205 friend AbstractOperation;
206 };
207
208 // Allow comparing operators.
209 inline bool operator==(OpState lhs, OpState rhs) {
210 return lhs.getOperation() == rhs.getOperation();
211 }
212 inline bool operator!=(OpState lhs, OpState rhs) {
213 return lhs.getOperation() != rhs.getOperation();
214 }
215
216 raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
217
218 /// This class represents a single result from folding an operation.
219 class OpFoldResult : public PointerUnion<Attribute, Value> {
220 using PointerUnion<Attribute, Value>::PointerUnion;
221
222 public:
dump()223 void dump() { llvm::errs() << *this << "\n"; }
224 };
225
226 /// Allow printing to a stream.
227 inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
228 if (Value value = ofr.dyn_cast<Value>())
229 value.print(os);
230 else
231 ofr.dyn_cast<Attribute>().print(os);
232 return os;
233 }
234
235 /// Allow printing to a stream.
236 inline raw_ostream &operator<<(raw_ostream &os, OpState &op) {
237 op.print(os, OpPrintingFlags().useLocalScope());
238 return os;
239 }
240
241 //===----------------------------------------------------------------------===//
242 // Operation Trait Types
243 //===----------------------------------------------------------------------===//
244
245 namespace OpTrait {
246
247 // These functions are out-of-line implementations of the methods in the
248 // corresponding trait classes. This avoids them being template
249 // instantiated/duplicated.
250 namespace impl {
251 OpFoldResult foldIdempotent(Operation *op);
252 OpFoldResult foldInvolution(Operation *op);
253 LogicalResult verifyZeroOperands(Operation *op);
254 LogicalResult verifyOneOperand(Operation *op);
255 LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
256 LogicalResult verifyIsIdempotent(Operation *op);
257 LogicalResult verifyIsInvolution(Operation *op);
258 LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
259 LogicalResult verifyOperandsAreFloatLike(Operation *op);
260 LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
261 LogicalResult verifySameTypeOperands(Operation *op);
262 LogicalResult verifyZeroRegion(Operation *op);
263 LogicalResult verifyOneRegion(Operation *op);
264 LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
265 LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
266 LogicalResult verifyZeroResult(Operation *op);
267 LogicalResult verifyOneResult(Operation *op);
268 LogicalResult verifyNResults(Operation *op, unsigned numOperands);
269 LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
270 LogicalResult verifySameOperandsShape(Operation *op);
271 LogicalResult verifySameOperandsAndResultShape(Operation *op);
272 LogicalResult verifySameOperandsElementType(Operation *op);
273 LogicalResult verifySameOperandsAndResultElementType(Operation *op);
274 LogicalResult verifySameOperandsAndResultType(Operation *op);
275 LogicalResult verifyResultsAreBoolLike(Operation *op);
276 LogicalResult verifyResultsAreFloatLike(Operation *op);
277 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
278 LogicalResult verifyIsTerminator(Operation *op);
279 LogicalResult verifyZeroSuccessor(Operation *op);
280 LogicalResult verifyOneSuccessor(Operation *op);
281 LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
282 LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
283 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
284 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
285 LogicalResult verifyNoRegionArguments(Operation *op);
286 LogicalResult verifyElementwiseMappable(Operation *op);
287 } // namespace impl
288
289 /// Helper class for implementing traits. Clients are not expected to interact
290 /// with this directly, so its members are all protected.
291 template <typename ConcreteType, template <typename> class TraitType>
292 class TraitBase {
293 protected:
294 /// Return the ultimate Operation being worked on.
getOperation()295 Operation *getOperation() {
296 // We have to cast up to the trait type, then to the concrete type, then to
297 // the BaseState class in explicit hops because the concrete type will
298 // multiply derive from the (content free) TraitBase class, and we need to
299 // be able to disambiguate the path for the C++ compiler.
300 auto *trait = static_cast<TraitType<ConcreteType> *>(this);
301 auto *concrete = static_cast<ConcreteType *>(trait);
302 auto *base = static_cast<OpState *>(concrete);
303 return base->getOperation();
304 }
305 };
306
307 //===----------------------------------------------------------------------===//
308 // Operand Traits
309
310 namespace detail {
311 /// Utility trait base that provides accessors for derived traits that have
312 /// multiple operands.
313 template <typename ConcreteType, template <typename> class TraitType>
314 struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
315 using operand_iterator = Operation::operand_iterator;
316 using operand_range = Operation::operand_range;
317 using operand_type_iterator = Operation::operand_type_iterator;
318 using operand_type_range = Operation::operand_type_range;
319
320 /// Return the number of operands.
getNumOperandsMultiOperandTraitBase321 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
322
323 /// Return the operand at index 'i'.
getOperandMultiOperandTraitBase324 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
325
326 /// Set the operand at index 'i' to 'value'.
setOperandMultiOperandTraitBase327 void setOperand(unsigned i, Value value) {
328 this->getOperation()->setOperand(i, value);
329 }
330
331 /// Operand iterator access.
operand_beginMultiOperandTraitBase332 operand_iterator operand_begin() {
333 return this->getOperation()->operand_begin();
334 }
operand_endMultiOperandTraitBase335 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
getOperandsMultiOperandTraitBase336 operand_range getOperands() { return this->getOperation()->getOperands(); }
337
338 /// Operand type access.
operand_type_beginMultiOperandTraitBase339 operand_type_iterator operand_type_begin() {
340 return this->getOperation()->operand_type_begin();
341 }
operand_type_endMultiOperandTraitBase342 operand_type_iterator operand_type_end() {
343 return this->getOperation()->operand_type_end();
344 }
getOperandTypesMultiOperandTraitBase345 operand_type_range getOperandTypes() {
346 return this->getOperation()->getOperandTypes();
347 }
348 };
349 } // end namespace detail
350
351 /// This class provides the API for ops that are known to have no
352 /// SSA operand.
353 template <typename ConcreteType>
354 class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
355 public:
verifyTrait(Operation * op)356 static LogicalResult verifyTrait(Operation *op) {
357 return impl::verifyZeroOperands(op);
358 }
359
360 private:
361 // Disable these.
getOperand()362 void getOperand() {}
setOperand()363 void setOperand() {}
364 };
365
366 /// This class provides the API for ops that are known to have exactly one
367 /// SSA operand.
368 template <typename ConcreteType>
369 class OneOperand : public TraitBase<ConcreteType, OneOperand> {
370 public:
getOperand()371 Value getOperand() { return this->getOperation()->getOperand(0); }
372
setOperand(Value value)373 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
374
verifyTrait(Operation * op)375 static LogicalResult verifyTrait(Operation *op) {
376 return impl::verifyOneOperand(op);
377 }
378 };
379
380 /// This class provides the API for ops that are known to have a specified
381 /// number of operands. This is used as a trait like this:
382 ///
383 /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
384 ///
385 template <unsigned N>
386 class NOperands {
387 public:
388 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
389
390 template <typename ConcreteType>
391 class Impl
392 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
393 public:
verifyTrait(Operation * op)394 static LogicalResult verifyTrait(Operation *op) {
395 return impl::verifyNOperands(op, N);
396 }
397 };
398 };
399
400 /// This class provides the API for ops that are known to have a at least a
401 /// specified number of operands. This is used as a trait like this:
402 ///
403 /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
404 ///
405 template <unsigned N>
406 class AtLeastNOperands {
407 public:
408 template <typename ConcreteType>
409 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
410 AtLeastNOperands<N>::Impl> {
411 public:
verifyTrait(Operation * op)412 static LogicalResult verifyTrait(Operation *op) {
413 return impl::verifyAtLeastNOperands(op, N);
414 }
415 };
416 };
417
418 /// This class provides the API for ops which have an unknown number of
419 /// SSA operands.
420 template <typename ConcreteType>
421 class VariadicOperands
422 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
423
424 //===----------------------------------------------------------------------===//
425 // Region Traits
426
427 /// This class provides verification for ops that are known to have zero
428 /// regions.
429 template <typename ConcreteType>
430 class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
431 public:
verifyTrait(Operation * op)432 static LogicalResult verifyTrait(Operation *op) {
433 return impl::verifyZeroRegion(op);
434 }
435 };
436
437 namespace detail {
438 /// Utility trait base that provides accessors for derived traits that have
439 /// multiple regions.
440 template <typename ConcreteType, template <typename> class TraitType>
441 struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
442 using region_iterator = MutableArrayRef<Region>;
443 using region_range = RegionRange;
444
445 /// Return the number of regions.
getNumRegionsMultiRegionTraitBase446 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
447
448 /// Return the region at `index`.
getRegionMultiRegionTraitBase449 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
450
451 /// Region iterator access.
region_beginMultiRegionTraitBase452 region_iterator region_begin() {
453 return this->getOperation()->region_begin();
454 }
region_endMultiRegionTraitBase455 region_iterator region_end() { return this->getOperation()->region_end(); }
getRegionsMultiRegionTraitBase456 region_range getRegions() { return this->getOperation()->getRegions(); }
457 };
458 } // end namespace detail
459
460 /// This class provides APIs for ops that are known to have a single region.
461 template <typename ConcreteType>
462 class OneRegion : public TraitBase<ConcreteType, OneRegion> {
463 public:
getRegion()464 Region &getRegion() { return this->getOperation()->getRegion(0); }
465
466 /// Returns a range of operations within the region of this operation.
getOps()467 auto getOps() { return getRegion().getOps(); }
468 template <typename OpT>
getOps()469 auto getOps() {
470 return getRegion().template getOps<OpT>();
471 }
472
verifyTrait(Operation * op)473 static LogicalResult verifyTrait(Operation *op) {
474 return impl::verifyOneRegion(op);
475 }
476 };
477
478 /// This class provides the API for ops that are known to have a specified
479 /// number of regions.
480 template <unsigned N>
481 class NRegions {
482 public:
483 static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
484
485 template <typename ConcreteType>
486 class Impl
487 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
488 public:
verifyTrait(Operation * op)489 static LogicalResult verifyTrait(Operation *op) {
490 return impl::verifyNRegions(op, N);
491 }
492 };
493 };
494
495 /// This class provides APIs for ops that are known to have at least a specified
496 /// number of regions.
497 template <unsigned N>
498 class AtLeastNRegions {
499 public:
500 template <typename ConcreteType>
501 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
502 AtLeastNRegions<N>::Impl> {
503 public:
verifyTrait(Operation * op)504 static LogicalResult verifyTrait(Operation *op) {
505 return impl::verifyAtLeastNRegions(op, N);
506 }
507 };
508 };
509
510 /// This class provides the API for ops which have an unknown number of
511 /// regions.
512 template <typename ConcreteType>
513 class VariadicRegions
514 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
515
516 //===----------------------------------------------------------------------===//
517 // Result Traits
518
519 /// This class provides return value APIs for ops that are known to have
520 /// zero results.
521 template <typename ConcreteType>
522 class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
523 public:
verifyTrait(Operation * op)524 static LogicalResult verifyTrait(Operation *op) {
525 return impl::verifyZeroResult(op);
526 }
527 };
528
529 namespace detail {
530 /// Utility trait base that provides accessors for derived traits that have
531 /// multiple results.
532 template <typename ConcreteType, template <typename> class TraitType>
533 struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
534 using result_iterator = Operation::result_iterator;
535 using result_range = Operation::result_range;
536 using result_type_iterator = Operation::result_type_iterator;
537 using result_type_range = Operation::result_type_range;
538
539 /// Return the number of results.
getNumResultsMultiResultTraitBase540 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
541
542 /// Return the result at index 'i'.
getResultMultiResultTraitBase543 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
544
545 /// Replace all uses of results of this operation with the provided 'values'.
546 /// 'values' may correspond to an existing operation, or a range of 'Value'.
547 template <typename ValuesT>
replaceAllUsesWithMultiResultTraitBase548 void replaceAllUsesWith(ValuesT &&values) {
549 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
550 }
551
552 /// Return the type of the `i`-th result.
getTypeMultiResultTraitBase553 Type getType(unsigned i) { return getResult(i).getType(); }
554
555 /// Result iterator access.
result_beginMultiResultTraitBase556 result_iterator result_begin() {
557 return this->getOperation()->result_begin();
558 }
result_endMultiResultTraitBase559 result_iterator result_end() { return this->getOperation()->result_end(); }
getResultsMultiResultTraitBase560 result_range getResults() { return this->getOperation()->getResults(); }
561
562 /// Result type access.
result_type_beginMultiResultTraitBase563 result_type_iterator result_type_begin() {
564 return this->getOperation()->result_type_begin();
565 }
result_type_endMultiResultTraitBase566 result_type_iterator result_type_end() {
567 return this->getOperation()->result_type_end();
568 }
getResultTypesMultiResultTraitBase569 result_type_range getResultTypes() {
570 return this->getOperation()->getResultTypes();
571 }
572 };
573 } // end namespace detail
574
575 /// This class provides return value APIs for ops that are known to have a
576 /// single result. ResultType is the concrete type returned by getType().
577 template <typename ConcreteType>
578 class OneResult : public TraitBase<ConcreteType, OneResult> {
579 public:
getResult()580 Value getResult() { return this->getOperation()->getResult(0); }
581
582 /// If the operation returns a single value, then the Op can be implicitly
583 /// converted to an Value. This yields the value of the only result.
Value()584 operator Value() { return getResult(); }
585
586 /// Replace all uses of 'this' value with the new value, updating anything
587 /// in the IR that uses 'this' to use the other value instead. When this
588 /// returns there are zero uses of 'this'.
replaceAllUsesWith(Value newValue)589 void replaceAllUsesWith(Value newValue) {
590 getResult().replaceAllUsesWith(newValue);
591 }
592
593 /// Replace all uses of 'this' value with the result of 'op'.
replaceAllUsesWith(Operation * op)594 void replaceAllUsesWith(Operation *op) {
595 this->getOperation()->replaceAllUsesWith(op);
596 }
597
verifyTrait(Operation * op)598 static LogicalResult verifyTrait(Operation *op) {
599 return impl::verifyOneResult(op);
600 }
601 };
602
603 /// This trait is used for return value APIs for ops that are known to have a
604 /// specific type other than `Type`. This allows the "getType()" member to be
605 /// more specific for an op. This should be used in conjunction with OneResult,
606 /// and occur in the trait list before OneResult.
607 template <typename ResultType>
608 class OneTypedResult {
609 public:
610 /// This class provides return value APIs for ops that are known to have a
611 /// single result. ResultType is the concrete type returned by getType().
612 template <typename ConcreteType>
613 class Impl
614 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
615 public:
getType()616 ResultType getType() {
617 auto resultTy = this->getOperation()->getResult(0).getType();
618 return resultTy.template cast<ResultType>();
619 }
620 };
621 };
622
623 /// This class provides the API for ops that are known to have a specified
624 /// number of results. This is used as a trait like this:
625 ///
626 /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
627 ///
628 template <unsigned N>
629 class NResults {
630 public:
631 static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
632
633 template <typename ConcreteType>
634 class Impl
635 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
636 public:
verifyTrait(Operation * op)637 static LogicalResult verifyTrait(Operation *op) {
638 return impl::verifyNResults(op, N);
639 }
640 };
641 };
642
643 /// This class provides the API for ops that are known to have at least a
644 /// specified number of results. This is used as a trait like this:
645 ///
646 /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
647 ///
648 template <unsigned N>
649 class AtLeastNResults {
650 public:
651 template <typename ConcreteType>
652 class Impl : public detail::MultiResultTraitBase<ConcreteType,
653 AtLeastNResults<N>::Impl> {
654 public:
verifyTrait(Operation * op)655 static LogicalResult verifyTrait(Operation *op) {
656 return impl::verifyAtLeastNResults(op, N);
657 }
658 };
659 };
660
661 /// This class provides the API for ops which have an unknown number of
662 /// results.
663 template <typename ConcreteType>
664 class VariadicResults
665 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
666
667 //===----------------------------------------------------------------------===//
668 // Terminator Traits
669
670 /// This class provides the API for ops that are known to be terminators.
671 template <typename ConcreteType>
672 class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
673 public:
getTraitProperties()674 static AbstractOperation::OperationProperties getTraitProperties() {
675 return static_cast<AbstractOperation::OperationProperties>(
676 OperationProperty::Terminator);
677 }
verifyTrait(Operation * op)678 static LogicalResult verifyTrait(Operation *op) {
679 return impl::verifyIsTerminator(op);
680 }
681 };
682
683 /// This class provides verification for ops that are known to have zero
684 /// successors.
685 template <typename ConcreteType>
686 class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
687 public:
verifyTrait(Operation * op)688 static LogicalResult verifyTrait(Operation *op) {
689 return impl::verifyZeroSuccessor(op);
690 }
691 };
692
693 namespace detail {
694 /// Utility trait base that provides accessors for derived traits that have
695 /// multiple successors.
696 template <typename ConcreteType, template <typename> class TraitType>
697 struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
698 using succ_iterator = Operation::succ_iterator;
699 using succ_range = SuccessorRange;
700
701 /// Return the number of successors.
getNumSuccessorsMultiSuccessorTraitBase702 unsigned getNumSuccessors() {
703 return this->getOperation()->getNumSuccessors();
704 }
705
706 /// Return the successor at `index`.
getSuccessorMultiSuccessorTraitBase707 Block *getSuccessor(unsigned i) {
708 return this->getOperation()->getSuccessor(i);
709 }
710
711 /// Set the successor at `index`.
setSuccessorMultiSuccessorTraitBase712 void setSuccessor(Block *block, unsigned i) {
713 return this->getOperation()->setSuccessor(block, i);
714 }
715
716 /// Successor iterator access.
succ_beginMultiSuccessorTraitBase717 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
succ_endMultiSuccessorTraitBase718 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
getSuccessorsMultiSuccessorTraitBase719 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
720 };
721 } // end namespace detail
722
723 /// This class provides APIs for ops that are known to have a single successor.
724 template <typename ConcreteType>
725 class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
726 public:
getSuccessor()727 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
setSuccessor(Block * succ)728 void setSuccessor(Block *succ) {
729 this->getOperation()->setSuccessor(succ, 0);
730 }
731
verifyTrait(Operation * op)732 static LogicalResult verifyTrait(Operation *op) {
733 return impl::verifyOneSuccessor(op);
734 }
735 };
736
737 /// This class provides the API for ops that are known to have a specified
738 /// number of successors.
739 template <unsigned N>
740 class NSuccessors {
741 public:
742 static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
743
744 template <typename ConcreteType>
745 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
746 NSuccessors<N>::Impl> {
747 public:
verifyTrait(Operation * op)748 static LogicalResult verifyTrait(Operation *op) {
749 return impl::verifyNSuccessors(op, N);
750 }
751 };
752 };
753
754 /// This class provides APIs for ops that are known to have at least a specified
755 /// number of successors.
756 template <unsigned N>
757 class AtLeastNSuccessors {
758 public:
759 template <typename ConcreteType>
760 class Impl
761 : public detail::MultiSuccessorTraitBase<ConcreteType,
762 AtLeastNSuccessors<N>::Impl> {
763 public:
verifyTrait(Operation * op)764 static LogicalResult verifyTrait(Operation *op) {
765 return impl::verifyAtLeastNSuccessors(op, N);
766 }
767 };
768 };
769
770 /// This class provides the API for ops which have an unknown number of
771 /// successors.
772 template <typename ConcreteType>
773 class VariadicSuccessors
774 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
775 };
776
777 //===----------------------------------------------------------------------===//
778 // SingleBlockImplicitTerminator
779
780 /// This class provides APIs and verifiers for ops with regions having a single
781 /// block that must terminate with `TerminatorOpType`.
782 template <typename TerminatorOpType>
783 struct SingleBlockImplicitTerminator {
784 template <typename ConcreteType>
785 class Impl : public TraitBase<ConcreteType, Impl> {
786 private:
787 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
788 /// cyclic header inclusion.
buildTerminatorSingleBlockImplicitTerminator789 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
790 OperationState state(loc, TerminatorOpType::getOperationName());
791 TerminatorOpType::build(builder, state);
792 return Operation::create(state);
793 }
794
795 public:
796 /// The type of the operation used as the implicit terminator type.
797 using ImplicitTerminatorOpT = TerminatorOpType;
798
verifyTraitSingleBlockImplicitTerminator799 static LogicalResult verifyTrait(Operation *op) {
800 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
801 Region ®ion = op->getRegion(i);
802
803 // Empty regions are fine.
804 if (region.empty())
805 continue;
806
807 // Non-empty regions must contain a single basic block.
808 if (std::next(region.begin()) != region.end())
809 return op->emitOpError("expects region #")
810 << i << " to have 0 or 1 blocks";
811
812 Block &block = region.front();
813 if (block.empty())
814 return op->emitOpError() << "expects a non-empty block";
815 Operation &terminator = block.back();
816 if (isa<TerminatorOpType>(terminator))
817 continue;
818
819 return op->emitOpError("expects regions to end with '" +
820 TerminatorOpType::getOperationName() +
821 "', found '" +
822 terminator.getName().getStringRef() + "'")
823 .attachNote()
824 << "in custom textual format, the absence of terminator implies "
825 "'"
826 << TerminatorOpType::getOperationName() << '\'';
827 }
828
829 return success();
830 }
831
832 /// Ensure that the given region has the terminator required by this trait.
833 /// If OpBuilder is provided, use it to build the terminator and notify the
834 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
835 /// construct an OpBuilder with no listeners; this should only be used if no
836 /// OpBuilder is available at the call site, e.g., in the parser.
ensureTerminatorSingleBlockImplicitTerminator837 static void ensureTerminator(Region ®ion, Builder &builder,
838 Location loc) {
839 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
840 buildTerminator);
841 }
ensureTerminatorSingleBlockImplicitTerminator842 static void ensureTerminator(Region ®ion, OpBuilder &builder,
843 Location loc) {
844 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
845 buildTerminator);
846 }
847
848 Block *getBody(unsigned idx = 0) {
849 Region ®ion = this->getOperation()->getRegion(idx);
850 assert(!region.empty() && "unexpected empty region");
851 return ®ion.front();
852 }
853 Region &getBodyRegion(unsigned idx = 0) {
854 return this->getOperation()->getRegion(idx);
855 }
856
857 //===------------------------------------------------------------------===//
858 // Single Region Utilities
859 //===------------------------------------------------------------------===//
860
861 /// The following are a set of methods only enabled when the parent
862 /// operation has a single region. Each of these methods take an additional
863 /// template parameter that represents the concrete operation so that we
864 /// can use SFINAE to disable the methods for non-single region operations.
865 template <typename OpT, typename T = void>
866 using enable_if_single_region =
867 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
868
869 template <typename OpT = ConcreteType>
beginSingleBlockImplicitTerminator870 enable_if_single_region<OpT, Block::iterator> begin() {
871 return getBody()->begin();
872 }
873 template <typename OpT = ConcreteType>
endSingleBlockImplicitTerminator874 enable_if_single_region<OpT, Block::iterator> end() {
875 return getBody()->end();
876 }
877 template <typename OpT = ConcreteType>
frontSingleBlockImplicitTerminator878 enable_if_single_region<OpT, Operation &> front() {
879 return *begin();
880 }
881
882 /// Insert the operation into the back of the body, before the terminator.
883 template <typename OpT = ConcreteType>
push_backSingleBlockImplicitTerminator884 enable_if_single_region<OpT> push_back(Operation *op) {
885 insert(Block::iterator(getBody()->getTerminator()), op);
886 }
887
888 /// Insert the operation at the given insertion point. Note: The operation
889 /// is never inserted after the terminator, even if the insertion point is
890 /// end().
891 template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator892 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
893 insert(Block::iterator(insertPt), op);
894 }
895 template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator896 enable_if_single_region<OpT> insert(Block::iterator insertPt,
897 Operation *op) {
898 auto *body = getBody();
899 if (insertPt == body->end())
900 insertPt = Block::iterator(body->getTerminator());
901 body->getOperations().insert(insertPt, op);
902 }
903 };
904 };
905
906 //===----------------------------------------------------------------------===//
907 // Misc Traits
908
909 /// This class provides verification for ops that are known to have the same
910 /// operand shape: all operands are scalars, vectors/tensors of the same
911 /// shape.
912 template <typename ConcreteType>
913 class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
914 public:
verifyTrait(Operation * op)915 static LogicalResult verifyTrait(Operation *op) {
916 return impl::verifySameOperandsShape(op);
917 }
918 };
919
920 /// This class provides verification for ops that are known to have the same
921 /// operand and result shape: both are scalars, vectors/tensors of the same
922 /// shape.
923 template <typename ConcreteType>
924 class SameOperandsAndResultShape
925 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
926 public:
verifyTrait(Operation * op)927 static LogicalResult verifyTrait(Operation *op) {
928 return impl::verifySameOperandsAndResultShape(op);
929 }
930 };
931
932 /// This class provides verification for ops that are known to have the same
933 /// operand element type (or the type itself if it is scalar).
934 ///
935 template <typename ConcreteType>
936 class SameOperandsElementType
937 : public TraitBase<ConcreteType, SameOperandsElementType> {
938 public:
verifyTrait(Operation * op)939 static LogicalResult verifyTrait(Operation *op) {
940 return impl::verifySameOperandsElementType(op);
941 }
942 };
943
944 /// This class provides verification for ops that are known to have the same
945 /// operand and result element type (or the type itself if it is scalar).
946 ///
947 template <typename ConcreteType>
948 class SameOperandsAndResultElementType
949 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
950 public:
verifyTrait(Operation * op)951 static LogicalResult verifyTrait(Operation *op) {
952 return impl::verifySameOperandsAndResultElementType(op);
953 }
954 };
955
956 /// This class provides verification for ops that are known to have the same
957 /// operand and result type.
958 ///
959 /// Note: this trait subsumes the SameOperandsAndResultShape and
960 /// SameOperandsAndResultElementType traits.
961 template <typename ConcreteType>
962 class SameOperandsAndResultType
963 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
964 public:
verifyTrait(Operation * op)965 static LogicalResult verifyTrait(Operation *op) {
966 return impl::verifySameOperandsAndResultType(op);
967 }
968 };
969
970 /// This class verifies that any results of the specified op have a boolean
971 /// type, a vector thereof, or a tensor thereof.
972 template <typename ConcreteType>
973 class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
974 public:
verifyTrait(Operation * op)975 static LogicalResult verifyTrait(Operation *op) {
976 return impl::verifyResultsAreBoolLike(op);
977 }
978 };
979
980 /// This class verifies that any results of the specified op have a floating
981 /// point type, a vector thereof, or a tensor thereof.
982 template <typename ConcreteType>
983 class ResultsAreFloatLike
984 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
985 public:
verifyTrait(Operation * op)986 static LogicalResult verifyTrait(Operation *op) {
987 return impl::verifyResultsAreFloatLike(op);
988 }
989 };
990
991 /// This class verifies that any results of the specified op have a signless
992 /// integer or index type, a vector thereof, or a tensor thereof.
993 template <typename ConcreteType>
994 class ResultsAreSignlessIntegerLike
995 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
996 public:
verifyTrait(Operation * op)997 static LogicalResult verifyTrait(Operation *op) {
998 return impl::verifyResultsAreSignlessIntegerLike(op);
999 }
1000 };
1001
1002 /// This class adds property that the operation is commutative.
1003 template <typename ConcreteType>
1004 class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
1005 public:
getTraitProperties()1006 static AbstractOperation::OperationProperties getTraitProperties() {
1007 return static_cast<AbstractOperation::OperationProperties>(
1008 OperationProperty::Commutative);
1009 }
1010 };
1011
1012 /// This class adds property that the operation is an involution.
1013 /// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1014 template <typename ConcreteType>
1015 class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1016 public:
verifyTrait(Operation * op)1017 static LogicalResult verifyTrait(Operation *op) {
1018 static_assert(ConcreteType::template hasTrait<OneResult>(),
1019 "expected operation to produce one result");
1020 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1021 "expected operation to take one operand");
1022 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1023 "expected operation to preserve type");
1024 // Involution requires the operation to be side effect free as well
1025 // but currently this check is under a FIXME and is not actually done.
1026 return impl::verifyIsInvolution(op);
1027 }
1028
foldTrait(Operation * op,ArrayRef<Attribute> operands)1029 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1030 return impl::foldInvolution(op);
1031 }
1032 };
1033
1034 /// This class adds property that the operation is idempotent.
1035 /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
1036 template <typename ConcreteType>
1037 class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1038 public:
verifyTrait(Operation * op)1039 static LogicalResult verifyTrait(Operation *op) {
1040 static_assert(ConcreteType::template hasTrait<OneResult>(),
1041 "expected operation to produce one result");
1042 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1043 "expected operation to take one operand");
1044 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1045 "expected operation to preserve type");
1046 // Idempotent requires the operation to be side effect free as well
1047 // but currently this check is under a FIXME and is not actually done.
1048 return impl::verifyIsIdempotent(op);
1049 }
1050
foldTrait(Operation * op,ArrayRef<Attribute> operands)1051 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1052 return impl::foldIdempotent(op);
1053 }
1054 };
1055
1056 /// This class verifies that all operands of the specified op have a float type,
1057 /// a vector thereof, or a tensor thereof.
1058 template <typename ConcreteType>
1059 class OperandsAreFloatLike
1060 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1061 public:
verifyTrait(Operation * op)1062 static LogicalResult verifyTrait(Operation *op) {
1063 return impl::verifyOperandsAreFloatLike(op);
1064 }
1065 };
1066
1067 /// This class verifies that all operands of the specified op have a signless
1068 /// integer or index type, a vector thereof, or a tensor thereof.
1069 template <typename ConcreteType>
1070 class OperandsAreSignlessIntegerLike
1071 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1072 public:
verifyTrait(Operation * op)1073 static LogicalResult verifyTrait(Operation *op) {
1074 return impl::verifyOperandsAreSignlessIntegerLike(op);
1075 }
1076 };
1077
1078 /// This class verifies that all operands of the specified op have the same
1079 /// type.
1080 template <typename ConcreteType>
1081 class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1082 public:
verifyTrait(Operation * op)1083 static LogicalResult verifyTrait(Operation *op) {
1084 return impl::verifySameTypeOperands(op);
1085 }
1086 };
1087
1088 /// This class provides the API for a sub-set of ops that are known to be
1089 /// constant-like. These are non-side effecting operations with one result and
1090 /// zero operands that can always be folded to a specific attribute value.
1091 template <typename ConcreteType>
1092 class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1093 public:
verifyTrait(Operation * op)1094 static LogicalResult verifyTrait(Operation *op) {
1095 static_assert(ConcreteType::template hasTrait<OneResult>(),
1096 "expected operation to produce one result");
1097 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1098 "expected operation to take zero operands");
1099 // TODO: We should verify that the operation can always be folded, but this
1100 // requires that the attributes of the op already be verified. We should add
1101 // support for verifying traits "after" the operation to enable this use
1102 // case.
1103 return success();
1104 }
1105 };
1106
1107 /// This class provides the API for ops that are known to be isolated from
1108 /// above.
1109 template <typename ConcreteType>
1110 class IsIsolatedFromAbove
1111 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1112 public:
getTraitProperties()1113 static AbstractOperation::OperationProperties getTraitProperties() {
1114 return static_cast<AbstractOperation::OperationProperties>(
1115 OperationProperty::IsolatedFromAbove);
1116 }
verifyTrait(Operation * op)1117 static LogicalResult verifyTrait(Operation *op) {
1118 for (auto ®ion : op->getRegions())
1119 if (!region.isIsolatedFromAbove(op->getLoc()))
1120 return failure();
1121 return success();
1122 }
1123 };
1124
1125 /// A trait of region holding operations that defines a new scope for polyhedral
1126 /// optimization purposes. Any SSA values of 'index' type that either dominate
1127 /// such an operation or are used at the top-level of such an operation
1128 /// automatically become valid symbols for the polyhedral scope defined by that
1129 /// operation. For more details, see `Traits.md#AffineScope`.
1130 template <typename ConcreteType>
1131 class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1132 public:
verifyTrait(Operation * op)1133 static LogicalResult verifyTrait(Operation *op) {
1134 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1135 "expected operation to have one or more regions");
1136 return success();
1137 }
1138 };
1139
1140 /// A trait of region holding operations that define a new scope for automatic
1141 /// allocations, i.e., allocations that are freed when control is transferred
1142 /// back from the operation's region. Any operations performing such allocations
1143 /// (for eg. std.alloca) will have their allocations automatically freed at
1144 /// their closest enclosing operation with this trait.
1145 template <typename ConcreteType>
1146 class AutomaticAllocationScope
1147 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1148 public:
verifyTrait(Operation * op)1149 static LogicalResult verifyTrait(Operation *op) {
1150 if (op->hasTrait<ZeroRegion>())
1151 return op->emitOpError("is expected to have regions");
1152 return success();
1153 }
1154 };
1155
1156 /// This class provides a verifier for ops that are expecting their parent
1157 /// to be one of the given parent ops
1158 template <typename... ParentOpTypes>
1159 struct HasParent {
1160 template <typename ConcreteType>
1161 class Impl : public TraitBase<ConcreteType, Impl> {
1162 public:
verifyTraitHasParent1163 static LogicalResult verifyTrait(Operation *op) {
1164 if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
1165 return success();
1166
1167 return op->emitOpError()
1168 << "expects parent op "
1169 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1170 << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1171 << "'";
1172 }
1173 };
1174 };
1175
1176 /// A trait for operations that have an attribute specifying operand segments.
1177 ///
1178 /// Certain operations can have multiple variadic operands and their size
1179 /// relationship is not always known statically. For such cases, we need
1180 /// a per-op-instance specification to divide the operands into logical groups
1181 /// or segments. This can be modeled by attributes. The attribute will be named
1182 /// as `operand_segment_sizes`.
1183 ///
1184 /// This trait verifies the attribute for specifying operand segments has
1185 /// the correct type (1D vector) and values (non-negative), etc.
1186 template <typename ConcreteType>
1187 class AttrSizedOperandSegments
1188 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1189 public:
getOperandSegmentSizeAttr()1190 static StringRef getOperandSegmentSizeAttr() {
1191 return "operand_segment_sizes";
1192 }
1193
verifyTrait(Operation * op)1194 static LogicalResult verifyTrait(Operation *op) {
1195 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1196 op, getOperandSegmentSizeAttr());
1197 }
1198 };
1199
1200 /// Similar to AttrSizedOperandSegments but used for results.
1201 template <typename ConcreteType>
1202 class AttrSizedResultSegments
1203 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1204 public:
getResultSegmentSizeAttr()1205 static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1206
verifyTrait(Operation * op)1207 static LogicalResult verifyTrait(Operation *op) {
1208 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1209 op, getResultSegmentSizeAttr());
1210 }
1211 };
1212
1213 /// This trait provides a verifier for ops that are expecting their regions to
1214 /// not have any arguments
1215 template <typename ConcrentType>
1216 struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
verifyTraitNoRegionArguments1217 static LogicalResult verifyTrait(Operation *op) {
1218 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1219 }
1220 };
1221
1222 // This trait is used to flag operations that consume or produce
1223 // values of `MemRef` type where those references can be 'normalized'.
1224 // TODO: Right now, the operands of an operation are either all normalizable,
1225 // or not. In the future, we may want to allow some of the operands to be
1226 // normalizable.
1227 template <typename ConcrentType>
1228 struct MemRefsNormalizable
1229 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1230
1231 /// This trait tags scalar ops that also can be applied to vectors/tensors, with
1232 /// their semantics on vectors/tensors being elementwise application.
1233 ///
1234 /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1235 /// trait. In particular, broadcasting behavior is not allowed. This trait
1236 /// describes a set of invariants that allow systematic
1237 /// vectorization/tensorization, and the reverse, scalarization. The properties
1238 /// needed for this also can be used to implement a number of
1239 /// transformations/analyses/interfaces.
1240 ///
1241 /// An `ElementwiseMappable` op must satisfy the following properties:
1242 ///
1243 /// 1. If any result is a vector (resp. tensor), then at least one operand must
1244 /// be a vector (resp. tensor).
1245 /// 2. If any operand is a vector (resp. tensor), then there must be at least
1246 /// one result, and all results must be vectors (resp. tensors).
1247 /// 3. The static types of all vector (resp. tensor) operands and results must
1248 /// have the same shape.
1249 /// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
1250 /// must be the same, otherwise the op has undefined behavior.
1251 /// 5. ("systematic scalarization" property) If an op has vector/tensor
1252 /// operands/results, then the same op, with the operand/result types changed to
1253 /// their corresponding element type, shall be a verifier-valid op.
1254 /// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
1255 /// applying the scalarized version of the op for each corresponding element of
1256 /// the vector (resp. tensor) operands in parallel.
1257 /// 7. ("systematic vectorization/tensorization" property) If an op has
1258 /// scalar operands/results, the op shall remain verifier-valid if all scalar
1259 /// operands are replaced with vectors/tensors of the same shape and
1260 /// corresponding element types.
1261 ///
1262 /// Together, these properties provide an easy way for scalar operations to
1263 /// conveniently generalize their behavior to vectors/tensors, and systematize
1264 /// conversion between these forms.
1265 ///
1266 /// Examples:
1267 /// ```
1268 /// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
1269 /// // Applying the systematic vectorization/tensorization property, this op
1270 /// // must also be valid:
1271 /// %tensor = "std.addf"(%a_tensor, %b_tensor)
1272 /// : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
1273 ///
1274 /// // These properties generalize well to the cases of non-scalar operands.
1275 /// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
1276 /// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1277 /// // Applying the systematic vectorization / tensorization property, this
1278 /// // op must also be valid:
1279 /// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
1280 /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1281 /// -> tensor<?xf32>
1282 /// // Applying the systematic scalarization property, this op must also
1283 /// // be valid.
1284 /// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1285 /// : (i1, f32, f32) -> f32
1286 /// ```
1287 ///
1288 /// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
1289 /// implementing a new "ElementwiseMappableTypeInterface" that describes types
1290 /// for which it makes sense to apply a scalar function to each element.
1291 ///
1292 /// Rationale:
1293 /// - 1. and 2. guarantee a well-defined iteration space for 6.
1294 /// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
1295 /// results, which complicate a generic definition of the iteration space.
1296 /// - 3. guarantees that folding can be done across scalars/vectors/tensors
1297 /// with the same pattern, as otherwise lots of special handling of type
1298 /// mismatches would be needed.
1299 /// - 4. guarantees that no error handling cases need to be considered.
1300 /// - Higher-level dialects should reify any needed guards / error handling
1301 /// code before lowering to an ElementwiseMappable op.
1302 /// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
1303 /// semantics and provide a constructive procedure for IR transformations
1304 /// to e.g. create scalar loop bodies from tensor ops.
1305 /// - 7. provides the reverse of 5., which when chained together allows
1306 /// reasoning about the relationship between the tensor and vector case.
1307 /// Additionally, it permits reasoning about promoting scalars to
1308 /// vectors/tensors via broadcasting in cases like `%select_scalar_pred`
1309 /// above.
1310 template <typename ConcreteType>
1311 struct ElementwiseMappable
1312 : public TraitBase<ConcreteType, ElementwiseMappable> {
verifyTraitElementwiseMappable1313 static LogicalResult verifyTrait(Operation *op) {
1314 return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
1315 }
1316 };
1317
1318 } // end namespace OpTrait
1319
1320 //===----------------------------------------------------------------------===//
1321 // Internal Trait Utilities
1322 //===----------------------------------------------------------------------===//
1323
1324 namespace op_definition_impl {
1325 //===----------------------------------------------------------------------===//
1326 // Trait Existence
1327
1328 /// Returns true if this given Trait ID matches the IDs of any of the provided
1329 /// trait types `Traits`.
1330 template <template <typename T> class... Traits>
hasTrait(TypeID traitID)1331 static bool hasTrait(TypeID traitID) {
1332 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1333 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1334 if (traitIDs[i] == traitID)
1335 return true;
1336 return false;
1337 }
1338
1339 //===----------------------------------------------------------------------===//
1340 // Trait Folding
1341
1342 /// Trait to check if T provides a 'foldTrait' method for single result
1343 /// operations.
1344 template <typename T, typename... Args>
1345 using has_single_result_fold_trait = decltype(T::foldTrait(
1346 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1347 template <typename T>
1348 using detect_has_single_result_fold_trait =
1349 llvm::is_detected<has_single_result_fold_trait, T>;
1350 /// Trait to check if T provides a general 'foldTrait' method.
1351 template <typename T, typename... Args>
1352 using has_fold_trait =
1353 decltype(T::foldTrait(std::declval<Operation *>(),
1354 std::declval<ArrayRef<Attribute>>(),
1355 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1356 template <typename T>
1357 using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1358 /// Trait to check if T provides any `foldTrait` method.
1359 /// NOTE: This should use std::disjunction when C++17 is available.
1360 template <typename T>
1361 using detect_has_any_fold_trait =
1362 std::conditional_t<bool(detect_has_fold_trait<T>::value),
1363 detect_has_fold_trait<T>,
1364 detect_has_single_result_fold_trait<T>>;
1365
1366 /// Returns the result of folding a trait that implements a `foldTrait` function
1367 /// that is specialized for operations that have a single result.
1368 template <typename Trait>
1369 static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1370 LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1371 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1372 SmallVectorImpl<OpFoldResult> &results) {
1373 assert(op->hasTrait<OpTrait::OneResult>() &&
1374 "expected trait on non single-result operation to implement the "
1375 "general `foldTrait` method");
1376 // If a previous trait has already been folded and replaced this operation, we
1377 // fail to fold this trait.
1378 if (!results.empty())
1379 return failure();
1380
1381 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1382 if (result.template dyn_cast<Value>() != op->getResult(0))
1383 results.push_back(result);
1384 return success();
1385 }
1386 return failure();
1387 }
1388 /// Returns the result of folding a trait that implements a generalized
1389 /// `foldTrait` function that is supports any operation type.
1390 template <typename Trait>
1391 static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1392 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1393 SmallVectorImpl<OpFoldResult> &results) {
1394 // If a previous trait has already been folded and replaced this operation, we
1395 // fail to fold this trait.
1396 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1397 }
1398
1399 /// The internal implementation of `foldTraits` below that returns the result of
1400 /// folding a set of trait types `Ts` that implement a `foldTrait` method.
1401 template <typename... Ts>
foldTraitsImpl(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results,std::tuple<Ts...> *)1402 static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
1403 SmallVectorImpl<OpFoldResult> &results,
1404 std::tuple<Ts...> *) {
1405 bool anyFolded = false;
1406 (void)std::initializer_list<int>{
1407 (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1408 return success(anyFolded);
1409 }
1410
1411 /// Given a tuple type containing a set of traits that contain a `foldTrait`
1412 /// method, return the result of folding the given operation.
1413 template <typename TraitTupleT>
1414 static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1415 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1416 SmallVectorImpl<OpFoldResult> &results) {
1417 return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
1418 }
1419 /// A variant of the method above that is specialized when there are no traits
1420 /// that contain a `foldTrait` method.
1421 template <typename TraitTupleT>
1422 static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1423 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1424 SmallVectorImpl<OpFoldResult> &results) {
1425 return failure();
1426 }
1427
1428 //===----------------------------------------------------------------------===//
1429 // Trait Properties
1430
1431 /// Trait to check if T provides a `getTraitProperties` method.
1432 template <typename T, typename... Args>
1433 using has_get_trait_properties = decltype(T::getTraitProperties());
1434 template <typename T>
1435 using detect_has_get_trait_properties =
1436 llvm::is_detected<has_get_trait_properties, T>;
1437
1438 /// The internal implementation of `getTraitProperties` below that returns the
1439 /// OR of invoking `getTraitProperties` on all of the provided trait types `Ts`.
1440 template <typename... Ts>
1441 static AbstractOperation::OperationProperties
getTraitPropertiesImpl(std::tuple<Ts...> *)1442 getTraitPropertiesImpl(std::tuple<Ts...> *) {
1443 AbstractOperation::OperationProperties result = 0;
1444 (void)std::initializer_list<int>{(result |= Ts::getTraitProperties(), 0)...};
1445 return result;
1446 }
1447
1448 /// Given a tuple type containing a set of traits that contain a
1449 /// `getTraitProperties` method, return the OR of all of the results of invoking
1450 /// those methods.
1451 template <typename TraitTupleT>
getTraitProperties()1452 static AbstractOperation::OperationProperties getTraitProperties() {
1453 return getTraitPropertiesImpl((TraitTupleT *)nullptr);
1454 }
1455
1456 //===----------------------------------------------------------------------===//
1457 // Trait Verification
1458
1459 /// Trait to check if T provides a `verifyTrait` method.
1460 template <typename T, typename... Args>
1461 using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1462 template <typename T>
1463 using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1464
1465 /// The internal implementation of `verifyTraits` below that returns the result
1466 /// of verifying the current operation with all of the provided trait types
1467 /// `Ts`.
1468 template <typename... Ts>
verifyTraitsImpl(Operation * op,std::tuple<Ts...> *)1469 static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
1470 LogicalResult result = success();
1471 (void)std::initializer_list<int>{
1472 (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
1473 return result;
1474 }
1475
1476 /// Given a tuple type containing a set of traits that contain a
1477 /// `verifyTrait` method, return the result of verifying the given operation.
1478 template <typename TraitTupleT>
verifyTraits(Operation * op)1479 static LogicalResult verifyTraits(Operation *op) {
1480 return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
1481 }
1482 } // namespace op_definition_impl
1483
1484 //===----------------------------------------------------------------------===//
1485 // Operation Definition classes
1486 //===----------------------------------------------------------------------===//
1487
1488 /// This provides public APIs that all operations should have. The template
1489 /// argument 'ConcreteType' should be the concrete type by CRTP and the others
1490 /// are base classes by the policy pattern.
1491 template <typename ConcreteType, template <typename T> class... Traits>
1492 class Op : public OpState, public Traits<ConcreteType>... {
1493 public:
1494 /// Inherit getOperation from `OpState`.
1495 using OpState::getOperation;
1496
1497 /// Return if this operation contains the provided trait.
1498 template <template <typename T> class Trait>
hasTrait()1499 static constexpr bool hasTrait() {
1500 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1501 }
1502
1503 /// Create a deep copy of this operation.
clone()1504 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1505
1506 /// Create a partial copy of this operation without traversing into attached
1507 /// regions. The new operation will have the same number of regions as the
1508 /// original one, but they will be left empty.
cloneWithoutRegions()1509 ConcreteType cloneWithoutRegions() {
1510 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1511 }
1512
1513 /// Return true if this "op class" can match against the specified operation.
classof(Operation * op)1514 static bool classof(Operation *op) {
1515 if (auto *abstractOp = op->getAbstractOperation())
1516 return TypeID::get<ConcreteType>() == abstractOp->typeID;
1517 #ifndef NDEBUG
1518 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1519 llvm::report_fatal_error(
1520 "classof on '" + ConcreteType::getOperationName() +
1521 "' failed due to the operation not being registered");
1522 #endif
1523 return false;
1524 }
1525
1526 /// Expose the type we are instantiated on to template machinery that may want
1527 /// to introspect traits on this operation.
1528 using ConcreteOpType = ConcreteType;
1529
1530 /// This is a public constructor. Any op can be initialized to null.
Op()1531 explicit Op() : OpState(nullptr) {}
Op(std::nullptr_t)1532 Op(std::nullptr_t) : OpState(nullptr) {}
1533
1534 /// This is a public constructor to enable access via the llvm::cast family of
1535 /// methods. This should not be used directly.
Op(Operation * state)1536 explicit Op(Operation *state) : OpState(state) {}
1537
1538 /// Methods for supporting PointerLikeTypeTraits.
getAsOpaquePointer()1539 const void *getAsOpaquePointer() const {
1540 return static_cast<const void *>((Operation *)*this);
1541 }
getFromOpaquePointer(const void * pointer)1542 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1543 return ConcreteOpType(
1544 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1545 }
1546
1547 private:
1548 /// Trait to check if T provides a 'fold' method for a single result op.
1549 template <typename T, typename... Args>
1550 using has_single_result_fold =
1551 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1552 template <typename T>
1553 using detect_has_single_result_fold =
1554 llvm::is_detected<has_single_result_fold, T>;
1555 /// Trait to check if T provides a general 'fold' method.
1556 template <typename T, typename... Args>
1557 using has_fold = decltype(
1558 std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
1559 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1560 template <typename T>
1561 using detect_has_fold = llvm::is_detected<has_fold, T>;
1562 /// Trait to check if T provides a 'print' method.
1563 template <typename T, typename... Args>
1564 using has_print =
1565 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1566 template <typename T>
1567 using detect_has_print = llvm::is_detected<has_print, T>;
1568 /// A tuple type containing the traits that have a `foldTrait` function.
1569 using FoldableTraitsTupleT = typename detail::FilterTypes<
1570 op_definition_impl::detect_has_any_fold_trait,
1571 Traits<ConcreteType>...>::type;
1572 /// A tuple type containing the traits that have a verify function.
1573 using VerifiableTraitsTupleT =
1574 typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
1575 Traits<ConcreteType>...>::type;
1576
1577 /// Returns the properties of this operation by combining the properties
1578 /// defined by the traits.
getOperationProperties()1579 static AbstractOperation::OperationProperties getOperationProperties() {
1580 return op_definition_impl::getTraitProperties<typename detail::FilterTypes<
1581 op_definition_impl::detect_has_get_trait_properties,
1582 Traits<ConcreteType>...>::type>();
1583 }
1584
1585 /// Returns an interface map containing the interfaces registered to this
1586 /// operation.
getInterfaceMap()1587 static detail::InterfaceMap getInterfaceMap() {
1588 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1589 }
1590
1591 /// Return the internal implementations of each of the AbstractOperation
1592 /// hooks.
1593 /// Implementation of `FoldHookFn` AbstractOperation hook.
getFoldHookFn()1594 static AbstractOperation::FoldHookFn getFoldHookFn() {
1595 return getFoldHookFnImpl<ConcreteType>();
1596 }
1597 /// The internal implementation of `getFoldHookFn` above that is invoked if
1598 /// the operation is single result and defines a `fold` method.
1599 template <typename ConcreteOpT>
1600 static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1601 Traits<ConcreteOpT>...>::value &&
1602 detect_has_single_result_fold<ConcreteOpT>::value,
1603 AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1604 getFoldHookFnImpl() {
1605 return &foldSingleResultHook<ConcreteOpT>;
1606 }
1607 /// The internal implementation of `getFoldHookFn` above that is invoked if
1608 /// the operation is not single result and defines a `fold` method.
1609 template <typename ConcreteOpT>
1610 static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1611 Traits<ConcreteOpT>...>::value &&
1612 detect_has_fold<ConcreteOpT>::value,
1613 AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1614 getFoldHookFnImpl() {
1615 return &foldHook<ConcreteOpT>;
1616 }
1617 /// The internal implementation of `getFoldHookFn` above that is invoked if
1618 /// the operation does not define a `fold` method.
1619 template <typename ConcreteOpT>
1620 static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1621 !detect_has_fold<ConcreteOpT>::value,
1622 AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1623 getFoldHookFnImpl() {
1624 // In this case, we only need to fold the traits of the operation.
1625 return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
1626 }
1627 /// Return the result of folding a single result operation that defines a
1628 /// `fold` method.
1629 template <typename ConcreteOpT>
1630 static LogicalResult
foldSingleResultHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1631 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1632 SmallVectorImpl<OpFoldResult> &results) {
1633 OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1634
1635 // If the fold failed or was in-place, try to fold the traits of the
1636 // operation.
1637 if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1638 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1639 op, operands, results)))
1640 return success();
1641 return success(static_cast<bool>(result));
1642 }
1643 results.push_back(result);
1644 return success();
1645 }
1646 /// Return the result of folding an operation that defines a `fold` method.
1647 template <typename ConcreteOpT>
foldHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1648 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1649 SmallVectorImpl<OpFoldResult> &results) {
1650 LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1651
1652 // If the fold failed or was in-place, try to fold the traits of the
1653 // operation.
1654 if (failed(result) || results.empty()) {
1655 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1656 op, operands, results)))
1657 return success();
1658 }
1659 return result;
1660 }
1661
1662 /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
1663 static AbstractOperation::GetCanonicalizationPatternsFn
getGetCanonicalizationPatternsFn()1664 getGetCanonicalizationPatternsFn() {
1665 return &ConcreteType::getCanonicalizationPatterns;
1666 }
1667 /// Implementation of `GetHasTraitFn`
getHasTraitFn()1668 static AbstractOperation::HasTraitFn getHasTraitFn() {
1669 return &op_definition_impl::hasTrait<Traits...>;
1670 }
1671 /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
getParseAssemblyFn()1672 static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
1673 return &ConcreteType::parse;
1674 }
1675 /// Implementation of `PrintAssemblyFn` AbstractOperation hook.
getPrintAssemblyFn()1676 static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
1677 return getPrintAssemblyFnImpl<ConcreteType>();
1678 }
1679 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1680 /// the concrete operation does not define a `print` method.
1681 template <typename ConcreteOpT>
1682 static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1683 AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1684 getPrintAssemblyFnImpl() {
1685 return &OpState::print;
1686 }
1687 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1688 /// the concrete operation defines a `print` method.
1689 template <typename ConcreteOpT>
1690 static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1691 AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1692 getPrintAssemblyFnImpl() {
1693 return &printAssembly;
1694 }
printAssembly(Operation * op,OpAsmPrinter & p)1695 static void printAssembly(Operation *op, OpAsmPrinter &p) {
1696 return cast<ConcreteType>(op).print(p);
1697 }
1698 /// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
getVerifyInvariantsFn()1699 static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
1700 return &verifyInvariants;
1701 }
verifyInvariants(Operation * op)1702 static LogicalResult verifyInvariants(Operation *op) {
1703 return failure(
1704 failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
1705 failed(cast<ConcreteType>(op).verify()));
1706 }
1707
1708 /// Allow access to internal implementation methods.
1709 friend AbstractOperation;
1710 };
1711
1712 /// This class represents the base of an operation interface. See the definition
1713 /// of `detail::Interface` for requirements on the `Traits` type.
1714 template <typename ConcreteType, typename Traits>
1715 class OpInterface
1716 : public detail::Interface<ConcreteType, Operation *, Traits,
1717 Op<ConcreteType>, OpTrait::TraitBase> {
1718 public:
1719 using Base = OpInterface<ConcreteType, Traits>;
1720 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1721 Op<ConcreteType>, OpTrait::TraitBase>;
1722
1723 /// Inherit the base class constructor.
1724 using InterfaceBase::InterfaceBase;
1725
1726 protected:
1727 /// Returns the impl interface instance for the given operation.
getInterfaceFor(Operation * op)1728 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1729 // Access the raw interface from the abstract operation.
1730 auto *abstractOp = op->getAbstractOperation();
1731 return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
1732 }
1733
1734 /// Allow access to `getInterfaceFor`.
1735 friend InterfaceBase;
1736 };
1737
1738 //===----------------------------------------------------------------------===//
1739 // Common Operation Folders/Parsers/Printers
1740 //===----------------------------------------------------------------------===//
1741
1742 // These functions are out-of-line implementations of the methods in UnaryOp and
1743 // BinaryOp, which avoids them being template instantiated/duplicated.
1744 namespace impl {
1745 ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
1746 OperationState &result);
1747
1748 void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1749 Value rhs);
1750 ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1751 OperationState &result);
1752
1753 // Prints the given binary `op` in custom assembly form if both the two operands
1754 // and the result have the same time. Otherwise, prints the generic assembly
1755 // form.
1756 void printOneResultOp(Operation *op, OpAsmPrinter &p);
1757 } // namespace impl
1758
1759 // These functions are out-of-line implementations of the methods in
1760 // CastOpInterface, which avoids them being template instantiated/duplicated.
1761 namespace impl {
1762 /// Attempt to fold the given cast operation.
1763 LogicalResult foldCastInterfaceOp(Operation *op,
1764 ArrayRef<Attribute> attrOperands,
1765 SmallVectorImpl<OpFoldResult> &foldResults);
1766 /// Attempt to verify the given cast operation.
1767 LogicalResult verifyCastInterfaceOp(
1768 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
1769
1770 // TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
1771 // need for them, but some older ODS code in `std` still depends on them).
1772 void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1773 Type destType);
1774 ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
1775 void printCastOp(Operation *op, OpAsmPrinter &p);
1776 // TODO: These methods are deprecated in favor of CastOpInterface. Remove them
1777 // when all uses have been updated. Also, consider adding functionality to
1778 // CastOpInterface to be able to perform the ChainedTensorCast canonicalization
1779 // generically.
1780 Value foldCastOp(Operation *op);
1781 LogicalResult verifyCastOp(Operation *op,
1782 function_ref<bool(Type, Type)> areCastCompatible);
1783 } // namespace impl
1784 } // end namespace mlir
1785
1786 #endif
1787