1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the TypeSwitch template, which mimics a switch() 11 /// statement whose cases are type names. 12 /// 13 //===-----------------------------------------------------------------------===/ 14 15 #ifndef LLVM_ADT_TYPESWITCH_H 16 #define LLVM_ADT_TYPESWITCH_H 17 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/Support/Casting.h" 20 #include <optional> 21 22 namespace llvm { 23 namespace detail { 24 25 template <typename DerivedT, typename T> class TypeSwitchBase { 26 public: 27 TypeSwitchBase(const T &value) : value(value) {} 28 TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {} 29 ~TypeSwitchBase() = default; 30 31 /// TypeSwitchBase is not copyable. 32 TypeSwitchBase(const TypeSwitchBase &) = delete; 33 void operator=(const TypeSwitchBase &) = delete; 34 void operator=(TypeSwitchBase &&other) = delete; 35 36 /// Invoke a case on the derived class with multiple case types. 37 template <typename CaseT, typename CaseT2, typename... CaseTs, 38 typename CallableT> 39 // This is marked always_inline and nodebug so it doesn't show up in stack 40 // traces at -O0 (or other optimization levels). Large TypeSwitch's are 41 // common, are equivalent to a switch, and don't add any value to stack 42 // traces. 43 LLVM_ATTRIBUTE_ALWAYS_INLINE LLVM_ATTRIBUTE_NODEBUG DerivedT & 44 Case(CallableT &&caseFn) { 45 DerivedT &derived = static_cast<DerivedT &>(*this); 46 return derived.template Case<CaseT>(caseFn) 47 .template Case<CaseT2, CaseTs...>(caseFn); 48 } 49 50 /// Invoke a case on the derived class, inferring the type of the Case from 51 /// the first input of the given callable. 52 /// Note: This inference rules for this overload are very simple: strip 53 /// pointers and references. 54 template <typename CallableT> DerivedT &Case(CallableT &&caseFn) { 55 using Traits = function_traits<std::decay_t<CallableT>>; 56 using CaseT = std::remove_cv_t<std::remove_pointer_t< 57 std::remove_reference_t<typename Traits::template arg_t<0>>>>; 58 59 DerivedT &derived = static_cast<DerivedT &>(*this); 60 return derived.template Case<CaseT>(std::forward<CallableT>(caseFn)); 61 } 62 63 protected: 64 /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type 65 /// `CastT`. 66 template <typename ValueT, typename CastT> 67 using has_dyn_cast_t = 68 decltype(std::declval<ValueT &>().template dyn_cast<CastT>()); 69 70 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is 71 /// selected if `value` already has a suitable dyn_cast method. 72 template <typename CastT, typename ValueT> 73 static decltype(auto) castValue( 74 ValueT &&value, 75 std::enable_if_t<is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = 76 nullptr) { 77 return value.template dyn_cast<CastT>(); 78 } 79 80 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is 81 /// selected if llvm::dyn_cast should be used. 82 template <typename CastT, typename ValueT> 83 static decltype(auto) castValue( 84 ValueT &&value, 85 std::enable_if_t<!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = 86 nullptr) { 87 return dyn_cast<CastT>(value); 88 } 89 90 /// The root value we are switching on. 91 const T value; 92 }; 93 } // end namespace detail 94 95 /// This class implements a switch-like dispatch statement for a value of 'T' 96 /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked 97 /// if the root value isa<T>, the callable is invoked with the result of 98 /// dyn_cast<T>() as a parameter. 99 /// 100 /// Example: 101 /// Operation *op = ...; 102 /// LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op) 103 /// .Case<ConstantOp>([](ConstantOp op) { ... }) 104 /// .Default([](Operation *op) { ... }); 105 /// 106 template <typename T, typename ResultT = void> 107 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> { 108 public: 109 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>; 110 using BaseT::BaseT; 111 using BaseT::Case; 112 TypeSwitch(TypeSwitch &&other) = default; 113 114 /// Add a case on the given type. 115 template <typename CaseT, typename CallableT> 116 TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) { 117 if (result) 118 return *this; 119 120 // Check to see if CaseT applies to 'value'. 121 if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) 122 result.emplace(caseFn(caseValue)); 123 return *this; 124 } 125 126 /// As a default, invoke the given callable within the root value. 127 template <typename CallableT> 128 [[nodiscard]] ResultT Default(CallableT &&defaultFn) { 129 if (result) 130 return std::move(*result); 131 return defaultFn(this->value); 132 } 133 /// As a default, return the given value. 134 [[nodiscard]] ResultT Default(ResultT defaultResult) { 135 if (result) 136 return std::move(*result); 137 return defaultResult; 138 } 139 140 [[nodiscard]] operator ResultT() { 141 assert(result && "Fell off the end of a type-switch"); 142 return std::move(*result); 143 } 144 145 private: 146 /// The pointer to the result of this switch statement, once known, 147 /// null before that. 148 std::optional<ResultT> result; 149 }; 150 151 /// Specialization of TypeSwitch for void returning callables. 152 template <typename T> 153 class TypeSwitch<T, void> 154 : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> { 155 public: 156 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>; 157 using BaseT::BaseT; 158 using BaseT::Case; 159 TypeSwitch(TypeSwitch &&other) = default; 160 161 /// Add a case on the given type. 162 template <typename CaseT, typename CallableT> 163 TypeSwitch<T, void> &Case(CallableT &&caseFn) { 164 if (foundMatch) 165 return *this; 166 167 // Check to see if any of the types apply to 'value'. 168 if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) { 169 caseFn(caseValue); 170 foundMatch = true; 171 } 172 return *this; 173 } 174 175 /// As a default, invoke the given callable within the root value. 176 template <typename CallableT> void Default(CallableT &&defaultFn) { 177 if (!foundMatch) 178 defaultFn(this->value); 179 } 180 181 private: 182 /// A flag detailing if we have already found a match. 183 bool foundMatch = false; 184 }; 185 } // end namespace llvm 186 187 #endif // LLVM_ADT_TYPESWITCH_H 188