1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 ///  This file implements the TypeSwitch template, which mimics a switch()
11 ///  statement whose cases are type names.
12 ///
13 //===-----------------------------------------------------------------------===/
14 
15 #ifndef LLVM_ADT_TYPESWITCH_H
16 #define LLVM_ADT_TYPESWITCH_H
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Casting.h"
20 #include <optional>
21 
22 namespace llvm {
23 namespace detail {
24 
25 template <typename DerivedT, typename T> class TypeSwitchBase {
26 public:
27   TypeSwitchBase(const T &value) : value(value) {}
28   TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
29   ~TypeSwitchBase() = default;
30 
31   /// TypeSwitchBase is not copyable.
32   TypeSwitchBase(const TypeSwitchBase &) = delete;
33   void operator=(const TypeSwitchBase &) = delete;
34   void operator=(TypeSwitchBase &&other) = delete;
35 
36   /// Invoke a case on the derived class with multiple case types.
37   template <typename CaseT, typename CaseT2, typename... CaseTs,
38             typename CallableT>
39   // This is marked always_inline and nodebug so it doesn't show up in stack
40   // traces at -O0 (or other optimization levels).  Large TypeSwitch's are
41   // common, are equivalent to a switch, and don't add any value to stack
42   // traces.
43   LLVM_ATTRIBUTE_ALWAYS_INLINE LLVM_ATTRIBUTE_NODEBUG DerivedT &
44   Case(CallableT &&caseFn) {
45     DerivedT &derived = static_cast<DerivedT &>(*this);
46     return derived.template Case<CaseT>(caseFn)
47         .template Case<CaseT2, CaseTs...>(caseFn);
48   }
49 
50   /// Invoke a case on the derived class, inferring the type of the Case from
51   /// the first input of the given callable.
52   /// Note: This inference rules for this overload are very simple: strip
53   ///       pointers and references.
54   template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
55     using Traits = function_traits<std::decay_t<CallableT>>;
56     using CaseT = std::remove_cv_t<std::remove_pointer_t<
57         std::remove_reference_t<typename Traits::template arg_t<0>>>>;
58 
59     DerivedT &derived = static_cast<DerivedT &>(*this);
60     return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
61   }
62 
63 protected:
64   /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
65   /// `CastT`.
66   template <typename ValueT, typename CastT>
67   using has_dyn_cast_t =
68       decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
69 
70   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
71   /// selected if `value` already has a suitable dyn_cast method.
72   template <typename CastT, typename ValueT>
73   static decltype(auto) castValue(
74       ValueT &&value,
75       std::enable_if_t<is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
76           nullptr) {
77     return value.template dyn_cast<CastT>();
78   }
79 
80   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
81   /// selected if llvm::dyn_cast should be used.
82   template <typename CastT, typename ValueT>
83   static decltype(auto) castValue(
84       ValueT &&value,
85       std::enable_if_t<!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
86           nullptr) {
87     return dyn_cast<CastT>(value);
88   }
89 
90   /// The root value we are switching on.
91   const T value;
92 };
93 } // end namespace detail
94 
95 /// This class implements a switch-like dispatch statement for a value of 'T'
96 /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
97 /// if the root value isa<T>, the callable is invoked with the result of
98 /// dyn_cast<T>() as a parameter.
99 ///
100 /// Example:
101 ///  Operation *op = ...;
102 ///  LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
103 ///    .Case<ConstantOp>([](ConstantOp op) { ... })
104 ///    .Default([](Operation *op) { ... });
105 ///
106 template <typename T, typename ResultT = void>
107 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
108 public:
109   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
110   using BaseT::BaseT;
111   using BaseT::Case;
112   TypeSwitch(TypeSwitch &&other) = default;
113 
114   /// Add a case on the given type.
115   template <typename CaseT, typename CallableT>
116   TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
117     if (result)
118       return *this;
119 
120     // Check to see if CaseT applies to 'value'.
121     if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
122       result.emplace(caseFn(caseValue));
123     return *this;
124   }
125 
126   /// As a default, invoke the given callable within the root value.
127   template <typename CallableT>
128   [[nodiscard]] ResultT Default(CallableT &&defaultFn) {
129     if (result)
130       return std::move(*result);
131     return defaultFn(this->value);
132   }
133   /// As a default, return the given value.
134   [[nodiscard]] ResultT Default(ResultT defaultResult) {
135     if (result)
136       return std::move(*result);
137     return defaultResult;
138   }
139 
140   [[nodiscard]] operator ResultT() {
141     assert(result && "Fell off the end of a type-switch");
142     return std::move(*result);
143   }
144 
145 private:
146   /// The pointer to the result of this switch statement, once known,
147   /// null before that.
148   std::optional<ResultT> result;
149 };
150 
151 /// Specialization of TypeSwitch for void returning callables.
152 template <typename T>
153 class TypeSwitch<T, void>
154     : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
155 public:
156   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
157   using BaseT::BaseT;
158   using BaseT::Case;
159   TypeSwitch(TypeSwitch &&other) = default;
160 
161   /// Add a case on the given type.
162   template <typename CaseT, typename CallableT>
163   TypeSwitch<T, void> &Case(CallableT &&caseFn) {
164     if (foundMatch)
165       return *this;
166 
167     // Check to see if any of the types apply to 'value'.
168     if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
169       caseFn(caseValue);
170       foundMatch = true;
171     }
172     return *this;
173   }
174 
175   /// As a default, invoke the given callable within the root value.
176   template <typename CallableT> void Default(CallableT &&defaultFn) {
177     if (!foundMatch)
178       defaultFn(this->value);
179   }
180 
181 private:
182   /// A flag detailing if we have already found a match.
183   bool foundMatch = false;
184 };
185 } // end namespace llvm
186 
187 #endif // LLVM_ADT_TYPESWITCH_H
188