1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24 
25 /// A utility result that is used to signal if a walk method should be
26 /// interrupted or advance.
27 class WalkResult {
28   enum ResultEnum { Interrupt, Advance } result;
29 
30 public:
WalkResult(ResultEnum result)31   WalkResult(ResultEnum result) : result(result) {}
32 
33   /// Allow LogicalResult to interrupt the walk on failure.
WalkResult(LogicalResult result)34   WalkResult(LogicalResult result)
35       : result(failed(result) ? Interrupt : Advance) {}
36 
37   /// Allow diagnostics to interrupt the walk.
WalkResult(Diagnostic &&)38   WalkResult(Diagnostic &&) : result(Interrupt) {}
WalkResult(InFlightDiagnostic &&)39   WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
40 
41   bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
42 
interrupt()43   static WalkResult interrupt() { return {Interrupt}; }
advance()44   static WalkResult advance() { return {Advance}; }
45 
46   /// Returns if the walk was interrupted.
wasInterrupted()47   bool wasInterrupted() const { return result == Interrupt; }
48 };
49 
50 namespace detail {
51 /// Helper templates to deduce the first argument of a callback parameter.
52 template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
53 template <typename Ret, typename F, typename Arg>
54 Arg first_argument_type(Ret (F::*)(Arg));
55 template <typename Ret, typename F, typename Arg>
56 Arg first_argument_type(Ret (F::*)(Arg) const);
57 template <typename F>
58 decltype(first_argument_type(&F::operator())) first_argument_type(F);
59 
60 /// Type definition of the first argument to the given callable 'T'.
61 template <typename T>
62 using first_argument = decltype(first_argument_type(std::declval<T>()));
63 
64 /// Walk all of the operations nested under and including the given operation.
65 void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
66 
67 /// Walk all of the operations nested under and including the given operation.
68 /// This methods walks operations until an interrupt result is returned by the
69 /// callback.
70 WalkResult walkOperations(Operation *op,
71                           function_ref<WalkResult(Operation *op)> callback);
72 
73 // Below are a set of functions to walk nested operations. Users should favor
74 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
75 // methods. They are also templated to allow for statically dispatching based
76 // upon the type of the callback function.
77 
78 /// Walk all of the operations nested under and including the given operation.
79 /// This method is selected for callbacks that operation on Operation*.
80 ///
81 /// Example:
82 ///   op->walk([](Operation *op) { ... });
83 template <
84     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
85     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
86 typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
walkOperations(Operation * op,FuncTy && callback)87 walkOperations(Operation *op, FuncTy &&callback) {
88   return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback));
89 }
90 
91 /// Walk all of the operations of type 'ArgT' nested under and including the
92 /// given operation. This method is selected for void returning callbacks that
93 /// operate on a specific derived operation type.
94 ///
95 /// Example:
96 ///   op->walk([](ReturnOp op) { ... });
97 template <
98     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
99     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
100 typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
101                             std::is_same<RetT, void>::value,
102                         RetT>::type
walkOperations(Operation * op,FuncTy && callback)103 walkOperations(Operation *op, FuncTy &&callback) {
104   auto wrapperFn = [&](Operation *op) {
105     if (auto derivedOp = dyn_cast<ArgT>(op))
106       callback(derivedOp);
107   };
108   return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
109 }
110 
111 /// Walk all of the operations of type 'ArgT' nested under and including the
112 /// given operation. This method is selected for WalkReturn returning
113 /// interruptible callbacks that operate on a specific derived operation type.
114 ///
115 /// Example:
116 ///   op->walk([](ReturnOp op) {
117 ///     if (some_invariant)
118 ///       return WalkResult::interrupt();
119 ///     return WalkResult::advance();
120 ///   });
121 template <
122     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
123     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
124 typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
125                             std::is_same<RetT, WalkResult>::value,
126                         RetT>::type
walkOperations(Operation * op,FuncTy && callback)127 walkOperations(Operation *op, FuncTy &&callback) {
128   auto wrapperFn = [&](Operation *op) {
129     if (auto derivedOp = dyn_cast<ArgT>(op))
130       return callback(derivedOp);
131     return WalkResult::advance();
132   };
133   return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
134 }
135 
136 /// Utility to provide the return type of a templated walk method.
137 template <typename FnT>
138 using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
139 } // end namespace detail
140 
141 } // namespace mlir
142 
143 #endif
144