1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "llvm/ADT/STLExtras.h"
19
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24
25 /// A utility result that is used to signal if a walk method should be
26 /// interrupted or advance.
27 class WalkResult {
28 enum ResultEnum { Interrupt, Advance } result;
29
30 public:
WalkResult(ResultEnum result)31 WalkResult(ResultEnum result) : result(result) {}
32
33 /// Allow LogicalResult to interrupt the walk on failure.
WalkResult(LogicalResult result)34 WalkResult(LogicalResult result)
35 : result(failed(result) ? Interrupt : Advance) {}
36
37 /// Allow diagnostics to interrupt the walk.
WalkResult(Diagnostic &&)38 WalkResult(Diagnostic &&) : result(Interrupt) {}
WalkResult(InFlightDiagnostic &&)39 WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
40
41 bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
42
interrupt()43 static WalkResult interrupt() { return {Interrupt}; }
advance()44 static WalkResult advance() { return {Advance}; }
45
46 /// Returns if the walk was interrupted.
wasInterrupted()47 bool wasInterrupted() const { return result == Interrupt; }
48 };
49
50 namespace detail {
51 /// Helper templates to deduce the first argument of a callback parameter.
52 template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
53 template <typename Ret, typename F, typename Arg>
54 Arg first_argument_type(Ret (F::*)(Arg));
55 template <typename Ret, typename F, typename Arg>
56 Arg first_argument_type(Ret (F::*)(Arg) const);
57 template <typename F>
58 decltype(first_argument_type(&F::operator())) first_argument_type(F);
59
60 /// Type definition of the first argument to the given callable 'T'.
61 template <typename T>
62 using first_argument = decltype(first_argument_type(std::declval<T>()));
63
64 /// Walk all of the operations nested under and including the given operation.
65 void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
66
67 /// Walk all of the operations nested under and including the given operation.
68 /// This methods walks operations until an interrupt result is returned by the
69 /// callback.
70 WalkResult walkOperations(Operation *op,
71 function_ref<WalkResult(Operation *op)> callback);
72
73 // Below are a set of functions to walk nested operations. Users should favor
74 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
75 // methods. They are also templated to allow for statically dispatching based
76 // upon the type of the callback function.
77
78 /// Walk all of the operations nested under and including the given operation.
79 /// This method is selected for callbacks that operation on Operation*.
80 ///
81 /// Example:
82 /// op->walk([](Operation *op) { ... });
83 template <
84 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
85 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
86 typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
walkOperations(Operation * op,FuncTy && callback)87 walkOperations(Operation *op, FuncTy &&callback) {
88 return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback));
89 }
90
91 /// Walk all of the operations of type 'ArgT' nested under and including the
92 /// given operation. This method is selected for void returning callbacks that
93 /// operate on a specific derived operation type.
94 ///
95 /// Example:
96 /// op->walk([](ReturnOp op) { ... });
97 template <
98 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
99 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
100 typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
101 std::is_same<RetT, void>::value,
102 RetT>::type
walkOperations(Operation * op,FuncTy && callback)103 walkOperations(Operation *op, FuncTy &&callback) {
104 auto wrapperFn = [&](Operation *op) {
105 if (auto derivedOp = dyn_cast<ArgT>(op))
106 callback(derivedOp);
107 };
108 return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
109 }
110
111 /// Walk all of the operations of type 'ArgT' nested under and including the
112 /// given operation. This method is selected for WalkReturn returning
113 /// interruptible callbacks that operate on a specific derived operation type.
114 ///
115 /// Example:
116 /// op->walk([](ReturnOp op) {
117 /// if (some_invariant)
118 /// return WalkResult::interrupt();
119 /// return WalkResult::advance();
120 /// });
121 template <
122 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
123 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
124 typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
125 std::is_same<RetT, WalkResult>::value,
126 RetT>::type
walkOperations(Operation * op,FuncTy && callback)127 walkOperations(Operation *op, FuncTy &&callback) {
128 auto wrapperFn = [&](Operation *op) {
129 if (auto derivedOp = dyn_cast<ArgT>(op))
130 return callback(derivedOp);
131 return WalkResult::advance();
132 };
133 return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
134 }
135
136 /// Utility to provide the return type of a templated walk method.
137 template <typename FnT>
138 using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
139 } // end namespace detail
140
141 } // namespace mlir
142
143 #endif
144