1 //===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/fold.h"
10 #include "fold-implementation.h"
11 #include "flang/Evaluate/characteristics.h"
12 
13 namespace Fortran::evaluate {
14 
Fold(FoldingContext & context,characteristics::TypeAndShape && x)15 characteristics::TypeAndShape Fold(
16     FoldingContext &context, characteristics::TypeAndShape &&x) {
17   x.Rewrite(context);
18   return std::move(x);
19 }
20 
GetConstantSubscript(FoldingContext & context,Subscript & ss,const NamedEntity & base,int dim)21 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
22     FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
23   ss = FoldOperation(context, std::move(ss));
24   return std::visit(
25       common::visitors{
26           [](IndirectSubscriptIntegerExpr &expr)
27               -> std::optional<Constant<SubscriptInteger>> {
28             if (const auto *constant{
29                     UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
30               return *constant;
31             } else {
32               return std::nullopt;
33             }
34           },
35           [&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
36             auto lower{triplet.lower()}, upper{triplet.upper()};
37             std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
38             if (!lower) {
39               lower = GetLowerBound(context, base, dim);
40             }
41             if (!upper) {
42               upper =
43                   ComputeUpperBound(context, GetLowerBound(context, base, dim),
44                       GetExtent(context, base, dim));
45             }
46             auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
47             if (lbi && ubi && stride && *stride != 0) {
48               std::vector<SubscriptInteger::Scalar> values;
49               while ((*stride > 0 && *lbi <= *ubi) ||
50                   (*stride < 0 && *lbi >= *ubi)) {
51                 values.emplace_back(*lbi);
52                 *lbi += *stride;
53               }
54               return Constant<SubscriptInteger>{std::move(values),
55                   ConstantSubscripts{
56                       static_cast<ConstantSubscript>(values.size())}};
57             } else {
58               return std::nullopt;
59             }
60           },
61       },
62       ss.u);
63 }
64 
FoldOperation(FoldingContext & context,StructureConstructor && structure)65 Expr<SomeDerived> FoldOperation(
66     FoldingContext &context, StructureConstructor &&structure) {
67   StructureConstructor ctor{structure.derivedTypeSpec()};
68   bool constantExtents{true};
69   for (auto &&[symbol, value] : std::move(structure)) {
70     auto expr{Fold(context, std::move(value.value()))};
71     if (!IsPointer(symbol)) {
72       bool ok{false};
73       if (auto valueShape{GetConstantExtents(context, expr)}) {
74         if (auto componentShape{GetConstantExtents(context, symbol)}) {
75           if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
76             expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
77                 std::move(expr));
78             ok = expr.Rank() > 0;
79           } else {
80             ok = *valueShape == *componentShape;
81           }
82         }
83       }
84       if (!ok) {
85         constantExtents = false;
86       }
87     }
88     ctor.Add(symbol, Fold(context, std::move(expr)));
89   }
90   if (constantExtents && IsConstantExpr(ctor)) {
91     return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
92   } else {
93     return Expr<SomeDerived>{std::move(ctor)};
94   }
95 }
96 
FoldOperation(FoldingContext & context,Component && component)97 Component FoldOperation(FoldingContext &context, Component &&component) {
98   return {FoldOperation(context, std::move(component.base())),
99       component.GetLastSymbol()};
100 }
101 
FoldOperation(FoldingContext & context,NamedEntity && x)102 NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
103   if (Component * c{x.UnwrapComponent()}) {
104     return NamedEntity{FoldOperation(context, std::move(*c))};
105   } else {
106     return std::move(x);
107   }
108 }
109 
FoldOperation(FoldingContext & context,Triplet && triplet)110 Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
111   MaybeExtentExpr lower{triplet.lower()};
112   MaybeExtentExpr upper{triplet.upper()};
113   return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
114       Fold(context, triplet.stride())};
115 }
116 
FoldOperation(FoldingContext & context,Subscript && subscript)117 Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
118   return std::visit(common::visitors{
119                         [&](IndirectSubscriptIntegerExpr &&expr) {
120                           expr.value() = Fold(context, std::move(expr.value()));
121                           return Subscript(std::move(expr));
122                         },
123                         [&](Triplet &&triplet) {
124                           return Subscript(
125                               FoldOperation(context, std::move(triplet)));
126                         },
127                     },
128       std::move(subscript.u));
129 }
130 
FoldOperation(FoldingContext & context,ArrayRef && arrayRef)131 ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
132   NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
133   for (Subscript &subscript : arrayRef.subscript()) {
134     subscript = FoldOperation(context, std::move(subscript));
135   }
136   return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
137 }
138 
FoldOperation(FoldingContext & context,CoarrayRef && coarrayRef)139 CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
140   std::vector<Subscript> subscript;
141   for (Subscript x : coarrayRef.subscript()) {
142     subscript.emplace_back(FoldOperation(context, std::move(x)));
143   }
144   std::vector<Expr<SubscriptInteger>> cosubscript;
145   for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
146     cosubscript.emplace_back(Fold(context, std::move(x)));
147   }
148   CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
149       std::move(cosubscript)};
150   if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
151     folded.set_stat(Fold(context, std::move(*stat)));
152   }
153   if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
154     folded.set_team(
155         Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
156   }
157   return folded;
158 }
159 
FoldOperation(FoldingContext & context,DataRef && dataRef)160 DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
161   return std::visit(common::visitors{
162                         [&](SymbolRef symbol) { return DataRef{*symbol}; },
163                         [&](auto &&x) {
164                           return DataRef{FoldOperation(context, std::move(x))};
165                         },
166                     },
167       std::move(dataRef.u));
168 }
169 
FoldOperation(FoldingContext & context,Substring && substring)170 Substring FoldOperation(FoldingContext &context, Substring &&substring) {
171   auto lower{Fold(context, substring.lower())};
172   auto upper{Fold(context, substring.upper())};
173   if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
174     return Substring{FoldOperation(context, DataRef{*dataRef}),
175         std::move(lower), std::move(upper)};
176   } else {
177     auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
178     return Substring{std::move(p), std::move(lower), std::move(upper)};
179   }
180 }
181 
FoldOperation(FoldingContext & context,ComplexPart && complexPart)182 ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
183   DataRef complex{complexPart.complex()};
184   return ComplexPart{
185       FoldOperation(context, std::move(complex)), complexPart.part()};
186 }
187 
GetInt64Arg(const std::optional<ActualArgument> & arg)188 std::optional<std::int64_t> GetInt64Arg(
189     const std::optional<ActualArgument> &arg) {
190   if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
191     return ToInt64(*intExpr);
192   } else {
193     return std::nullopt;
194   }
195 }
196 
GetInt64ArgOr(const std::optional<ActualArgument> & arg,std::int64_t defaultValue)197 std::optional<std::int64_t> GetInt64ArgOr(
198     const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
199   if (!arg) {
200     return defaultValue;
201   } else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
202     return ToInt64(*intExpr);
203   } else {
204     return std::nullopt;
205   }
206 }
207 
FoldOperation(FoldingContext & context,ImpliedDoIndex && iDo)208 Expr<ImpliedDoIndex::Result> FoldOperation(
209     FoldingContext &context, ImpliedDoIndex &&iDo) {
210   if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
211     return Expr<ImpliedDoIndex::Result>{*value};
212   } else {
213     return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
214   }
215 }
216 
217 template class ExpressionBase<SomeDerived>;
218 template class ExpressionBase<SomeType>;
219 
220 } // namespace Fortran::evaluate
221