1 //===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_EVALUATE_TRAVERSE_H_
10 #define FORTRAN_EVALUATE_TRAVERSE_H_
11 
12 // A utility for scanning all of the constituent objects in an Expr<>
13 // expression representation using a collection of mutually recursive
14 // functions to compose a function object.
15 //
16 // The class template Traverse<> below implements a function object that
17 // can handle every type that can appear in or around an Expr<>.
18 // Each of its overloads for operator() should be viewed as a *default*
19 // handler; some of these must be overridden by the client to accomplish
20 // its particular task.
21 //
22 // The client (Visitor) of Traverse<Visitor,Result> must define:
23 // - a member function "Result Default();"
24 // - a member function "Result Combine(Result &&, Result &&)"
25 // - overrides for "Result operator()"
26 //
27 // Boilerplate classes also appear below to ease construction of visitors.
28 // See CheckSpecificationExpr() in check-expression.cpp for an example client.
29 //
30 // How this works:
31 // - The operator() overloads in Traverse<> invoke the visitor's Default() for
32 //   expression leaf nodes.  They invoke the visitor's operator() for the
33 //   subtrees of interior nodes, and the visitor's Combine() to merge their
34 //   results together.
35 // - Overloads of operator() in each visitor handle the cases of interest.
36 //
37 // The default handler for semantics::Symbol will descend into the associated
38 // expression of an ASSOCIATE (or related) construct entity.
39 
40 #include "expression.h"
41 #include "flang/Semantics/symbol.h"
42 #include "flang/Semantics/type.h"
43 #include <set>
44 #include <type_traits>
45 
46 namespace Fortran::evaluate {
47 template <typename Visitor, typename Result> class Traverse {
48 public:
Traverse(Visitor & v)49   explicit Traverse(Visitor &v) : visitor_{v} {}
50 
51   // Packaging
52   template <typename A, bool C>
operator()53   Result operator()(const common::Indirection<A, C> &x) const {
54     return visitor_(x.value());
55   }
operator()56   template <typename A> Result operator()(const SymbolRef x) const {
57     return visitor_(*x);
58   }
operator()59   template <typename A> Result operator()(const std::unique_ptr<A> &x) const {
60     return visitor_(x.get());
61   }
operator()62   template <typename A> Result operator()(const std::shared_ptr<A> &x) const {
63     return visitor_(x.get());
64   }
operator()65   template <typename A> Result operator()(const A *x) const {
66     if (x) {
67       return visitor_(*x);
68     } else {
69       return visitor_.Default();
70     }
71   }
operator()72   template <typename A> Result operator()(const std::optional<A> &x) const {
73     if (x) {
74       return visitor_(*x);
75     } else {
76       return visitor_.Default();
77     }
78   }
79   template <typename... A>
operator()80   Result operator()(const std::variant<A...> &u) const {
81     return std::visit(visitor_, u);
82   }
operator()83   template <typename A> Result operator()(const std::vector<A> &x) const {
84     return CombineContents(x);
85   }
86 
87   // Leaves
operator()88   Result operator()(const BOZLiteralConstant &) const {
89     return visitor_.Default();
90   }
operator()91   Result operator()(const NullPointer &) const { return visitor_.Default(); }
operator()92   template <typename T> Result operator()(const Constant<T> &x) const {
93     if constexpr (T::category == TypeCategory::Derived) {
94       std::optional<Result> result;
95       for (const StructureConstructorValues &map : x.values()) {
96         for (const auto &pair : map) {
97           auto value{visitor_(pair.second.value())};
98           result = result
99               ? visitor_.Combine(std::move(*result), std::move(value))
100               : std::move(value);
101         }
102       }
103       return result ? *result : visitor_.Default();
104     } else {
105       return visitor_.Default();
106     }
107   }
operator()108   Result operator()(const Symbol &symbol) const {
109     const Symbol &ultimate{symbol.GetUltimate()};
110     if (const auto *assoc{
111             ultimate.detailsIf<semantics::AssocEntityDetails>()}) {
112       return visitor_(assoc->expr());
113     } else {
114       return visitor_.Default();
115     }
116   }
operator()117   Result operator()(const StaticDataObject &) const {
118     return visitor_.Default();
119   }
operator()120   Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); }
121 
122   // Variables
operator()123   Result operator()(const BaseObject &x) const { return visitor_(x.u); }
operator()124   Result operator()(const Component &x) const {
125     return Combine(x.base(), x.GetLastSymbol());
126   }
operator()127   Result operator()(const NamedEntity &x) const {
128     if (const Component * component{x.UnwrapComponent()}) {
129       return visitor_(*component);
130     } else {
131       return visitor_(x.GetFirstSymbol());
132     }
133   }
operator()134   Result operator()(const TypeParamInquiry &x) const {
135     return visitor_(x.base());
136   }
operator()137   Result operator()(const Triplet &x) const {
138     return Combine(x.lower(), x.upper(), x.stride());
139   }
operator()140   Result operator()(const Subscript &x) const { return visitor_(x.u); }
operator()141   Result operator()(const ArrayRef &x) const {
142     return Combine(x.base(), x.subscript());
143   }
operator()144   Result operator()(const CoarrayRef &x) const {
145     return Combine(
146         x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team());
147   }
operator()148   Result operator()(const DataRef &x) const { return visitor_(x.u); }
operator()149   Result operator()(const Substring &x) const {
150     return Combine(x.parent(), x.lower(), x.upper());
151   }
operator()152   Result operator()(const ComplexPart &x) const {
153     return visitor_(x.complex());
154   }
operator()155   template <typename T> Result operator()(const Designator<T> &x) const {
156     return visitor_(x.u);
157   }
operator()158   template <typename T> Result operator()(const Variable<T> &x) const {
159     return visitor_(x.u);
160   }
operator()161   Result operator()(const DescriptorInquiry &x) const {
162     return visitor_(x.base());
163   }
164 
165   // Calls
operator()166   Result operator()(const SpecificIntrinsic &) const {
167     return visitor_.Default();
168   }
operator()169   Result operator()(const ProcedureDesignator &x) const {
170     if (const Component * component{x.GetComponent()}) {
171       return visitor_(*component);
172     } else if (const Symbol * symbol{x.GetSymbol()}) {
173       return visitor_(*symbol);
174     } else {
175       return visitor_(DEREF(x.GetSpecificIntrinsic()));
176     }
177   }
operator()178   Result operator()(const ActualArgument &x) const {
179     if (const auto *symbol{x.GetAssumedTypeDummy()}) {
180       return visitor_(*symbol);
181     } else {
182       return visitor_(x.UnwrapExpr());
183     }
184   }
operator()185   Result operator()(const ProcedureRef &x) const {
186     return Combine(x.proc(), x.arguments());
187   }
operator()188   template <typename T> Result operator()(const FunctionRef<T> &x) const {
189     return visitor_(static_cast<const ProcedureRef &>(x));
190   }
191 
192   // Other primaries
193   template <typename T>
operator()194   Result operator()(const ArrayConstructorValue<T> &x) const {
195     return visitor_(x.u);
196   }
197   template <typename T>
operator()198   Result operator()(const ArrayConstructorValues<T> &x) const {
199     return CombineContents(x);
200   }
operator()201   template <typename T> Result operator()(const ImpliedDo<T> &x) const {
202     return Combine(x.lower(), x.upper(), x.stride(), x.values());
203   }
operator()204   Result operator()(const semantics::ParamValue &x) const {
205     return visitor_(x.GetExplicit());
206   }
operator()207   Result operator()(
208       const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const {
209     return visitor_(x.second);
210   }
operator()211   Result operator()(const semantics::DerivedTypeSpec &x) const {
212     return CombineContents(x.parameters());
213   }
operator()214   Result operator()(const StructureConstructorValues::value_type &x) const {
215     return visitor_(x.second);
216   }
operator()217   Result operator()(const StructureConstructor &x) const {
218     return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
219   }
220 
221   // Operations and wrappers
222   template <typename D, typename R, typename O>
operator()223   Result operator()(const Operation<D, R, O> &op) const {
224     return visitor_(op.left());
225   }
226   template <typename D, typename R, typename LO, typename RO>
operator()227   Result operator()(const Operation<D, R, LO, RO> &op) const {
228     return Combine(op.left(), op.right());
229   }
operator()230   Result operator()(const Relational<SomeType> &x) const {
231     return visitor_(x.u);
232   }
operator()233   template <typename T> Result operator()(const Expr<T> &x) const {
234     return visitor_(x.u);
235   }
236 
237 private:
CombineRange(ITER iter,ITER end)238   template <typename ITER> Result CombineRange(ITER iter, ITER end) const {
239     if (iter == end) {
240       return visitor_.Default();
241     } else {
242       Result result{visitor_(*iter++)};
243       for (; iter != end; ++iter) {
244         result = visitor_.Combine(std::move(result), visitor_(*iter));
245       }
246       return result;
247     }
248   }
249 
CombineContents(const A & x)250   template <typename A> Result CombineContents(const A &x) const {
251     return CombineRange(x.begin(), x.end());
252   }
253 
254   template <typename A, typename... Bs>
Combine(const A & x,const Bs &...ys)255   Result Combine(const A &x, const Bs &...ys) const {
256     if constexpr (sizeof...(Bs) == 0) {
257       return visitor_(x);
258     } else {
259       return visitor_.Combine(visitor_(x), Combine(ys...));
260     }
261   }
262 
263   Visitor &visitor_;
264 };
265 
266 // For validity checks across an expression: if any operator() result is
267 // false, so is the overall result.
268 template <typename Visitor, bool DefaultValue,
269     typename Base = Traverse<Visitor, bool>>
270 struct AllTraverse : public Base {
AllTraverseAllTraverse271   explicit AllTraverse(Visitor &v) : Base{v} {}
272   using Base::operator();
DefaultAllTraverse273   static bool Default() { return DefaultValue; }
CombineAllTraverse274   static bool Combine(bool x, bool y) { return x && y; }
275 };
276 
277 // For searches over an expression: the first operator() result that
278 // is truthful is the final result.  Works for Booleans, pointers,
279 // and std::optional<>.
280 template <typename Visitor, typename Result = bool,
281     typename Base = Traverse<Visitor, Result>>
282 class AnyTraverse : public Base {
283 public:
AnyTraverse(Visitor & v)284   explicit AnyTraverse(Visitor &v) : Base{v} {}
285   using Base::operator();
Default()286   Result Default() const { return default_; }
Combine(Result && x,Result && y)287   static Result Combine(Result &&x, Result &&y) {
288     if (x) {
289       return std::move(x);
290     } else {
291       return std::move(y);
292     }
293   }
294 
295 private:
296   Result default_{};
297 };
298 
299 template <typename Visitor, typename Set,
300     typename Base = Traverse<Visitor, Set>>
301 struct SetTraverse : public Base {
SetTraverseSetTraverse302   explicit SetTraverse(Visitor &v) : Base{v} {}
303   using Base::operator();
DefaultSetTraverse304   static Set Default() { return {}; }
CombineSetTraverse305   static Set Combine(Set &&x, Set &&y) {
306 #if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES)
307     x.merge(y);
308 #else
309     // std::set::merge() not available (yet)
310     for (auto &value : y) {
311       x.insert(std::move(value));
312     }
313 #endif
314     return std::move(x);
315   }
316 };
317 
318 } // namespace Fortran::evaluate
319 #endif
320