1 //===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/fold.h"
10 #include "fold-implementation.h"
11 #include "flang/Evaluate/characteristics.h"
12
13 namespace Fortran::evaluate {
14
Fold(FoldingContext & context,characteristics::TypeAndShape && x)15 characteristics::TypeAndShape Fold(
16 FoldingContext &context, characteristics::TypeAndShape &&x) {
17 x.Rewrite(context);
18 return std::move(x);
19 }
20
GetConstantSubscript(FoldingContext & context,Subscript & ss,const NamedEntity & base,int dim)21 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
22 FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
23 ss = FoldOperation(context, std::move(ss));
24 return std::visit(
25 common::visitors{
26 [](IndirectSubscriptIntegerExpr &expr)
27 -> std::optional<Constant<SubscriptInteger>> {
28 if (const auto *constant{
29 UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
30 return *constant;
31 } else {
32 return std::nullopt;
33 }
34 },
35 [&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
36 auto lower{triplet.lower()}, upper{triplet.upper()};
37 std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
38 if (!lower) {
39 lower = GetLowerBound(context, base, dim);
40 }
41 if (!upper) {
42 upper =
43 ComputeUpperBound(context, GetLowerBound(context, base, dim),
44 GetExtent(context, base, dim));
45 }
46 auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
47 if (lbi && ubi && stride && *stride != 0) {
48 std::vector<SubscriptInteger::Scalar> values;
49 while ((*stride > 0 && *lbi <= *ubi) ||
50 (*stride < 0 && *lbi >= *ubi)) {
51 values.emplace_back(*lbi);
52 *lbi += *stride;
53 }
54 return Constant<SubscriptInteger>{std::move(values),
55 ConstantSubscripts{
56 static_cast<ConstantSubscript>(values.size())}};
57 } else {
58 return std::nullopt;
59 }
60 },
61 },
62 ss.u);
63 }
64
FoldOperation(FoldingContext & context,StructureConstructor && structure)65 Expr<SomeDerived> FoldOperation(
66 FoldingContext &context, StructureConstructor &&structure) {
67 StructureConstructor ctor{structure.derivedTypeSpec()};
68 bool constantExtents{true};
69 for (auto &&[symbol, value] : std::move(structure)) {
70 auto expr{Fold(context, std::move(value.value()))};
71 if (!IsPointer(symbol)) {
72 bool ok{false};
73 if (auto valueShape{GetConstantExtents(context, expr)}) {
74 if (auto componentShape{GetConstantExtents(context, symbol)}) {
75 if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
76 expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
77 std::move(expr));
78 ok = expr.Rank() > 0;
79 } else {
80 ok = *valueShape == *componentShape;
81 }
82 }
83 }
84 if (!ok) {
85 constantExtents = false;
86 }
87 }
88 ctor.Add(symbol, Fold(context, std::move(expr)));
89 }
90 if (constantExtents && IsConstantExpr(ctor)) {
91 return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
92 } else {
93 return Expr<SomeDerived>{std::move(ctor)};
94 }
95 }
96
FoldOperation(FoldingContext & context,Component && component)97 Component FoldOperation(FoldingContext &context, Component &&component) {
98 return {FoldOperation(context, std::move(component.base())),
99 component.GetLastSymbol()};
100 }
101
FoldOperation(FoldingContext & context,NamedEntity && x)102 NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
103 if (Component * c{x.UnwrapComponent()}) {
104 return NamedEntity{FoldOperation(context, std::move(*c))};
105 } else {
106 return std::move(x);
107 }
108 }
109
FoldOperation(FoldingContext & context,Triplet && triplet)110 Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
111 MaybeExtentExpr lower{triplet.lower()};
112 MaybeExtentExpr upper{triplet.upper()};
113 return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
114 Fold(context, triplet.stride())};
115 }
116
FoldOperation(FoldingContext & context,Subscript && subscript)117 Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
118 return std::visit(common::visitors{
119 [&](IndirectSubscriptIntegerExpr &&expr) {
120 expr.value() = Fold(context, std::move(expr.value()));
121 return Subscript(std::move(expr));
122 },
123 [&](Triplet &&triplet) {
124 return Subscript(
125 FoldOperation(context, std::move(triplet)));
126 },
127 },
128 std::move(subscript.u));
129 }
130
FoldOperation(FoldingContext & context,ArrayRef && arrayRef)131 ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
132 NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
133 for (Subscript &subscript : arrayRef.subscript()) {
134 subscript = FoldOperation(context, std::move(subscript));
135 }
136 return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
137 }
138
FoldOperation(FoldingContext & context,CoarrayRef && coarrayRef)139 CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
140 std::vector<Subscript> subscript;
141 for (Subscript x : coarrayRef.subscript()) {
142 subscript.emplace_back(FoldOperation(context, std::move(x)));
143 }
144 std::vector<Expr<SubscriptInteger>> cosubscript;
145 for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
146 cosubscript.emplace_back(Fold(context, std::move(x)));
147 }
148 CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
149 std::move(cosubscript)};
150 if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
151 folded.set_stat(Fold(context, std::move(*stat)));
152 }
153 if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
154 folded.set_team(
155 Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
156 }
157 return folded;
158 }
159
FoldOperation(FoldingContext & context,DataRef && dataRef)160 DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
161 return std::visit(common::visitors{
162 [&](SymbolRef symbol) { return DataRef{*symbol}; },
163 [&](auto &&x) {
164 return DataRef{FoldOperation(context, std::move(x))};
165 },
166 },
167 std::move(dataRef.u));
168 }
169
FoldOperation(FoldingContext & context,Substring && substring)170 Substring FoldOperation(FoldingContext &context, Substring &&substring) {
171 auto lower{Fold(context, substring.lower())};
172 auto upper{Fold(context, substring.upper())};
173 if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
174 return Substring{FoldOperation(context, DataRef{*dataRef}),
175 std::move(lower), std::move(upper)};
176 } else {
177 auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
178 return Substring{std::move(p), std::move(lower), std::move(upper)};
179 }
180 }
181
FoldOperation(FoldingContext & context,ComplexPart && complexPart)182 ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
183 DataRef complex{complexPart.complex()};
184 return ComplexPart{
185 FoldOperation(context, std::move(complex)), complexPart.part()};
186 }
187
GetInt64Arg(const std::optional<ActualArgument> & arg)188 std::optional<std::int64_t> GetInt64Arg(
189 const std::optional<ActualArgument> &arg) {
190 if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
191 return ToInt64(*intExpr);
192 } else {
193 return std::nullopt;
194 }
195 }
196
GetInt64ArgOr(const std::optional<ActualArgument> & arg,std::int64_t defaultValue)197 std::optional<std::int64_t> GetInt64ArgOr(
198 const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
199 if (!arg) {
200 return defaultValue;
201 } else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
202 return ToInt64(*intExpr);
203 } else {
204 return std::nullopt;
205 }
206 }
207
FoldOperation(FoldingContext & context,ImpliedDoIndex && iDo)208 Expr<ImpliedDoIndex::Result> FoldOperation(
209 FoldingContext &context, ImpliedDoIndex &&iDo) {
210 if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
211 return Expr<ImpliedDoIndex::Result>{*value};
212 } else {
213 return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
214 }
215 }
216
217 template class ExpressionBase<SomeDerived>;
218 template class ExpressionBase<SomeType>;
219
220 } // namespace Fortran::evaluate
221