1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 ///  This file implements the TypeSwitch template, which mimics a switch()
11 ///  statement whose cases are type names.
12 ///
13 //===-----------------------------------------------------------------------===/
14 
15 #ifndef LLVM_ADT_TYPESWITCH_H
16 #define LLVM_ADT_TYPESWITCH_H
17 
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/Casting.h"
21 
22 namespace llvm {
23 namespace detail {
24 
25 template <typename DerivedT, typename T> class TypeSwitchBase {
26 public:
27   TypeSwitchBase(const T &value) : value(value) {}
28   TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
29   ~TypeSwitchBase() = default;
30 
31   /// TypeSwitchBase is not copyable.
32   TypeSwitchBase(const TypeSwitchBase &) = delete;
33   void operator=(const TypeSwitchBase &) = delete;
34   void operator=(TypeSwitchBase &&other) = delete;
35 
36   /// Invoke a case on the derived class with multiple case types.
37   template <typename CaseT, typename CaseT2, typename... CaseTs,
38             typename CallableT>
39   // This is marked always_inline and nodebug so it doesn't show up in stack
40   // traces at -O0 (or other optimization levels).  Large TypeSwitch's are
41   // common, are equivalent to a switch, and don't add any value to stack
42   // traces.
43   LLVM_ATTRIBUTE_ALWAYS_INLINE LLVM_ATTRIBUTE_NODEBUG DerivedT &
44   Case(CallableT &&caseFn) {
45     DerivedT &derived = static_cast<DerivedT &>(*this);
46     return derived.template Case<CaseT>(caseFn)
47         .template Case<CaseT2, CaseTs...>(caseFn);
48   }
49 
50   /// Invoke a case on the derived class, inferring the type of the Case from
51   /// the first input of the given callable.
52   /// Note: This inference rules for this overload are very simple: strip
53   ///       pointers and references.
54   template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
55     using Traits = function_traits<std::decay_t<CallableT>>;
56     using CaseT = std::remove_cv_t<std::remove_pointer_t<
57         std::remove_reference_t<typename Traits::template arg_t<0>>>>;
58 
59     DerivedT &derived = static_cast<DerivedT &>(*this);
60     return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
61   }
62 
63 protected:
64   /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
65   /// `CastT`.
66   template <typename ValueT, typename CastT>
67   using has_dyn_cast_t =
68       decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
69 
70   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
71   /// selected if `value` already has a suitable dyn_cast method.
72   template <typename CastT, typename ValueT>
73   static auto castValue(
74       ValueT value,
75       typename std::enable_if_t<
76           is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
77     return value.template dyn_cast<CastT>();
78   }
79 
80   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
81   /// selected if llvm::dyn_cast should be used.
82   template <typename CastT, typename ValueT>
83   static auto castValue(
84       ValueT value,
85       typename std::enable_if_t<
86           !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
87     return dyn_cast<CastT>(value);
88   }
89 
90   /// The root value we are switching on.
91   const T value;
92 };
93 } // end namespace detail
94 
95 /// This class implements a switch-like dispatch statement for a value of 'T'
96 /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
97 /// if the root value isa<T>, the callable is invoked with the result of
98 /// dyn_cast<T>() as a parameter.
99 ///
100 /// Example:
101 ///  Operation *op = ...;
102 ///  LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
103 ///    .Case<ConstantOp>([](ConstantOp op) { ... })
104 ///    .Default([](Operation *op) { ... });
105 ///
106 template <typename T, typename ResultT = void>
107 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
108 public:
109   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
110   using BaseT::BaseT;
111   using BaseT::Case;
112   TypeSwitch(TypeSwitch &&other) = default;
113 
114   /// Add a case on the given type.
115   template <typename CaseT, typename CallableT>
116   TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
117     if (result)
118       return *this;
119 
120     // Check to see if CaseT applies to 'value'.
121     if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
122       result = caseFn(caseValue);
123     return *this;
124   }
125 
126   /// As a default, invoke the given callable within the root value.
127   template <typename CallableT>
128   LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
129     if (result)
130       return std::move(*result);
131     return defaultFn(this->value);
132   }
133   /// As a default, return the given value.
134   LLVM_NODISCARD ResultT Default(ResultT defaultResult) {
135     if (result)
136       return std::move(*result);
137     return defaultResult;
138   }
139 
140   LLVM_NODISCARD
141   operator ResultT() {
142     assert(result && "Fell off the end of a type-switch");
143     return std::move(*result);
144   }
145 
146 private:
147   /// The pointer to the result of this switch statement, once known,
148   /// null before that.
149   Optional<ResultT> result;
150 };
151 
152 /// Specialization of TypeSwitch for void returning callables.
153 template <typename T>
154 class TypeSwitch<T, void>
155     : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
156 public:
157   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
158   using BaseT::BaseT;
159   using BaseT::Case;
160   TypeSwitch(TypeSwitch &&other) = default;
161 
162   /// Add a case on the given type.
163   template <typename CaseT, typename CallableT>
164   TypeSwitch<T, void> &Case(CallableT &&caseFn) {
165     if (foundMatch)
166       return *this;
167 
168     // Check to see if any of the types apply to 'value'.
169     if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
170       caseFn(caseValue);
171       foundMatch = true;
172     }
173     return *this;
174   }
175 
176   /// As a default, invoke the given callable within the root value.
177   template <typename CallableT> void Default(CallableT &&defaultFn) {
178     if (!foundMatch)
179       defaultFn(this->value);
180   }
181 
182 private:
183   /// A flag detailing if we have already found a match.
184   bool foundMatch = false;
185 };
186 } // end namespace llvm
187 
188 #endif // LLVM_ADT_TYPESWITCH_H
189