1 //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper classes for implementing the "Op" types.  This
10 // includes the Op type, which is the base class for Op class definitions,
11 // as well as number of traits in the OpTrait namespace that provide a
12 // declarative way to specify properties of Ops.
13 //
14 // The purpose of these types are to allow light-weight implementation of
15 // concrete ops (like DimOp) with very little boilerplate.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef MLIR_IR_OPDEFINITION_H
20 #define MLIR_IR_OPDEFINITION_H
21 
22 #include "mlir/IR/Operation.h"
23 #include "llvm/Support/PointerLikeTypeTraits.h"
24 
25 #include <type_traits>
26 
27 namespace mlir {
28 class Builder;
29 class OpBuilder;
30 
31 /// This class represents success/failure for operation parsing. It is
32 /// essentially a simple wrapper class around LogicalResult that allows for
33 /// explicit conversion to bool. This allows for the parser to chain together
34 /// parse rules without the clutter of "failed/succeeded".
35 class ParseResult : public LogicalResult {
36 public:
LogicalResult(result)37   ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
38 
39   // Allow diagnostics emitted during parsing to be converted to failure.
ParseResult(const InFlightDiagnostic &)40   ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
ParseResult(const Diagnostic &)41   ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
42 
43   /// Failure is true in a boolean context.
44   explicit operator bool() const { return failed(*this); }
45 };
46 /// This class implements `Optional` functionality for ParseResult. We don't
47 /// directly use Optional here, because it provides an implicit conversion
48 /// to 'bool' which we want to avoid. This class is used to implement tri-state
49 /// 'parseOptional' functions that may have a failure mode when parsing that
50 /// shouldn't be attributed to "not present".
51 class OptionalParseResult {
52 public:
53   OptionalParseResult() = default;
OptionalParseResult(LogicalResult result)54   OptionalParseResult(LogicalResult result) : impl(result) {}
OptionalParseResult(ParseResult result)55   OptionalParseResult(ParseResult result) : impl(result) {}
OptionalParseResult(const InFlightDiagnostic &)56   OptionalParseResult(const InFlightDiagnostic &)
57       : OptionalParseResult(failure()) {}
OptionalParseResult(llvm::NoneType)58   OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
59 
60   /// Returns true if we contain a valid ParseResult value.
hasValue()61   bool hasValue() const { return impl.hasValue(); }
62 
63   /// Access the internal ParseResult value.
getValue()64   ParseResult getValue() const { return impl.getValue(); }
65   ParseResult operator*() const { return getValue(); }
66 
67 private:
68   Optional<ParseResult> impl;
69 };
70 
71 // These functions are out-of-line utilities, which avoids them being template
72 // instantiated/duplicated.
73 namespace impl {
74 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
75 /// region's only block if it does not have a terminator already. If the region
76 /// is empty, insert a new block first. `buildTerminatorOp` should return the
77 /// terminator operation to insert.
78 void ensureRegionTerminator(
79     Region &region, OpBuilder &builder, Location loc,
80     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
81 void ensureRegionTerminator(
82     Region &region, Builder &builder, Location loc,
83     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
84 
85 } // namespace impl
86 
87 /// This is the concrete base class that holds the operation pointer and has
88 /// non-generic methods that only depend on State (to avoid having them
89 /// instantiated on template types that don't affect them.
90 ///
91 /// This also has the fallback implementations of customization hooks for when
92 /// they aren't customized.
93 class OpState {
94 public:
95   /// Ops are pointer-like, so we allow implicit conversion to bool.
96   operator bool() { return getOperation() != nullptr; }
97 
98   /// This implicitly converts to Operation*.
99   operator Operation *() const { return state; }
100 
101   /// Shortcut of `->` to access a member of Operation.
102   Operation *operator->() const { return state; }
103 
104   /// Return the operation that this refers to.
getOperation()105   Operation *getOperation() { return state; }
106 
107   /// Return the context this operation belongs to.
getContext()108   MLIRContext *getContext() { return getOperation()->getContext(); }
109 
110   /// Print the operation to the given stream.
111   void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
112     state->print(os, flags);
113   }
114   void print(raw_ostream &os, AsmState &asmState,
115              OpPrintingFlags flags = llvm::None) {
116     state->print(os, asmState, flags);
117   }
118 
119   /// Dump this operation.
dump()120   void dump() { state->dump(); }
121 
122   /// The source location the operation was defined or derived from.
getLoc()123   Location getLoc() { return state->getLoc(); }
setLoc(Location loc)124   void setLoc(Location loc) { state->setLoc(loc); }
125 
126   /// Return all of the attributes on this operation.
getAttrs()127   ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
128 
129   /// A utility iterator that filters out non-dialect attributes.
130   using dialect_attr_iterator = Operation::dialect_attr_iterator;
131   using dialect_attr_range = Operation::dialect_attr_range;
132 
133   /// Set the dialect attributes for this operation, and preserve all dependent.
134   template <typename DialectAttrs>
setDialectAttrs(DialectAttrs && attrs)135   void setDialectAttrs(DialectAttrs &&attrs) {
136     state->setDialectAttrs(std::forward<DialectAttrs>(attrs));
137   }
138 
139   /// Remove the attribute with the specified name if it exists. Return the
140   /// attribute that was erased, or nullptr if there was no attribute with such
141   /// name.
removeAttr(Identifier name)142   Attribute removeAttr(Identifier name) { return state->removeAttr(name); }
removeAttr(StringRef name)143   Attribute removeAttr(StringRef name) {
144     return state->removeAttr(Identifier::get(name, getContext()));
145   }
146 
147   /// Return true if there are no users of any results of this operation.
use_empty()148   bool use_empty() { return state->use_empty(); }
149 
150   /// Remove this operation from its parent block and delete it.
erase()151   void erase() { state->erase(); }
152 
153   /// Emit an error with the op name prefixed, like "'dim' op " which is
154   /// convenient for verifiers.
155   InFlightDiagnostic emitOpError(const Twine &message = {});
156 
157   /// Emit an error about fatal conditions with this operation, reporting up to
158   /// any diagnostic handlers that may be listening.
159   InFlightDiagnostic emitError(const Twine &message = {});
160 
161   /// Emit a warning about this operation, reporting up to any diagnostic
162   /// handlers that may be listening.
163   InFlightDiagnostic emitWarning(const Twine &message = {});
164 
165   /// Emit a remark about this operation, reporting up to any diagnostic
166   /// handlers that may be listening.
167   InFlightDiagnostic emitRemark(const Twine &message = {});
168 
169   /// Walk the operation in postorder, calling the callback for each nested
170   /// operation(including this one).
171   /// See Operation::walk for more details.
172   template <typename FnT, typename RetT = detail::walkResultType<FnT>>
walk(FnT && callback)173   RetT walk(FnT &&callback) {
174     return state->walk(std::forward<FnT>(callback));
175   }
176 
177   // These are default implementations of customization hooks.
178 public:
179   /// This hook returns any canonicalization pattern rewrites that the operation
180   /// supports, for use by the canonicalization pass.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)181   static void getCanonicalizationPatterns(OwningRewritePatternList &results,
182                                           MLIRContext *context) {}
183 
184 protected:
185   /// If the concrete type didn't implement a custom verifier hook, just fall
186   /// back to this one which accepts everything.
verify()187   LogicalResult verify() { return success(); }
188 
189   /// Unless overridden, the custom assembly form of an op is always rejected.
190   /// Op implementations should implement this to return failure.
191   /// On success, they should fill in result with the fields to use.
192   static ParseResult parse(OpAsmParser &parser, OperationState &result);
193 
194   // The fallback for the printer is to print it the generic assembly form.
195   static void print(Operation *op, OpAsmPrinter &p);
196 
197   /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
198   /// so we can cast it away here.
OpState(Operation * state)199   explicit OpState(Operation *state) : state(state) {}
200 
201 private:
202   Operation *state;
203 
204   /// Allow access to internal hook implementation methods.
205   friend AbstractOperation;
206 };
207 
208 // Allow comparing operators.
209 inline bool operator==(OpState lhs, OpState rhs) {
210   return lhs.getOperation() == rhs.getOperation();
211 }
212 inline bool operator!=(OpState lhs, OpState rhs) {
213   return lhs.getOperation() != rhs.getOperation();
214 }
215 
216 raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
217 
218 /// This class represents a single result from folding an operation.
219 class OpFoldResult : public PointerUnion<Attribute, Value> {
220   using PointerUnion<Attribute, Value>::PointerUnion;
221 
222 public:
dump()223   void dump() { llvm::errs() << *this << "\n"; }
224 };
225 
226 /// Allow printing to a stream.
227 inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
228   if (Value value = ofr.dyn_cast<Value>())
229     value.print(os);
230   else
231     ofr.dyn_cast<Attribute>().print(os);
232   return os;
233 }
234 
235 /// Allow printing to a stream.
236 inline raw_ostream &operator<<(raw_ostream &os, OpState &op) {
237   op.print(os, OpPrintingFlags().useLocalScope());
238   return os;
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Operation Trait Types
243 //===----------------------------------------------------------------------===//
244 
245 namespace OpTrait {
246 
247 // These functions are out-of-line implementations of the methods in the
248 // corresponding trait classes.  This avoids them being template
249 // instantiated/duplicated.
250 namespace impl {
251 OpFoldResult foldIdempotent(Operation *op);
252 OpFoldResult foldInvolution(Operation *op);
253 LogicalResult verifyZeroOperands(Operation *op);
254 LogicalResult verifyOneOperand(Operation *op);
255 LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
256 LogicalResult verifyIsIdempotent(Operation *op);
257 LogicalResult verifyIsInvolution(Operation *op);
258 LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
259 LogicalResult verifyOperandsAreFloatLike(Operation *op);
260 LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
261 LogicalResult verifySameTypeOperands(Operation *op);
262 LogicalResult verifyZeroRegion(Operation *op);
263 LogicalResult verifyOneRegion(Operation *op);
264 LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
265 LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
266 LogicalResult verifyZeroResult(Operation *op);
267 LogicalResult verifyOneResult(Operation *op);
268 LogicalResult verifyNResults(Operation *op, unsigned numOperands);
269 LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
270 LogicalResult verifySameOperandsShape(Operation *op);
271 LogicalResult verifySameOperandsAndResultShape(Operation *op);
272 LogicalResult verifySameOperandsElementType(Operation *op);
273 LogicalResult verifySameOperandsAndResultElementType(Operation *op);
274 LogicalResult verifySameOperandsAndResultType(Operation *op);
275 LogicalResult verifyResultsAreBoolLike(Operation *op);
276 LogicalResult verifyResultsAreFloatLike(Operation *op);
277 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
278 LogicalResult verifyIsTerminator(Operation *op);
279 LogicalResult verifyZeroSuccessor(Operation *op);
280 LogicalResult verifyOneSuccessor(Operation *op);
281 LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
282 LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
283 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
284 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
285 LogicalResult verifyNoRegionArguments(Operation *op);
286 LogicalResult verifyElementwiseMappable(Operation *op);
287 } // namespace impl
288 
289 /// Helper class for implementing traits.  Clients are not expected to interact
290 /// with this directly, so its members are all protected.
291 template <typename ConcreteType, template <typename> class TraitType>
292 class TraitBase {
293 protected:
294   /// Return the ultimate Operation being worked on.
getOperation()295   Operation *getOperation() {
296     // We have to cast up to the trait type, then to the concrete type, then to
297     // the BaseState class in explicit hops because the concrete type will
298     // multiply derive from the (content free) TraitBase class, and we need to
299     // be able to disambiguate the path for the C++ compiler.
300     auto *trait = static_cast<TraitType<ConcreteType> *>(this);
301     auto *concrete = static_cast<ConcreteType *>(trait);
302     auto *base = static_cast<OpState *>(concrete);
303     return base->getOperation();
304   }
305 };
306 
307 //===----------------------------------------------------------------------===//
308 // Operand Traits
309 
310 namespace detail {
311 /// Utility trait base that provides accessors for derived traits that have
312 /// multiple operands.
313 template <typename ConcreteType, template <typename> class TraitType>
314 struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
315   using operand_iterator = Operation::operand_iterator;
316   using operand_range = Operation::operand_range;
317   using operand_type_iterator = Operation::operand_type_iterator;
318   using operand_type_range = Operation::operand_type_range;
319 
320   /// Return the number of operands.
getNumOperandsMultiOperandTraitBase321   unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
322 
323   /// Return the operand at index 'i'.
getOperandMultiOperandTraitBase324   Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
325 
326   /// Set the operand at index 'i' to 'value'.
setOperandMultiOperandTraitBase327   void setOperand(unsigned i, Value value) {
328     this->getOperation()->setOperand(i, value);
329   }
330 
331   /// Operand iterator access.
operand_beginMultiOperandTraitBase332   operand_iterator operand_begin() {
333     return this->getOperation()->operand_begin();
334   }
operand_endMultiOperandTraitBase335   operand_iterator operand_end() { return this->getOperation()->operand_end(); }
getOperandsMultiOperandTraitBase336   operand_range getOperands() { return this->getOperation()->getOperands(); }
337 
338   /// Operand type access.
operand_type_beginMultiOperandTraitBase339   operand_type_iterator operand_type_begin() {
340     return this->getOperation()->operand_type_begin();
341   }
operand_type_endMultiOperandTraitBase342   operand_type_iterator operand_type_end() {
343     return this->getOperation()->operand_type_end();
344   }
getOperandTypesMultiOperandTraitBase345   operand_type_range getOperandTypes() {
346     return this->getOperation()->getOperandTypes();
347   }
348 };
349 } // end namespace detail
350 
351 /// This class provides the API for ops that are known to have no
352 /// SSA operand.
353 template <typename ConcreteType>
354 class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
355 public:
verifyTrait(Operation * op)356   static LogicalResult verifyTrait(Operation *op) {
357     return impl::verifyZeroOperands(op);
358   }
359 
360 private:
361   // Disable these.
getOperand()362   void getOperand() {}
setOperand()363   void setOperand() {}
364 };
365 
366 /// This class provides the API for ops that are known to have exactly one
367 /// SSA operand.
368 template <typename ConcreteType>
369 class OneOperand : public TraitBase<ConcreteType, OneOperand> {
370 public:
getOperand()371   Value getOperand() { return this->getOperation()->getOperand(0); }
372 
setOperand(Value value)373   void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
374 
verifyTrait(Operation * op)375   static LogicalResult verifyTrait(Operation *op) {
376     return impl::verifyOneOperand(op);
377   }
378 };
379 
380 /// This class provides the API for ops that are known to have a specified
381 /// number of operands.  This is used as a trait like this:
382 ///
383 ///   class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
384 ///
385 template <unsigned N>
386 class NOperands {
387 public:
388   static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
389 
390   template <typename ConcreteType>
391   class Impl
392       : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
393   public:
verifyTrait(Operation * op)394     static LogicalResult verifyTrait(Operation *op) {
395       return impl::verifyNOperands(op, N);
396     }
397   };
398 };
399 
400 /// This class provides the API for ops that are known to have a at least a
401 /// specified number of operands.  This is used as a trait like this:
402 ///
403 ///   class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
404 ///
405 template <unsigned N>
406 class AtLeastNOperands {
407 public:
408   template <typename ConcreteType>
409   class Impl : public detail::MultiOperandTraitBase<ConcreteType,
410                                                     AtLeastNOperands<N>::Impl> {
411   public:
verifyTrait(Operation * op)412     static LogicalResult verifyTrait(Operation *op) {
413       return impl::verifyAtLeastNOperands(op, N);
414     }
415   };
416 };
417 
418 /// This class provides the API for ops which have an unknown number of
419 /// SSA operands.
420 template <typename ConcreteType>
421 class VariadicOperands
422     : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
423 
424 //===----------------------------------------------------------------------===//
425 // Region Traits
426 
427 /// This class provides verification for ops that are known to have zero
428 /// regions.
429 template <typename ConcreteType>
430 class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
431 public:
verifyTrait(Operation * op)432   static LogicalResult verifyTrait(Operation *op) {
433     return impl::verifyZeroRegion(op);
434   }
435 };
436 
437 namespace detail {
438 /// Utility trait base that provides accessors for derived traits that have
439 /// multiple regions.
440 template <typename ConcreteType, template <typename> class TraitType>
441 struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
442   using region_iterator = MutableArrayRef<Region>;
443   using region_range = RegionRange;
444 
445   /// Return the number of regions.
getNumRegionsMultiRegionTraitBase446   unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
447 
448   /// Return the region at `index`.
getRegionMultiRegionTraitBase449   Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
450 
451   /// Region iterator access.
region_beginMultiRegionTraitBase452   region_iterator region_begin() {
453     return this->getOperation()->region_begin();
454   }
region_endMultiRegionTraitBase455   region_iterator region_end() { return this->getOperation()->region_end(); }
getRegionsMultiRegionTraitBase456   region_range getRegions() { return this->getOperation()->getRegions(); }
457 };
458 } // end namespace detail
459 
460 /// This class provides APIs for ops that are known to have a single region.
461 template <typename ConcreteType>
462 class OneRegion : public TraitBase<ConcreteType, OneRegion> {
463 public:
getRegion()464   Region &getRegion() { return this->getOperation()->getRegion(0); }
465 
466   /// Returns a range of operations within the region of this operation.
getOps()467   auto getOps() { return getRegion().getOps(); }
468   template <typename OpT>
getOps()469   auto getOps() {
470     return getRegion().template getOps<OpT>();
471   }
472 
verifyTrait(Operation * op)473   static LogicalResult verifyTrait(Operation *op) {
474     return impl::verifyOneRegion(op);
475   }
476 };
477 
478 /// This class provides the API for ops that are known to have a specified
479 /// number of regions.
480 template <unsigned N>
481 class NRegions {
482 public:
483   static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
484 
485   template <typename ConcreteType>
486   class Impl
487       : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
488   public:
verifyTrait(Operation * op)489     static LogicalResult verifyTrait(Operation *op) {
490       return impl::verifyNRegions(op, N);
491     }
492   };
493 };
494 
495 /// This class provides APIs for ops that are known to have at least a specified
496 /// number of regions.
497 template <unsigned N>
498 class AtLeastNRegions {
499 public:
500   template <typename ConcreteType>
501   class Impl : public detail::MultiRegionTraitBase<ConcreteType,
502                                                    AtLeastNRegions<N>::Impl> {
503   public:
verifyTrait(Operation * op)504     static LogicalResult verifyTrait(Operation *op) {
505       return impl::verifyAtLeastNRegions(op, N);
506     }
507   };
508 };
509 
510 /// This class provides the API for ops which have an unknown number of
511 /// regions.
512 template <typename ConcreteType>
513 class VariadicRegions
514     : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
515 
516 //===----------------------------------------------------------------------===//
517 // Result Traits
518 
519 /// This class provides return value APIs for ops that are known to have
520 /// zero results.
521 template <typename ConcreteType>
522 class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
523 public:
verifyTrait(Operation * op)524   static LogicalResult verifyTrait(Operation *op) {
525     return impl::verifyZeroResult(op);
526   }
527 };
528 
529 namespace detail {
530 /// Utility trait base that provides accessors for derived traits that have
531 /// multiple results.
532 template <typename ConcreteType, template <typename> class TraitType>
533 struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
534   using result_iterator = Operation::result_iterator;
535   using result_range = Operation::result_range;
536   using result_type_iterator = Operation::result_type_iterator;
537   using result_type_range = Operation::result_type_range;
538 
539   /// Return the number of results.
getNumResultsMultiResultTraitBase540   unsigned getNumResults() { return this->getOperation()->getNumResults(); }
541 
542   /// Return the result at index 'i'.
getResultMultiResultTraitBase543   Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
544 
545   /// Replace all uses of results of this operation with the provided 'values'.
546   /// 'values' may correspond to an existing operation, or a range of 'Value'.
547   template <typename ValuesT>
replaceAllUsesWithMultiResultTraitBase548   void replaceAllUsesWith(ValuesT &&values) {
549     this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
550   }
551 
552   /// Return the type of the `i`-th result.
getTypeMultiResultTraitBase553   Type getType(unsigned i) { return getResult(i).getType(); }
554 
555   /// Result iterator access.
result_beginMultiResultTraitBase556   result_iterator result_begin() {
557     return this->getOperation()->result_begin();
558   }
result_endMultiResultTraitBase559   result_iterator result_end() { return this->getOperation()->result_end(); }
getResultsMultiResultTraitBase560   result_range getResults() { return this->getOperation()->getResults(); }
561 
562   /// Result type access.
result_type_beginMultiResultTraitBase563   result_type_iterator result_type_begin() {
564     return this->getOperation()->result_type_begin();
565   }
result_type_endMultiResultTraitBase566   result_type_iterator result_type_end() {
567     return this->getOperation()->result_type_end();
568   }
getResultTypesMultiResultTraitBase569   result_type_range getResultTypes() {
570     return this->getOperation()->getResultTypes();
571   }
572 };
573 } // end namespace detail
574 
575 /// This class provides return value APIs for ops that are known to have a
576 /// single result.  ResultType is the concrete type returned by getType().
577 template <typename ConcreteType>
578 class OneResult : public TraitBase<ConcreteType, OneResult> {
579 public:
getResult()580   Value getResult() { return this->getOperation()->getResult(0); }
581 
582   /// If the operation returns a single value, then the Op can be implicitly
583   /// converted to an Value. This yields the value of the only result.
Value()584   operator Value() { return getResult(); }
585 
586   /// Replace all uses of 'this' value with the new value, updating anything
587   /// in the IR that uses 'this' to use the other value instead.  When this
588   /// returns there are zero uses of 'this'.
replaceAllUsesWith(Value newValue)589   void replaceAllUsesWith(Value newValue) {
590     getResult().replaceAllUsesWith(newValue);
591   }
592 
593   /// Replace all uses of 'this' value with the result of 'op'.
replaceAllUsesWith(Operation * op)594   void replaceAllUsesWith(Operation *op) {
595     this->getOperation()->replaceAllUsesWith(op);
596   }
597 
verifyTrait(Operation * op)598   static LogicalResult verifyTrait(Operation *op) {
599     return impl::verifyOneResult(op);
600   }
601 };
602 
603 /// This trait is used for return value APIs for ops that are known to have a
604 /// specific type other than `Type`.  This allows the "getType()" member to be
605 /// more specific for an op.  This should be used in conjunction with OneResult,
606 /// and occur in the trait list before OneResult.
607 template <typename ResultType>
608 class OneTypedResult {
609 public:
610   /// This class provides return value APIs for ops that are known to have a
611   /// single result.  ResultType is the concrete type returned by getType().
612   template <typename ConcreteType>
613   class Impl
614       : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
615   public:
getType()616     ResultType getType() {
617       auto resultTy = this->getOperation()->getResult(0).getType();
618       return resultTy.template cast<ResultType>();
619     }
620   };
621 };
622 
623 /// This class provides the API for ops that are known to have a specified
624 /// number of results.  This is used as a trait like this:
625 ///
626 ///   class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
627 ///
628 template <unsigned N>
629 class NResults {
630 public:
631   static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
632 
633   template <typename ConcreteType>
634   class Impl
635       : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
636   public:
verifyTrait(Operation * op)637     static LogicalResult verifyTrait(Operation *op) {
638       return impl::verifyNResults(op, N);
639     }
640   };
641 };
642 
643 /// This class provides the API for ops that are known to have at least a
644 /// specified number of results.  This is used as a trait like this:
645 ///
646 ///   class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
647 ///
648 template <unsigned N>
649 class AtLeastNResults {
650 public:
651   template <typename ConcreteType>
652   class Impl : public detail::MultiResultTraitBase<ConcreteType,
653                                                    AtLeastNResults<N>::Impl> {
654   public:
verifyTrait(Operation * op)655     static LogicalResult verifyTrait(Operation *op) {
656       return impl::verifyAtLeastNResults(op, N);
657     }
658   };
659 };
660 
661 /// This class provides the API for ops which have an unknown number of
662 /// results.
663 template <typename ConcreteType>
664 class VariadicResults
665     : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
666 
667 //===----------------------------------------------------------------------===//
668 // Terminator Traits
669 
670 /// This class provides the API for ops that are known to be terminators.
671 template <typename ConcreteType>
672 class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
673 public:
getTraitProperties()674   static AbstractOperation::OperationProperties getTraitProperties() {
675     return static_cast<AbstractOperation::OperationProperties>(
676         OperationProperty::Terminator);
677   }
verifyTrait(Operation * op)678   static LogicalResult verifyTrait(Operation *op) {
679     return impl::verifyIsTerminator(op);
680   }
681 };
682 
683 /// This class provides verification for ops that are known to have zero
684 /// successors.
685 template <typename ConcreteType>
686 class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
687 public:
verifyTrait(Operation * op)688   static LogicalResult verifyTrait(Operation *op) {
689     return impl::verifyZeroSuccessor(op);
690   }
691 };
692 
693 namespace detail {
694 /// Utility trait base that provides accessors for derived traits that have
695 /// multiple successors.
696 template <typename ConcreteType, template <typename> class TraitType>
697 struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
698   using succ_iterator = Operation::succ_iterator;
699   using succ_range = SuccessorRange;
700 
701   /// Return the number of successors.
getNumSuccessorsMultiSuccessorTraitBase702   unsigned getNumSuccessors() {
703     return this->getOperation()->getNumSuccessors();
704   }
705 
706   /// Return the successor at `index`.
getSuccessorMultiSuccessorTraitBase707   Block *getSuccessor(unsigned i) {
708     return this->getOperation()->getSuccessor(i);
709   }
710 
711   /// Set the successor at `index`.
setSuccessorMultiSuccessorTraitBase712   void setSuccessor(Block *block, unsigned i) {
713     return this->getOperation()->setSuccessor(block, i);
714   }
715 
716   /// Successor iterator access.
succ_beginMultiSuccessorTraitBase717   succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
succ_endMultiSuccessorTraitBase718   succ_iterator succ_end() { return this->getOperation()->succ_end(); }
getSuccessorsMultiSuccessorTraitBase719   succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
720 };
721 } // end namespace detail
722 
723 /// This class provides APIs for ops that are known to have a single successor.
724 template <typename ConcreteType>
725 class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
726 public:
getSuccessor()727   Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
setSuccessor(Block * succ)728   void setSuccessor(Block *succ) {
729     this->getOperation()->setSuccessor(succ, 0);
730   }
731 
verifyTrait(Operation * op)732   static LogicalResult verifyTrait(Operation *op) {
733     return impl::verifyOneSuccessor(op);
734   }
735 };
736 
737 /// This class provides the API for ops that are known to have a specified
738 /// number of successors.
739 template <unsigned N>
740 class NSuccessors {
741 public:
742   static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
743 
744   template <typename ConcreteType>
745   class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
746                                                       NSuccessors<N>::Impl> {
747   public:
verifyTrait(Operation * op)748     static LogicalResult verifyTrait(Operation *op) {
749       return impl::verifyNSuccessors(op, N);
750     }
751   };
752 };
753 
754 /// This class provides APIs for ops that are known to have at least a specified
755 /// number of successors.
756 template <unsigned N>
757 class AtLeastNSuccessors {
758 public:
759   template <typename ConcreteType>
760   class Impl
761       : public detail::MultiSuccessorTraitBase<ConcreteType,
762                                                AtLeastNSuccessors<N>::Impl> {
763   public:
verifyTrait(Operation * op)764     static LogicalResult verifyTrait(Operation *op) {
765       return impl::verifyAtLeastNSuccessors(op, N);
766     }
767   };
768 };
769 
770 /// This class provides the API for ops which have an unknown number of
771 /// successors.
772 template <typename ConcreteType>
773 class VariadicSuccessors
774     : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
775 };
776 
777 //===----------------------------------------------------------------------===//
778 // SingleBlockImplicitTerminator
779 
780 /// This class provides APIs and verifiers for ops with regions having a single
781 /// block that must terminate with `TerminatorOpType`.
782 template <typename TerminatorOpType>
783 struct SingleBlockImplicitTerminator {
784   template <typename ConcreteType>
785   class Impl : public TraitBase<ConcreteType, Impl> {
786   private:
787     /// Builds a terminator operation without relying on OpBuilder APIs to avoid
788     /// cyclic header inclusion.
buildTerminatorSingleBlockImplicitTerminator789     static Operation *buildTerminator(OpBuilder &builder, Location loc) {
790       OperationState state(loc, TerminatorOpType::getOperationName());
791       TerminatorOpType::build(builder, state);
792       return Operation::create(state);
793     }
794 
795   public:
796     /// The type of the operation used as the implicit terminator type.
797     using ImplicitTerminatorOpT = TerminatorOpType;
798 
verifyTraitSingleBlockImplicitTerminator799     static LogicalResult verifyTrait(Operation *op) {
800       for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
801         Region &region = op->getRegion(i);
802 
803         // Empty regions are fine.
804         if (region.empty())
805           continue;
806 
807         // Non-empty regions must contain a single basic block.
808         if (std::next(region.begin()) != region.end())
809           return op->emitOpError("expects region #")
810                  << i << " to have 0 or 1 blocks";
811 
812         Block &block = region.front();
813         if (block.empty())
814           return op->emitOpError() << "expects a non-empty block";
815         Operation &terminator = block.back();
816         if (isa<TerminatorOpType>(terminator))
817           continue;
818 
819         return op->emitOpError("expects regions to end with '" +
820                                TerminatorOpType::getOperationName() +
821                                "', found '" +
822                                terminator.getName().getStringRef() + "'")
823                    .attachNote()
824                << "in custom textual format, the absence of terminator implies "
825                   "'"
826                << TerminatorOpType::getOperationName() << '\'';
827       }
828 
829       return success();
830     }
831 
832     /// Ensure that the given region has the terminator required by this trait.
833     /// If OpBuilder is provided, use it to build the terminator and notify the
834     /// OpBuilder listeners accordingly. If only a Builder is provided, locally
835     /// construct an OpBuilder with no listeners; this should only be used if no
836     /// OpBuilder is available at the call site, e.g., in the parser.
ensureTerminatorSingleBlockImplicitTerminator837     static void ensureTerminator(Region &region, Builder &builder,
838                                  Location loc) {
839       ::mlir::impl::ensureRegionTerminator(region, builder, loc,
840                                            buildTerminator);
841     }
ensureTerminatorSingleBlockImplicitTerminator842     static void ensureTerminator(Region &region, OpBuilder &builder,
843                                  Location loc) {
844       ::mlir::impl::ensureRegionTerminator(region, builder, loc,
845                                            buildTerminator);
846     }
847 
848     Block *getBody(unsigned idx = 0) {
849       Region &region = this->getOperation()->getRegion(idx);
850       assert(!region.empty() && "unexpected empty region");
851       return &region.front();
852     }
853     Region &getBodyRegion(unsigned idx = 0) {
854       return this->getOperation()->getRegion(idx);
855     }
856 
857     //===------------------------------------------------------------------===//
858     // Single Region Utilities
859     //===------------------------------------------------------------------===//
860 
861     /// The following are a set of methods only enabled when the parent
862     /// operation has a single region. Each of these methods take an additional
863     /// template parameter that represents the concrete operation so that we
864     /// can use SFINAE to disable the methods for non-single region operations.
865     template <typename OpT, typename T = void>
866     using enable_if_single_region =
867         typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
868 
869     template <typename OpT = ConcreteType>
beginSingleBlockImplicitTerminator870     enable_if_single_region<OpT, Block::iterator> begin() {
871       return getBody()->begin();
872     }
873     template <typename OpT = ConcreteType>
endSingleBlockImplicitTerminator874     enable_if_single_region<OpT, Block::iterator> end() {
875       return getBody()->end();
876     }
877     template <typename OpT = ConcreteType>
frontSingleBlockImplicitTerminator878     enable_if_single_region<OpT, Operation &> front() {
879       return *begin();
880     }
881 
882     /// Insert the operation into the back of the body, before the terminator.
883     template <typename OpT = ConcreteType>
push_backSingleBlockImplicitTerminator884     enable_if_single_region<OpT> push_back(Operation *op) {
885       insert(Block::iterator(getBody()->getTerminator()), op);
886     }
887 
888     /// Insert the operation at the given insertion point. Note: The operation
889     /// is never inserted after the terminator, even if the insertion point is
890     /// end().
891     template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator892     enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
893       insert(Block::iterator(insertPt), op);
894     }
895     template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator896     enable_if_single_region<OpT> insert(Block::iterator insertPt,
897                                         Operation *op) {
898       auto *body = getBody();
899       if (insertPt == body->end())
900         insertPt = Block::iterator(body->getTerminator());
901       body->getOperations().insert(insertPt, op);
902     }
903   };
904 };
905 
906 //===----------------------------------------------------------------------===//
907 // Misc Traits
908 
909 /// This class provides verification for ops that are known to have the same
910 /// operand shape: all operands are scalars, vectors/tensors of the same
911 /// shape.
912 template <typename ConcreteType>
913 class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
914 public:
verifyTrait(Operation * op)915   static LogicalResult verifyTrait(Operation *op) {
916     return impl::verifySameOperandsShape(op);
917   }
918 };
919 
920 /// This class provides verification for ops that are known to have the same
921 /// operand and result shape: both are scalars, vectors/tensors of the same
922 /// shape.
923 template <typename ConcreteType>
924 class SameOperandsAndResultShape
925     : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
926 public:
verifyTrait(Operation * op)927   static LogicalResult verifyTrait(Operation *op) {
928     return impl::verifySameOperandsAndResultShape(op);
929   }
930 };
931 
932 /// This class provides verification for ops that are known to have the same
933 /// operand element type (or the type itself if it is scalar).
934 ///
935 template <typename ConcreteType>
936 class SameOperandsElementType
937     : public TraitBase<ConcreteType, SameOperandsElementType> {
938 public:
verifyTrait(Operation * op)939   static LogicalResult verifyTrait(Operation *op) {
940     return impl::verifySameOperandsElementType(op);
941   }
942 };
943 
944 /// This class provides verification for ops that are known to have the same
945 /// operand and result element type (or the type itself if it is scalar).
946 ///
947 template <typename ConcreteType>
948 class SameOperandsAndResultElementType
949     : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
950 public:
verifyTrait(Operation * op)951   static LogicalResult verifyTrait(Operation *op) {
952     return impl::verifySameOperandsAndResultElementType(op);
953   }
954 };
955 
956 /// This class provides verification for ops that are known to have the same
957 /// operand and result type.
958 ///
959 /// Note: this trait subsumes the SameOperandsAndResultShape and
960 /// SameOperandsAndResultElementType traits.
961 template <typename ConcreteType>
962 class SameOperandsAndResultType
963     : public TraitBase<ConcreteType, SameOperandsAndResultType> {
964 public:
verifyTrait(Operation * op)965   static LogicalResult verifyTrait(Operation *op) {
966     return impl::verifySameOperandsAndResultType(op);
967   }
968 };
969 
970 /// This class verifies that any results of the specified op have a boolean
971 /// type, a vector thereof, or a tensor thereof.
972 template <typename ConcreteType>
973 class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
974 public:
verifyTrait(Operation * op)975   static LogicalResult verifyTrait(Operation *op) {
976     return impl::verifyResultsAreBoolLike(op);
977   }
978 };
979 
980 /// This class verifies that any results of the specified op have a floating
981 /// point type, a vector thereof, or a tensor thereof.
982 template <typename ConcreteType>
983 class ResultsAreFloatLike
984     : public TraitBase<ConcreteType, ResultsAreFloatLike> {
985 public:
verifyTrait(Operation * op)986   static LogicalResult verifyTrait(Operation *op) {
987     return impl::verifyResultsAreFloatLike(op);
988   }
989 };
990 
991 /// This class verifies that any results of the specified op have a signless
992 /// integer or index type, a vector thereof, or a tensor thereof.
993 template <typename ConcreteType>
994 class ResultsAreSignlessIntegerLike
995     : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
996 public:
verifyTrait(Operation * op)997   static LogicalResult verifyTrait(Operation *op) {
998     return impl::verifyResultsAreSignlessIntegerLike(op);
999   }
1000 };
1001 
1002 /// This class adds property that the operation is commutative.
1003 template <typename ConcreteType>
1004 class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
1005 public:
getTraitProperties()1006   static AbstractOperation::OperationProperties getTraitProperties() {
1007     return static_cast<AbstractOperation::OperationProperties>(
1008         OperationProperty::Commutative);
1009   }
1010 };
1011 
1012 /// This class adds property that the operation is an involution.
1013 /// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1014 template <typename ConcreteType>
1015 class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1016 public:
verifyTrait(Operation * op)1017   static LogicalResult verifyTrait(Operation *op) {
1018     static_assert(ConcreteType::template hasTrait<OneResult>(),
1019                   "expected operation to produce one result");
1020     static_assert(ConcreteType::template hasTrait<OneOperand>(),
1021                   "expected operation to take one operand");
1022     static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1023                   "expected operation to preserve type");
1024     // Involution requires the operation to be side effect free as well
1025     // but currently this check is under a FIXME and is not actually done.
1026     return impl::verifyIsInvolution(op);
1027   }
1028 
foldTrait(Operation * op,ArrayRef<Attribute> operands)1029   static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1030     return impl::foldInvolution(op);
1031   }
1032 };
1033 
1034 /// This class adds property that the operation is idempotent.
1035 /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
1036 template <typename ConcreteType>
1037 class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1038 public:
verifyTrait(Operation * op)1039   static LogicalResult verifyTrait(Operation *op) {
1040     static_assert(ConcreteType::template hasTrait<OneResult>(),
1041                   "expected operation to produce one result");
1042     static_assert(ConcreteType::template hasTrait<OneOperand>(),
1043                   "expected operation to take one operand");
1044     static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1045                   "expected operation to preserve type");
1046     // Idempotent requires the operation to be side effect free as well
1047     // but currently this check is under a FIXME and is not actually done.
1048     return impl::verifyIsIdempotent(op);
1049   }
1050 
foldTrait(Operation * op,ArrayRef<Attribute> operands)1051   static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1052     return impl::foldIdempotent(op);
1053   }
1054 };
1055 
1056 /// This class verifies that all operands of the specified op have a float type,
1057 /// a vector thereof, or a tensor thereof.
1058 template <typename ConcreteType>
1059 class OperandsAreFloatLike
1060     : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1061 public:
verifyTrait(Operation * op)1062   static LogicalResult verifyTrait(Operation *op) {
1063     return impl::verifyOperandsAreFloatLike(op);
1064   }
1065 };
1066 
1067 /// This class verifies that all operands of the specified op have a signless
1068 /// integer or index type, a vector thereof, or a tensor thereof.
1069 template <typename ConcreteType>
1070 class OperandsAreSignlessIntegerLike
1071     : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1072 public:
verifyTrait(Operation * op)1073   static LogicalResult verifyTrait(Operation *op) {
1074     return impl::verifyOperandsAreSignlessIntegerLike(op);
1075   }
1076 };
1077 
1078 /// This class verifies that all operands of the specified op have the same
1079 /// type.
1080 template <typename ConcreteType>
1081 class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1082 public:
verifyTrait(Operation * op)1083   static LogicalResult verifyTrait(Operation *op) {
1084     return impl::verifySameTypeOperands(op);
1085   }
1086 };
1087 
1088 /// This class provides the API for a sub-set of ops that are known to be
1089 /// constant-like. These are non-side effecting operations with one result and
1090 /// zero operands that can always be folded to a specific attribute value.
1091 template <typename ConcreteType>
1092 class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1093 public:
verifyTrait(Operation * op)1094   static LogicalResult verifyTrait(Operation *op) {
1095     static_assert(ConcreteType::template hasTrait<OneResult>(),
1096                   "expected operation to produce one result");
1097     static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1098                   "expected operation to take zero operands");
1099     // TODO: We should verify that the operation can always be folded, but this
1100     // requires that the attributes of the op already be verified. We should add
1101     // support for verifying traits "after" the operation to enable this use
1102     // case.
1103     return success();
1104   }
1105 };
1106 
1107 /// This class provides the API for ops that are known to be isolated from
1108 /// above.
1109 template <typename ConcreteType>
1110 class IsIsolatedFromAbove
1111     : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1112 public:
getTraitProperties()1113   static AbstractOperation::OperationProperties getTraitProperties() {
1114     return static_cast<AbstractOperation::OperationProperties>(
1115         OperationProperty::IsolatedFromAbove);
1116   }
verifyTrait(Operation * op)1117   static LogicalResult verifyTrait(Operation *op) {
1118     for (auto &region : op->getRegions())
1119       if (!region.isIsolatedFromAbove(op->getLoc()))
1120         return failure();
1121     return success();
1122   }
1123 };
1124 
1125 /// A trait of region holding operations that defines a new scope for polyhedral
1126 /// optimization purposes. Any SSA values of 'index' type that either dominate
1127 /// such an operation or are used at the top-level of such an operation
1128 /// automatically become valid symbols for the polyhedral scope defined by that
1129 /// operation. For more details, see `Traits.md#AffineScope`.
1130 template <typename ConcreteType>
1131 class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1132 public:
verifyTrait(Operation * op)1133   static LogicalResult verifyTrait(Operation *op) {
1134     static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1135                   "expected operation to have one or more regions");
1136     return success();
1137   }
1138 };
1139 
1140 /// A trait of region holding operations that define a new scope for automatic
1141 /// allocations, i.e., allocations that are freed when control is transferred
1142 /// back from the operation's region. Any operations performing such allocations
1143 /// (for eg. std.alloca) will have their allocations automatically freed at
1144 /// their closest enclosing operation with this trait.
1145 template <typename ConcreteType>
1146 class AutomaticAllocationScope
1147     : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1148 public:
verifyTrait(Operation * op)1149   static LogicalResult verifyTrait(Operation *op) {
1150     if (op->hasTrait<ZeroRegion>())
1151       return op->emitOpError("is expected to have regions");
1152     return success();
1153   }
1154 };
1155 
1156 /// This class provides a verifier for ops that are expecting their parent
1157 /// to be one of the given parent ops
1158 template <typename... ParentOpTypes>
1159 struct HasParent {
1160   template <typename ConcreteType>
1161   class Impl : public TraitBase<ConcreteType, Impl> {
1162   public:
verifyTraitHasParent1163     static LogicalResult verifyTrait(Operation *op) {
1164       if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
1165         return success();
1166 
1167       return op->emitOpError()
1168              << "expects parent op "
1169              << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1170              << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1171              << "'";
1172     }
1173   };
1174 };
1175 
1176 /// A trait for operations that have an attribute specifying operand segments.
1177 ///
1178 /// Certain operations can have multiple variadic operands and their size
1179 /// relationship is not always known statically. For such cases, we need
1180 /// a per-op-instance specification to divide the operands into logical groups
1181 /// or segments. This can be modeled by attributes. The attribute will be named
1182 /// as `operand_segment_sizes`.
1183 ///
1184 /// This trait verifies the attribute for specifying operand segments has
1185 /// the correct type (1D vector) and values (non-negative), etc.
1186 template <typename ConcreteType>
1187 class AttrSizedOperandSegments
1188     : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1189 public:
getOperandSegmentSizeAttr()1190   static StringRef getOperandSegmentSizeAttr() {
1191     return "operand_segment_sizes";
1192   }
1193 
verifyTrait(Operation * op)1194   static LogicalResult verifyTrait(Operation *op) {
1195     return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1196         op, getOperandSegmentSizeAttr());
1197   }
1198 };
1199 
1200 /// Similar to AttrSizedOperandSegments but used for results.
1201 template <typename ConcreteType>
1202 class AttrSizedResultSegments
1203     : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1204 public:
getResultSegmentSizeAttr()1205   static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1206 
verifyTrait(Operation * op)1207   static LogicalResult verifyTrait(Operation *op) {
1208     return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1209         op, getResultSegmentSizeAttr());
1210   }
1211 };
1212 
1213 /// This trait provides a verifier for ops that are expecting their regions to
1214 /// not have any arguments
1215 template <typename ConcrentType>
1216 struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
verifyTraitNoRegionArguments1217   static LogicalResult verifyTrait(Operation *op) {
1218     return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1219   }
1220 };
1221 
1222 // This trait is used to flag operations that consume or produce
1223 // values of `MemRef` type where those references can be 'normalized'.
1224 // TODO: Right now, the operands of an operation are either all normalizable,
1225 // or not. In the future, we may want to allow some of the operands to be
1226 // normalizable.
1227 template <typename ConcrentType>
1228 struct MemRefsNormalizable
1229     : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1230 
1231 /// This trait tags scalar ops that also can be applied to vectors/tensors, with
1232 /// their semantics on vectors/tensors being elementwise application.
1233 ///
1234 /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1235 /// trait. In particular, broadcasting behavior is not allowed. This trait
1236 /// describes a set of invariants that allow systematic
1237 /// vectorization/tensorization, and the reverse, scalarization. The properties
1238 /// needed for this also can be used to implement a number of
1239 /// transformations/analyses/interfaces.
1240 ///
1241 /// An `ElementwiseMappable` op must satisfy the following properties:
1242 ///
1243 /// 1. If any result is a vector (resp. tensor), then at least one operand must
1244 /// be a vector (resp. tensor).
1245 /// 2. If any operand is a vector (resp. tensor), then there must be at least
1246 /// one result, and all results must be vectors (resp. tensors).
1247 /// 3. The static types of all vector (resp. tensor) operands and results must
1248 /// have the same shape.
1249 /// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
1250 /// must be the same, otherwise the op has undefined behavior.
1251 /// 5. ("systematic scalarization" property) If an op has vector/tensor
1252 /// operands/results, then the same op, with the operand/result types changed to
1253 /// their corresponding element type, shall be a verifier-valid op.
1254 /// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
1255 /// applying the scalarized version of the op for each corresponding element of
1256 /// the vector (resp. tensor) operands in parallel.
1257 /// 7. ("systematic vectorization/tensorization" property) If an op has
1258 /// scalar operands/results, the op shall remain verifier-valid if all scalar
1259 /// operands are replaced with vectors/tensors of the same shape and
1260 /// corresponding element types.
1261 ///
1262 /// Together, these properties provide an easy way for scalar operations to
1263 /// conveniently generalize their behavior to vectors/tensors, and systematize
1264 /// conversion between these forms.
1265 ///
1266 /// Examples:
1267 /// ```
1268 /// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
1269 /// // Applying the systematic vectorization/tensorization property, this op
1270 /// // must also be valid:
1271 /// %tensor = "std.addf"(%a_tensor, %b_tensor)
1272 ///           : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
1273 ///
1274 /// // These properties generalize well to the cases of non-scalar operands.
1275 /// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
1276 ///                       : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1277 /// // Applying the systematic vectorization / tensorization property, this
1278 /// // op must also be valid:
1279 /// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
1280 ///                       : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1281 ///                       -> tensor<?xf32>
1282 /// // Applying the systematic scalarization property, this op must also
1283 /// // be valid.
1284 /// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1285 ///                  : (i1, f32, f32) -> f32
1286 /// ```
1287 ///
1288 /// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
1289 /// implementing a new "ElementwiseMappableTypeInterface" that describes types
1290 /// for which it makes sense to apply a scalar function to each element.
1291 ///
1292 /// Rationale:
1293 /// - 1. and 2. guarantee a well-defined iteration space for 6.
1294 ///   - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
1295 ///     results, which complicate a generic definition of the iteration space.
1296 /// - 3. guarantees that folding can be done across scalars/vectors/tensors
1297 ///   with the same pattern, as otherwise lots of special handling of type
1298 ///   mismatches would be needed.
1299 /// - 4. guarantees that no error handling cases need to be considered.
1300 ///   - Higher-level dialects should reify any needed guards / error handling
1301 ///   code before lowering to an ElementwiseMappable op.
1302 /// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
1303 ///   semantics and provide a constructive procedure for IR transformations
1304 ///   to e.g. create scalar loop bodies from tensor ops.
1305 /// - 7. provides the reverse of 5., which when chained together allows
1306 ///   reasoning about the relationship between the tensor and vector case.
1307 ///   Additionally, it permits reasoning about promoting scalars to
1308 ///   vectors/tensors via broadcasting in cases like `%select_scalar_pred`
1309 ///   above.
1310 template <typename ConcreteType>
1311 struct ElementwiseMappable
1312     : public TraitBase<ConcreteType, ElementwiseMappable> {
verifyTraitElementwiseMappable1313   static LogicalResult verifyTrait(Operation *op) {
1314     return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
1315   }
1316 };
1317 
1318 } // end namespace OpTrait
1319 
1320 //===----------------------------------------------------------------------===//
1321 // Internal Trait Utilities
1322 //===----------------------------------------------------------------------===//
1323 
1324 namespace op_definition_impl {
1325 //===----------------------------------------------------------------------===//
1326 // Trait Existence
1327 
1328 /// Returns true if this given Trait ID matches the IDs of any of the provided
1329 /// trait types `Traits`.
1330 template <template <typename T> class... Traits>
hasTrait(TypeID traitID)1331 static bool hasTrait(TypeID traitID) {
1332   TypeID traitIDs[] = {TypeID::get<Traits>()...};
1333   for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1334     if (traitIDs[i] == traitID)
1335       return true;
1336   return false;
1337 }
1338 
1339 //===----------------------------------------------------------------------===//
1340 // Trait Folding
1341 
1342 /// Trait to check if T provides a 'foldTrait' method for single result
1343 /// operations.
1344 template <typename T, typename... Args>
1345 using has_single_result_fold_trait = decltype(T::foldTrait(
1346     std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1347 template <typename T>
1348 using detect_has_single_result_fold_trait =
1349     llvm::is_detected<has_single_result_fold_trait, T>;
1350 /// Trait to check if T provides a general 'foldTrait' method.
1351 template <typename T, typename... Args>
1352 using has_fold_trait =
1353     decltype(T::foldTrait(std::declval<Operation *>(),
1354                           std::declval<ArrayRef<Attribute>>(),
1355                           std::declval<SmallVectorImpl<OpFoldResult> &>()));
1356 template <typename T>
1357 using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1358 /// Trait to check if T provides any `foldTrait` method.
1359 /// NOTE: This should use std::disjunction when C++17 is available.
1360 template <typename T>
1361 using detect_has_any_fold_trait =
1362     std::conditional_t<bool(detect_has_fold_trait<T>::value),
1363                        detect_has_fold_trait<T>,
1364                        detect_has_single_result_fold_trait<T>>;
1365 
1366 /// Returns the result of folding a trait that implements a `foldTrait` function
1367 /// that is specialized for operations that have a single result.
1368 template <typename Trait>
1369 static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1370                         LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1371 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1372           SmallVectorImpl<OpFoldResult> &results) {
1373   assert(op->hasTrait<OpTrait::OneResult>() &&
1374          "expected trait on non single-result operation to implement the "
1375          "general `foldTrait` method");
1376   // If a previous trait has already been folded and replaced this operation, we
1377   // fail to fold this trait.
1378   if (!results.empty())
1379     return failure();
1380 
1381   if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1382     if (result.template dyn_cast<Value>() != op->getResult(0))
1383       results.push_back(result);
1384     return success();
1385   }
1386   return failure();
1387 }
1388 /// Returns the result of folding a trait that implements a generalized
1389 /// `foldTrait` function that is supports any operation type.
1390 template <typename Trait>
1391 static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1392 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1393           SmallVectorImpl<OpFoldResult> &results) {
1394   // If a previous trait has already been folded and replaced this operation, we
1395   // fail to fold this trait.
1396   return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1397 }
1398 
1399 /// The internal implementation of `foldTraits` below that returns the result of
1400 /// folding a set of trait types `Ts` that implement a `foldTrait` method.
1401 template <typename... Ts>
foldTraitsImpl(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results,std::tuple<Ts...> *)1402 static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
1403                                     SmallVectorImpl<OpFoldResult> &results,
1404                                     std::tuple<Ts...> *) {
1405   bool anyFolded = false;
1406   (void)std::initializer_list<int>{
1407       (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1408   return success(anyFolded);
1409 }
1410 
1411 /// Given a tuple type containing a set of traits that contain a `foldTrait`
1412 /// method, return the result of folding the given operation.
1413 template <typename TraitTupleT>
1414 static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1415 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1416            SmallVectorImpl<OpFoldResult> &results) {
1417   return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
1418 }
1419 /// A variant of the method above that is specialized when there are no traits
1420 /// that contain a `foldTrait` method.
1421 template <typename TraitTupleT>
1422 static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1423 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1424            SmallVectorImpl<OpFoldResult> &results) {
1425   return failure();
1426 }
1427 
1428 //===----------------------------------------------------------------------===//
1429 // Trait Properties
1430 
1431 /// Trait to check if T provides a `getTraitProperties` method.
1432 template <typename T, typename... Args>
1433 using has_get_trait_properties = decltype(T::getTraitProperties());
1434 template <typename T>
1435 using detect_has_get_trait_properties =
1436     llvm::is_detected<has_get_trait_properties, T>;
1437 
1438 /// The internal implementation of `getTraitProperties` below that returns the
1439 /// OR of invoking `getTraitProperties` on all of the provided trait types `Ts`.
1440 template <typename... Ts>
1441 static AbstractOperation::OperationProperties
getTraitPropertiesImpl(std::tuple<Ts...> *)1442 getTraitPropertiesImpl(std::tuple<Ts...> *) {
1443   AbstractOperation::OperationProperties result = 0;
1444   (void)std::initializer_list<int>{(result |= Ts::getTraitProperties(), 0)...};
1445   return result;
1446 }
1447 
1448 /// Given a tuple type containing a set of traits that contain a
1449 /// `getTraitProperties` method, return the OR of all of the results of invoking
1450 /// those methods.
1451 template <typename TraitTupleT>
getTraitProperties()1452 static AbstractOperation::OperationProperties getTraitProperties() {
1453   return getTraitPropertiesImpl((TraitTupleT *)nullptr);
1454 }
1455 
1456 //===----------------------------------------------------------------------===//
1457 // Trait Verification
1458 
1459 /// Trait to check if T provides a `verifyTrait` method.
1460 template <typename T, typename... Args>
1461 using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1462 template <typename T>
1463 using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1464 
1465 /// The internal implementation of `verifyTraits` below that returns the result
1466 /// of verifying the current operation with all of the provided trait types
1467 /// `Ts`.
1468 template <typename... Ts>
verifyTraitsImpl(Operation * op,std::tuple<Ts...> *)1469 static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
1470   LogicalResult result = success();
1471   (void)std::initializer_list<int>{
1472       (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
1473   return result;
1474 }
1475 
1476 /// Given a tuple type containing a set of traits that contain a
1477 /// `verifyTrait` method, return the result of verifying the given operation.
1478 template <typename TraitTupleT>
verifyTraits(Operation * op)1479 static LogicalResult verifyTraits(Operation *op) {
1480   return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
1481 }
1482 } // namespace op_definition_impl
1483 
1484 //===----------------------------------------------------------------------===//
1485 // Operation Definition classes
1486 //===----------------------------------------------------------------------===//
1487 
1488 /// This provides public APIs that all operations should have.  The template
1489 /// argument 'ConcreteType' should be the concrete type by CRTP and the others
1490 /// are base classes by the policy pattern.
1491 template <typename ConcreteType, template <typename T> class... Traits>
1492 class Op : public OpState, public Traits<ConcreteType>... {
1493 public:
1494   /// Inherit getOperation from `OpState`.
1495   using OpState::getOperation;
1496 
1497   /// Return if this operation contains the provided trait.
1498   template <template <typename T> class Trait>
hasTrait()1499   static constexpr bool hasTrait() {
1500     return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1501   }
1502 
1503   /// Create a deep copy of this operation.
clone()1504   ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1505 
1506   /// Create a partial copy of this operation without traversing into attached
1507   /// regions. The new operation will have the same number of regions as the
1508   /// original one, but they will be left empty.
cloneWithoutRegions()1509   ConcreteType cloneWithoutRegions() {
1510     return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1511   }
1512 
1513   /// Return true if this "op class" can match against the specified operation.
classof(Operation * op)1514   static bool classof(Operation *op) {
1515     if (auto *abstractOp = op->getAbstractOperation())
1516       return TypeID::get<ConcreteType>() == abstractOp->typeID;
1517 #ifndef NDEBUG
1518     if (op->getName().getStringRef() == ConcreteType::getOperationName())
1519       llvm::report_fatal_error(
1520           "classof on '" + ConcreteType::getOperationName() +
1521           "' failed due to the operation not being registered");
1522 #endif
1523     return false;
1524   }
1525 
1526   /// Expose the type we are instantiated on to template machinery that may want
1527   /// to introspect traits on this operation.
1528   using ConcreteOpType = ConcreteType;
1529 
1530   /// This is a public constructor.  Any op can be initialized to null.
Op()1531   explicit Op() : OpState(nullptr) {}
Op(std::nullptr_t)1532   Op(std::nullptr_t) : OpState(nullptr) {}
1533 
1534   /// This is a public constructor to enable access via the llvm::cast family of
1535   /// methods. This should not be used directly.
Op(Operation * state)1536   explicit Op(Operation *state) : OpState(state) {}
1537 
1538   /// Methods for supporting PointerLikeTypeTraits.
getAsOpaquePointer()1539   const void *getAsOpaquePointer() const {
1540     return static_cast<const void *>((Operation *)*this);
1541   }
getFromOpaquePointer(const void * pointer)1542   static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1543     return ConcreteOpType(
1544         reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1545   }
1546 
1547 private:
1548   /// Trait to check if T provides a 'fold' method for a single result op.
1549   template <typename T, typename... Args>
1550   using has_single_result_fold =
1551       decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1552   template <typename T>
1553   using detect_has_single_result_fold =
1554       llvm::is_detected<has_single_result_fold, T>;
1555   /// Trait to check if T provides a general 'fold' method.
1556   template <typename T, typename... Args>
1557   using has_fold = decltype(
1558       std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
1559                              std::declval<SmallVectorImpl<OpFoldResult> &>()));
1560   template <typename T>
1561   using detect_has_fold = llvm::is_detected<has_fold, T>;
1562   /// Trait to check if T provides a 'print' method.
1563   template <typename T, typename... Args>
1564   using has_print =
1565       decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1566   template <typename T>
1567   using detect_has_print = llvm::is_detected<has_print, T>;
1568   /// A tuple type containing the traits that have a `foldTrait` function.
1569   using FoldableTraitsTupleT = typename detail::FilterTypes<
1570       op_definition_impl::detect_has_any_fold_trait,
1571       Traits<ConcreteType>...>::type;
1572   /// A tuple type containing the traits that have a verify function.
1573   using VerifiableTraitsTupleT =
1574       typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
1575                                    Traits<ConcreteType>...>::type;
1576 
1577   /// Returns the properties of this operation by combining the properties
1578   /// defined by the traits.
getOperationProperties()1579   static AbstractOperation::OperationProperties getOperationProperties() {
1580     return op_definition_impl::getTraitProperties<typename detail::FilterTypes<
1581         op_definition_impl::detect_has_get_trait_properties,
1582         Traits<ConcreteType>...>::type>();
1583   }
1584 
1585   /// Returns an interface map containing the interfaces registered to this
1586   /// operation.
getInterfaceMap()1587   static detail::InterfaceMap getInterfaceMap() {
1588     return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1589   }
1590 
1591   /// Return the internal implementations of each of the AbstractOperation
1592   /// hooks.
1593   /// Implementation of `FoldHookFn` AbstractOperation hook.
getFoldHookFn()1594   static AbstractOperation::FoldHookFn getFoldHookFn() {
1595     return getFoldHookFnImpl<ConcreteType>();
1596   }
1597   /// The internal implementation of `getFoldHookFn` above that is invoked if
1598   /// the operation is single result and defines a `fold` method.
1599   template <typename ConcreteOpT>
1600   static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1601                                           Traits<ConcreteOpT>...>::value &&
1602                               detect_has_single_result_fold<ConcreteOpT>::value,
1603                           AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1604   getFoldHookFnImpl() {
1605     return &foldSingleResultHook<ConcreteOpT>;
1606   }
1607   /// The internal implementation of `getFoldHookFn` above that is invoked if
1608   /// the operation is not single result and defines a `fold` method.
1609   template <typename ConcreteOpT>
1610   static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1611                                            Traits<ConcreteOpT>...>::value &&
1612                               detect_has_fold<ConcreteOpT>::value,
1613                           AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1614   getFoldHookFnImpl() {
1615     return &foldHook<ConcreteOpT>;
1616   }
1617   /// The internal implementation of `getFoldHookFn` above that is invoked if
1618   /// the operation does not define a `fold` method.
1619   template <typename ConcreteOpT>
1620   static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1621                               !detect_has_fold<ConcreteOpT>::value,
1622                           AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1623   getFoldHookFnImpl() {
1624     // In this case, we only need to fold the traits of the operation.
1625     return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
1626   }
1627   /// Return the result of folding a single result operation that defines a
1628   /// `fold` method.
1629   template <typename ConcreteOpT>
1630   static LogicalResult
foldSingleResultHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1631   foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1632                        SmallVectorImpl<OpFoldResult> &results) {
1633     OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1634 
1635     // If the fold failed or was in-place, try to fold the traits of the
1636     // operation.
1637     if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1638       if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1639               op, operands, results)))
1640         return success();
1641       return success(static_cast<bool>(result));
1642     }
1643     results.push_back(result);
1644     return success();
1645   }
1646   /// Return the result of folding an operation that defines a `fold` method.
1647   template <typename ConcreteOpT>
foldHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1648   static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1649                                 SmallVectorImpl<OpFoldResult> &results) {
1650     LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1651 
1652     // If the fold failed or was in-place, try to fold the traits of the
1653     // operation.
1654     if (failed(result) || results.empty()) {
1655       if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1656               op, operands, results)))
1657         return success();
1658     }
1659     return result;
1660   }
1661 
1662   /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
1663   static AbstractOperation::GetCanonicalizationPatternsFn
getGetCanonicalizationPatternsFn()1664   getGetCanonicalizationPatternsFn() {
1665     return &ConcreteType::getCanonicalizationPatterns;
1666   }
1667   /// Implementation of `GetHasTraitFn`
getHasTraitFn()1668   static AbstractOperation::HasTraitFn getHasTraitFn() {
1669     return &op_definition_impl::hasTrait<Traits...>;
1670   }
1671   /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
getParseAssemblyFn()1672   static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
1673     return &ConcreteType::parse;
1674   }
1675   /// Implementation of `PrintAssemblyFn` AbstractOperation hook.
getPrintAssemblyFn()1676   static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
1677     return getPrintAssemblyFnImpl<ConcreteType>();
1678   }
1679   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1680   /// the concrete operation does not define a `print` method.
1681   template <typename ConcreteOpT>
1682   static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1683                           AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1684   getPrintAssemblyFnImpl() {
1685     return &OpState::print;
1686   }
1687   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1688   /// the concrete operation defines a `print` method.
1689   template <typename ConcreteOpT>
1690   static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1691                           AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1692   getPrintAssemblyFnImpl() {
1693     return &printAssembly;
1694   }
printAssembly(Operation * op,OpAsmPrinter & p)1695   static void printAssembly(Operation *op, OpAsmPrinter &p) {
1696     return cast<ConcreteType>(op).print(p);
1697   }
1698   /// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
getVerifyInvariantsFn()1699   static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
1700     return &verifyInvariants;
1701   }
verifyInvariants(Operation * op)1702   static LogicalResult verifyInvariants(Operation *op) {
1703     return failure(
1704         failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
1705         failed(cast<ConcreteType>(op).verify()));
1706   }
1707 
1708   /// Allow access to internal implementation methods.
1709   friend AbstractOperation;
1710 };
1711 
1712 /// This class represents the base of an operation interface. See the definition
1713 /// of `detail::Interface` for requirements on the `Traits` type.
1714 template <typename ConcreteType, typename Traits>
1715 class OpInterface
1716     : public detail::Interface<ConcreteType, Operation *, Traits,
1717                                Op<ConcreteType>, OpTrait::TraitBase> {
1718 public:
1719   using Base = OpInterface<ConcreteType, Traits>;
1720   using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1721                                           Op<ConcreteType>, OpTrait::TraitBase>;
1722 
1723   /// Inherit the base class constructor.
1724   using InterfaceBase::InterfaceBase;
1725 
1726 protected:
1727   /// Returns the impl interface instance for the given operation.
getInterfaceFor(Operation * op)1728   static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1729     // Access the raw interface from the abstract operation.
1730     auto *abstractOp = op->getAbstractOperation();
1731     return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
1732   }
1733 
1734   /// Allow access to `getInterfaceFor`.
1735   friend InterfaceBase;
1736 };
1737 
1738 //===----------------------------------------------------------------------===//
1739 // Common Operation Folders/Parsers/Printers
1740 //===----------------------------------------------------------------------===//
1741 
1742 // These functions are out-of-line implementations of the methods in UnaryOp and
1743 // BinaryOp, which avoids them being template instantiated/duplicated.
1744 namespace impl {
1745 ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
1746                                            OperationState &result);
1747 
1748 void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1749                    Value rhs);
1750 ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1751                                             OperationState &result);
1752 
1753 // Prints the given binary `op` in custom assembly form if both the two operands
1754 // and the result have the same time. Otherwise, prints the generic assembly
1755 // form.
1756 void printOneResultOp(Operation *op, OpAsmPrinter &p);
1757 } // namespace impl
1758 
1759 // These functions are out-of-line implementations of the methods in
1760 // CastOpInterface, which avoids them being template instantiated/duplicated.
1761 namespace impl {
1762 /// Attempt to fold the given cast operation.
1763 LogicalResult foldCastInterfaceOp(Operation *op,
1764                                   ArrayRef<Attribute> attrOperands,
1765                                   SmallVectorImpl<OpFoldResult> &foldResults);
1766 /// Attempt to verify the given cast operation.
1767 LogicalResult verifyCastInterfaceOp(
1768     Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
1769 
1770 // TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
1771 // need for them, but some older ODS code in `std` still depends on them).
1772 void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1773                  Type destType);
1774 ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
1775 void printCastOp(Operation *op, OpAsmPrinter &p);
1776 // TODO: These methods are deprecated in favor of CastOpInterface. Remove them
1777 // when all uses have been updated. Also, consider adding functionality to
1778 // CastOpInterface to be able to perform the ChainedTensorCast canonicalization
1779 // generically.
1780 Value foldCastOp(Operation *op);
1781 LogicalResult verifyCastOp(Operation *op,
1782                            function_ref<bool(Type, Type)> areCastCompatible);
1783 } // namespace impl
1784 } // end namespace mlir
1785 
1786 #endif
1787