1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file implements the TypeSwitch template, which mimics a switch()
10 //  statement whose cases are type names.
11 //
12 //===-----------------------------------------------------------------------===/
13 
14 #ifndef LLVM_ADT_TYPESWITCH_H
15 #define LLVM_ADT_TYPESWITCH_H
16 
17 #include "llvm/ADT/Optional.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Casting.h"
20 
21 namespace llvm {
22 namespace detail {
23 
24 template <typename DerivedT, typename T> class TypeSwitchBase {
25 public:
TypeSwitchBase(const T & value)26   TypeSwitchBase(const T &value) : value(value) {}
TypeSwitchBase(TypeSwitchBase && other)27   TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
28   ~TypeSwitchBase() = default;
29 
30   /// TypeSwitchBase is not copyable.
31   TypeSwitchBase(const TypeSwitchBase &) = delete;
32   void operator=(const TypeSwitchBase &) = delete;
33   void operator=(TypeSwitchBase &&other) = delete;
34 
35   /// Invoke a case on the derived class with multiple case types.
36   template <typename CaseT, typename CaseT2, typename... CaseTs,
37             typename CallableT>
Case(CallableT && caseFn)38   DerivedT &Case(CallableT &&caseFn) {
39     DerivedT &derived = static_cast<DerivedT &>(*this);
40     return derived.template Case<CaseT>(caseFn)
41         .template Case<CaseT2, CaseTs...>(caseFn);
42   }
43 
44   /// Invoke a case on the derived class, inferring the type of the Case from
45   /// the first input of the given callable.
46   /// Note: This inference rules for this overload are very simple: strip
47   ///       pointers and references.
Case(CallableT && caseFn)48   template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
49     using Traits = function_traits<std::decay_t<CallableT>>;
50     using CaseT = std::remove_cv_t<std::remove_pointer_t<
51         std::remove_reference_t<typename Traits::template arg_t<0>>>>;
52 
53     DerivedT &derived = static_cast<DerivedT &>(*this);
54     return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
55   }
56 
57 protected:
58   /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
59   /// `CastT`.
60   template <typename ValueT, typename CastT>
61   using has_dyn_cast_t =
62       decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
63 
64   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
65   /// selected if `value` already has a suitable dyn_cast method.
66   template <typename CastT, typename ValueT>
67   static auto castValue(
68       ValueT value,
69       typename std::enable_if_t<
70           is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
71     return value.template dyn_cast<CastT>();
72   }
73 
74   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
75   /// selected if llvm::dyn_cast should be used.
76   template <typename CastT, typename ValueT>
77   static auto castValue(
78       ValueT value,
79       typename std::enable_if_t<
80           !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
81     return dyn_cast<CastT>(value);
82   }
83 
84   /// The root value we are switching on.
85   const T value;
86 };
87 } // end namespace detail
88 
89 /// This class implements a switch-like dispatch statement for a value of 'T'
90 /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
91 /// if the root value isa<T>, the callable is invoked with the result of
92 /// dyn_cast<T>() as a parameter.
93 ///
94 /// Example:
95 ///  Operation *op = ...;
96 ///  LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
97 ///    .Case<ConstantOp>([](ConstantOp op) { ... })
98 ///    .Default([](Operation *op) { ... });
99 ///
100 template <typename T, typename ResultT = void>
101 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
102 public:
103   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
104   using BaseT::BaseT;
105   using BaseT::Case;
106   TypeSwitch(TypeSwitch &&other) = default;
107 
108   /// Add a case on the given type.
109   template <typename CaseT, typename CallableT>
Case(CallableT && caseFn)110   TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
111     if (result)
112       return *this;
113 
114     // Check to see if CaseT applies to 'value'.
115     if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
116       result = caseFn(caseValue);
117     return *this;
118   }
119 
120   /// As a default, invoke the given callable within the root value.
121   template <typename CallableT>
Default(CallableT && defaultFn)122   LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
123     if (result)
124       return std::move(*result);
125     return defaultFn(this->value);
126   }
127 
128   LLVM_NODISCARD
ResultT()129   operator ResultT() {
130     assert(result && "Fell off the end of a type-switch");
131     return std::move(*result);
132   }
133 
134 private:
135   /// The pointer to the result of this switch statement, once known,
136   /// null before that.
137   Optional<ResultT> result;
138 };
139 
140 /// Specialization of TypeSwitch for void returning callables.
141 template <typename T>
142 class TypeSwitch<T, void>
143     : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
144 public:
145   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
146   using BaseT::BaseT;
147   using BaseT::Case;
148   TypeSwitch(TypeSwitch &&other) = default;
149 
150   /// Add a case on the given type.
151   template <typename CaseT, typename CallableT>
Case(CallableT && caseFn)152   TypeSwitch<T, void> &Case(CallableT &&caseFn) {
153     if (foundMatch)
154       return *this;
155 
156     // Check to see if any of the types apply to 'value'.
157     if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
158       caseFn(caseValue);
159       foundMatch = true;
160     }
161     return *this;
162   }
163 
164   /// As a default, invoke the given callable within the root value.
Default(CallableT && defaultFn)165   template <typename CallableT> void Default(CallableT &&defaultFn) {
166     if (!foundMatch)
167       defaultFn(this->value);
168   }
169 
170 private:
171   /// A flag detailing if we have already found a match.
172   bool foundMatch = false;
173 };
174 } // end namespace llvm
175 
176 #endif // LLVM_ADT_TYPESWITCH_H
177