1 //===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef FORTRAN_EVALUATE_TRAVERSE_H_ 10 #define FORTRAN_EVALUATE_TRAVERSE_H_ 11 12 // A utility for scanning all of the constituent objects in an Expr<> 13 // expression representation using a collection of mutually recursive 14 // functions to compose a function object. 15 // 16 // The class template Traverse<> below implements a function object that 17 // can handle every type that can appear in or around an Expr<>. 18 // Each of its overloads for operator() should be viewed as a *default* 19 // handler; some of these must be overridden by the client to accomplish 20 // its particular task. 21 // 22 // The client (Visitor) of Traverse<Visitor,Result> must define: 23 // - a member function "Result Default();" 24 // - a member function "Result Combine(Result &&, Result &&)" 25 // - overrides for "Result operator()" 26 // 27 // Boilerplate classes also appear below to ease construction of visitors. 28 // See CheckSpecificationExpr() in check-expression.cpp for an example client. 29 // 30 // How this works: 31 // - The operator() overloads in Traverse<> invoke the visitor's Default() for 32 // expression leaf nodes. They invoke the visitor's operator() for the 33 // subtrees of interior nodes, and the visitor's Combine() to merge their 34 // results together. 35 // - Overloads of operator() in each visitor handle the cases of interest. 36 // 37 // The default handler for semantics::Symbol will descend into the associated 38 // expression of an ASSOCIATE (or related) construct entity. 39 40 #include "expression.h" 41 #include "flang/Semantics/symbol.h" 42 #include "flang/Semantics/type.h" 43 #include <set> 44 #include <type_traits> 45 46 namespace Fortran::evaluate { 47 template <typename Visitor, typename Result> class Traverse { 48 public: Traverse(Visitor & v)49 explicit Traverse(Visitor &v) : visitor_{v} {} 50 51 // Packaging 52 template <typename A, bool C> operator()53 Result operator()(const common::Indirection<A, C> &x) const { 54 return visitor_(x.value()); 55 } operator()56 template <typename A> Result operator()(const SymbolRef x) const { 57 return visitor_(*x); 58 } operator()59 template <typename A> Result operator()(const std::unique_ptr<A> &x) const { 60 return visitor_(x.get()); 61 } operator()62 template <typename A> Result operator()(const std::shared_ptr<A> &x) const { 63 return visitor_(x.get()); 64 } operator()65 template <typename A> Result operator()(const A *x) const { 66 if (x) { 67 return visitor_(*x); 68 } else { 69 return visitor_.Default(); 70 } 71 } operator()72 template <typename A> Result operator()(const std::optional<A> &x) const { 73 if (x) { 74 return visitor_(*x); 75 } else { 76 return visitor_.Default(); 77 } 78 } 79 template <typename... A> operator()80 Result operator()(const std::variant<A...> &u) const { 81 return std::visit(visitor_, u); 82 } operator()83 template <typename A> Result operator()(const std::vector<A> &x) const { 84 return CombineContents(x); 85 } 86 87 // Leaves operator()88 Result operator()(const BOZLiteralConstant &) const { 89 return visitor_.Default(); 90 } operator()91 Result operator()(const NullPointer &) const { return visitor_.Default(); } operator()92 template <typename T> Result operator()(const Constant<T> &x) const { 93 if constexpr (T::category == TypeCategory::Derived) { 94 std::optional<Result> result; 95 for (const StructureConstructorValues &map : x.values()) { 96 for (const auto &pair : map) { 97 auto value{visitor_(pair.second.value())}; 98 result = result 99 ? visitor_.Combine(std::move(*result), std::move(value)) 100 : std::move(value); 101 } 102 } 103 return result ? *result : visitor_.Default(); 104 } else { 105 return visitor_.Default(); 106 } 107 } operator()108 Result operator()(const Symbol &symbol) const { 109 const Symbol &ultimate{symbol.GetUltimate()}; 110 if (const auto *assoc{ 111 ultimate.detailsIf<semantics::AssocEntityDetails>()}) { 112 return visitor_(assoc->expr()); 113 } else { 114 return visitor_.Default(); 115 } 116 } operator()117 Result operator()(const StaticDataObject &) const { 118 return visitor_.Default(); 119 } operator()120 Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); } 121 122 // Variables operator()123 Result operator()(const BaseObject &x) const { return visitor_(x.u); } operator()124 Result operator()(const Component &x) const { 125 return Combine(x.base(), x.GetLastSymbol()); 126 } operator()127 Result operator()(const NamedEntity &x) const { 128 if (const Component * component{x.UnwrapComponent()}) { 129 return visitor_(*component); 130 } else { 131 return visitor_(x.GetFirstSymbol()); 132 } 133 } operator()134 Result operator()(const TypeParamInquiry &x) const { 135 return visitor_(x.base()); 136 } operator()137 Result operator()(const Triplet &x) const { 138 return Combine(x.lower(), x.upper(), x.stride()); 139 } operator()140 Result operator()(const Subscript &x) const { return visitor_(x.u); } operator()141 Result operator()(const ArrayRef &x) const { 142 return Combine(x.base(), x.subscript()); 143 } operator()144 Result operator()(const CoarrayRef &x) const { 145 return Combine( 146 x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team()); 147 } operator()148 Result operator()(const DataRef &x) const { return visitor_(x.u); } operator()149 Result operator()(const Substring &x) const { 150 return Combine(x.parent(), x.lower(), x.upper()); 151 } operator()152 Result operator()(const ComplexPart &x) const { 153 return visitor_(x.complex()); 154 } operator()155 template <typename T> Result operator()(const Designator<T> &x) const { 156 return visitor_(x.u); 157 } operator()158 template <typename T> Result operator()(const Variable<T> &x) const { 159 return visitor_(x.u); 160 } operator()161 Result operator()(const DescriptorInquiry &x) const { 162 return visitor_(x.base()); 163 } 164 165 // Calls operator()166 Result operator()(const SpecificIntrinsic &) const { 167 return visitor_.Default(); 168 } operator()169 Result operator()(const ProcedureDesignator &x) const { 170 if (const Component * component{x.GetComponent()}) { 171 return visitor_(*component); 172 } else if (const Symbol * symbol{x.GetSymbol()}) { 173 return visitor_(*symbol); 174 } else { 175 return visitor_(DEREF(x.GetSpecificIntrinsic())); 176 } 177 } operator()178 Result operator()(const ActualArgument &x) const { 179 if (const auto *symbol{x.GetAssumedTypeDummy()}) { 180 return visitor_(*symbol); 181 } else { 182 return visitor_(x.UnwrapExpr()); 183 } 184 } operator()185 Result operator()(const ProcedureRef &x) const { 186 return Combine(x.proc(), x.arguments()); 187 } operator()188 template <typename T> Result operator()(const FunctionRef<T> &x) const { 189 return visitor_(static_cast<const ProcedureRef &>(x)); 190 } 191 192 // Other primaries 193 template <typename T> operator()194 Result operator()(const ArrayConstructorValue<T> &x) const { 195 return visitor_(x.u); 196 } 197 template <typename T> operator()198 Result operator()(const ArrayConstructorValues<T> &x) const { 199 return CombineContents(x); 200 } operator()201 template <typename T> Result operator()(const ImpliedDo<T> &x) const { 202 return Combine(x.lower(), x.upper(), x.stride(), x.values()); 203 } operator()204 Result operator()(const semantics::ParamValue &x) const { 205 return visitor_(x.GetExplicit()); 206 } operator()207 Result operator()( 208 const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const { 209 return visitor_(x.second); 210 } operator()211 Result operator()(const semantics::DerivedTypeSpec &x) const { 212 return CombineContents(x.parameters()); 213 } operator()214 Result operator()(const StructureConstructorValues::value_type &x) const { 215 return visitor_(x.second); 216 } operator()217 Result operator()(const StructureConstructor &x) const { 218 return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x)); 219 } 220 221 // Operations and wrappers 222 template <typename D, typename R, typename O> operator()223 Result operator()(const Operation<D, R, O> &op) const { 224 return visitor_(op.left()); 225 } 226 template <typename D, typename R, typename LO, typename RO> operator()227 Result operator()(const Operation<D, R, LO, RO> &op) const { 228 return Combine(op.left(), op.right()); 229 } operator()230 Result operator()(const Relational<SomeType> &x) const { 231 return visitor_(x.u); 232 } operator()233 template <typename T> Result operator()(const Expr<T> &x) const { 234 return visitor_(x.u); 235 } 236 237 private: CombineRange(ITER iter,ITER end)238 template <typename ITER> Result CombineRange(ITER iter, ITER end) const { 239 if (iter == end) { 240 return visitor_.Default(); 241 } else { 242 Result result{visitor_(*iter++)}; 243 for (; iter != end; ++iter) { 244 result = visitor_.Combine(std::move(result), visitor_(*iter)); 245 } 246 return result; 247 } 248 } 249 CombineContents(const A & x)250 template <typename A> Result CombineContents(const A &x) const { 251 return CombineRange(x.begin(), x.end()); 252 } 253 254 template <typename A, typename... Bs> Combine(const A & x,const Bs &...ys)255 Result Combine(const A &x, const Bs &...ys) const { 256 if constexpr (sizeof...(Bs) == 0) { 257 return visitor_(x); 258 } else { 259 return visitor_.Combine(visitor_(x), Combine(ys...)); 260 } 261 } 262 263 Visitor &visitor_; 264 }; 265 266 // For validity checks across an expression: if any operator() result is 267 // false, so is the overall result. 268 template <typename Visitor, bool DefaultValue, 269 typename Base = Traverse<Visitor, bool>> 270 struct AllTraverse : public Base { AllTraverseAllTraverse271 explicit AllTraverse(Visitor &v) : Base{v} {} 272 using Base::operator(); DefaultAllTraverse273 static bool Default() { return DefaultValue; } CombineAllTraverse274 static bool Combine(bool x, bool y) { return x && y; } 275 }; 276 277 // For searches over an expression: the first operator() result that 278 // is truthful is the final result. Works for Booleans, pointers, 279 // and std::optional<>. 280 template <typename Visitor, typename Result = bool, 281 typename Base = Traverse<Visitor, Result>> 282 class AnyTraverse : public Base { 283 public: AnyTraverse(Visitor & v)284 explicit AnyTraverse(Visitor &v) : Base{v} {} 285 using Base::operator(); Default()286 Result Default() const { return default_; } Combine(Result && x,Result && y)287 static Result Combine(Result &&x, Result &&y) { 288 if (x) { 289 return std::move(x); 290 } else { 291 return std::move(y); 292 } 293 } 294 295 private: 296 Result default_{}; 297 }; 298 299 template <typename Visitor, typename Set, 300 typename Base = Traverse<Visitor, Set>> 301 struct SetTraverse : public Base { SetTraverseSetTraverse302 explicit SetTraverse(Visitor &v) : Base{v} {} 303 using Base::operator(); DefaultSetTraverse304 static Set Default() { return {}; } CombineSetTraverse305 static Set Combine(Set &&x, Set &&y) { 306 #if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES) 307 x.merge(y); 308 #else 309 // std::set::merge() not available (yet) 310 for (auto &value : y) { 311 x.insert(std::move(value)); 312 } 313 #endif 314 return std::move(x); 315 } 316 }; 317 318 } // namespace Fortran::evaluate 319 #endif 320