1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the TypeSwitch template, which mimics a switch() 10 // statement whose cases are type names. 11 // 12 //===-----------------------------------------------------------------------===/ 13 14 #ifndef LLVM_ADT_TYPESWITCH_H 15 #define LLVM_ADT_TYPESWITCH_H 16 17 #include "llvm/ADT/Optional.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/Support/Casting.h" 20 21 namespace llvm { 22 namespace detail { 23 24 template <typename DerivedT, typename T> class TypeSwitchBase { 25 public: TypeSwitchBase(const T & value)26 TypeSwitchBase(const T &value) : value(value) {} TypeSwitchBase(TypeSwitchBase && other)27 TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {} 28 ~TypeSwitchBase() = default; 29 30 /// TypeSwitchBase is not copyable. 31 TypeSwitchBase(const TypeSwitchBase &) = delete; 32 void operator=(const TypeSwitchBase &) = delete; 33 void operator=(TypeSwitchBase &&other) = delete; 34 35 /// Invoke a case on the derived class with multiple case types. 36 template <typename CaseT, typename CaseT2, typename... CaseTs, 37 typename CallableT> Case(CallableT && caseFn)38 DerivedT &Case(CallableT &&caseFn) { 39 DerivedT &derived = static_cast<DerivedT &>(*this); 40 return derived.template Case<CaseT>(caseFn) 41 .template Case<CaseT2, CaseTs...>(caseFn); 42 } 43 44 /// Invoke a case on the derived class, inferring the type of the Case from 45 /// the first input of the given callable. 46 /// Note: This inference rules for this overload are very simple: strip 47 /// pointers and references. Case(CallableT && caseFn)48 template <typename CallableT> DerivedT &Case(CallableT &&caseFn) { 49 using Traits = function_traits<std::decay_t<CallableT>>; 50 using CaseT = std::remove_cv_t<std::remove_pointer_t< 51 std::remove_reference_t<typename Traits::template arg_t<0>>>>; 52 53 DerivedT &derived = static_cast<DerivedT &>(*this); 54 return derived.template Case<CaseT>(std::forward<CallableT>(caseFn)); 55 } 56 57 protected: 58 /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type 59 /// `CastT`. 60 template <typename ValueT, typename CastT> 61 using has_dyn_cast_t = 62 decltype(std::declval<ValueT &>().template dyn_cast<CastT>()); 63 64 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is 65 /// selected if `value` already has a suitable dyn_cast method. 66 template <typename CastT, typename ValueT> 67 static auto castValue( 68 ValueT value, 69 typename std::enable_if_t< 70 is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) { 71 return value.template dyn_cast<CastT>(); 72 } 73 74 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is 75 /// selected if llvm::dyn_cast should be used. 76 template <typename CastT, typename ValueT> 77 static auto castValue( 78 ValueT value, 79 typename std::enable_if_t< 80 !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) { 81 return dyn_cast<CastT>(value); 82 } 83 84 /// The root value we are switching on. 85 const T value; 86 }; 87 } // end namespace detail 88 89 /// This class implements a switch-like dispatch statement for a value of 'T' 90 /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked 91 /// if the root value isa<T>, the callable is invoked with the result of 92 /// dyn_cast<T>() as a parameter. 93 /// 94 /// Example: 95 /// Operation *op = ...; 96 /// LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op) 97 /// .Case<ConstantOp>([](ConstantOp op) { ... }) 98 /// .Default([](Operation *op) { ... }); 99 /// 100 template <typename T, typename ResultT = void> 101 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> { 102 public: 103 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>; 104 using BaseT::BaseT; 105 using BaseT::Case; 106 TypeSwitch(TypeSwitch &&other) = default; 107 108 /// Add a case on the given type. 109 template <typename CaseT, typename CallableT> Case(CallableT && caseFn)110 TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) { 111 if (result) 112 return *this; 113 114 // Check to see if CaseT applies to 'value'. 115 if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) 116 result = caseFn(caseValue); 117 return *this; 118 } 119 120 /// As a default, invoke the given callable within the root value. 121 template <typename CallableT> Default(CallableT && defaultFn)122 LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) { 123 if (result) 124 return std::move(*result); 125 return defaultFn(this->value); 126 } 127 128 LLVM_NODISCARD ResultT()129 operator ResultT() { 130 assert(result && "Fell off the end of a type-switch"); 131 return std::move(*result); 132 } 133 134 private: 135 /// The pointer to the result of this switch statement, once known, 136 /// null before that. 137 Optional<ResultT> result; 138 }; 139 140 /// Specialization of TypeSwitch for void returning callables. 141 template <typename T> 142 class TypeSwitch<T, void> 143 : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> { 144 public: 145 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>; 146 using BaseT::BaseT; 147 using BaseT::Case; 148 TypeSwitch(TypeSwitch &&other) = default; 149 150 /// Add a case on the given type. 151 template <typename CaseT, typename CallableT> Case(CallableT && caseFn)152 TypeSwitch<T, void> &Case(CallableT &&caseFn) { 153 if (foundMatch) 154 return *this; 155 156 // Check to see if any of the types apply to 'value'. 157 if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) { 158 caseFn(caseValue); 159 foundMatch = true; 160 } 161 return *this; 162 } 163 164 /// As a default, invoke the given callable within the root value. Default(CallableT && defaultFn)165 template <typename CallableT> void Default(CallableT &&defaultFn) { 166 if (!foundMatch) 167 defaultFn(this->value); 168 } 169 170 private: 171 /// A flag detailing if we have already found a match. 172 bool foundMatch = false; 173 }; 174 } // end namespace llvm 175 176 #endif // LLVM_ADT_TYPESWITCH_H 177