1 //===-- include/flang/Evaluate/shape.h --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 // GetShape() analyzes an expression and determines its shape, if possible,
10 // representing the result as a vector of scalar integer expressions.
11
12 #ifndef FORTRAN_EVALUATE_SHAPE_H_
13 #define FORTRAN_EVALUATE_SHAPE_H_
14
15 #include "expression.h"
16 #include "fold.h"
17 #include "traverse.h"
18 #include "variable.h"
19 #include "flang/Common/indirection.h"
20 #include "flang/Evaluate/tools.h"
21 #include "flang/Evaluate/type.h"
22 #include <optional>
23 #include <variant>
24
25 namespace Fortran::parser {
26 class ContextualMessages;
27 }
28
29 namespace Fortran::evaluate {
30
31 class FoldingContext;
32
33 using ExtentType = SubscriptInteger;
34 using ExtentExpr = Expr<ExtentType>;
35 using MaybeExtentExpr = std::optional<ExtentExpr>;
36 using Shape = std::vector<MaybeExtentExpr>;
37
38 bool IsImpliedShape(const Symbol &);
39 bool IsExplicitShape(const Symbol &);
40
41 // Conversions between various representations of shapes.
42 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &);
43
44 std::optional<Constant<ExtentType>> AsConstantShape(
45 FoldingContext &, const Shape &);
46 Constant<ExtentType> AsConstantShape(const ConstantSubscripts &);
47
48 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &);
49 std::optional<ConstantSubscripts> AsConstantExtents(
50 FoldingContext &, const Shape &);
51 Shape AsShape(const ConstantSubscripts &);
52 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &);
53
GetRank(const Shape & s)54 inline int GetRank(const Shape &s) { return static_cast<int>(s.size()); }
55
56 Shape Fold(FoldingContext &, Shape &&);
57 std::optional<Shape> Fold(FoldingContext &, std::optional<Shape> &&);
58
59 template <typename A>
60 std::optional<Shape> GetShape(FoldingContext &, const A &);
61 template <typename A> std::optional<Shape> GetShape(const A &);
62
63 // The dimension argument to these inquiries is zero-based,
64 // unlike the DIM= arguments to many intrinsics.
65 ExtentExpr GetLowerBound(const NamedEntity &, int dimension);
66 ExtentExpr GetLowerBound(FoldingContext &, const NamedEntity &, int dimension);
67 MaybeExtentExpr GetUpperBound(const NamedEntity &, int dimension);
68 MaybeExtentExpr GetUpperBound(
69 FoldingContext &, const NamedEntity &, int dimension);
70 MaybeExtentExpr ComputeUpperBound(ExtentExpr &&lower, MaybeExtentExpr &&extent);
71 MaybeExtentExpr ComputeUpperBound(
72 FoldingContext &, ExtentExpr &&lower, MaybeExtentExpr &&extent);
73 Shape GetLowerBounds(const NamedEntity &);
74 Shape GetLowerBounds(FoldingContext &, const NamedEntity &);
75 Shape GetUpperBounds(const NamedEntity &);
76 Shape GetUpperBounds(FoldingContext &, const NamedEntity &);
77 MaybeExtentExpr GetExtent(const NamedEntity &, int dimension);
78 MaybeExtentExpr GetExtent(FoldingContext &, const NamedEntity &, int dimension);
79 MaybeExtentExpr GetExtent(
80 const Subscript &, const NamedEntity &, int dimension);
81 MaybeExtentExpr GetExtent(
82 FoldingContext &, const Subscript &, const NamedEntity &, int dimension);
83
84 // Compute an element count for a triplet or trip count for a DO.
85 ExtentExpr CountTrips(
86 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride);
87 ExtentExpr CountTrips(
88 const ExtentExpr &lower, const ExtentExpr &upper, const ExtentExpr &stride);
89 MaybeExtentExpr CountTrips(
90 MaybeExtentExpr &&lower, MaybeExtentExpr &&upper, MaybeExtentExpr &&stride);
91
92 // Computes SIZE() == PRODUCT(shape)
93 MaybeExtentExpr GetSize(Shape &&);
94 ConstantSubscript GetSize(const ConstantSubscripts &);
95
96 // Utility predicate: does an expression reference any implied DO index?
97 bool ContainsAnyImpliedDoIndex(const ExtentExpr &);
98
99 class GetShapeHelper
100 : public AnyTraverse<GetShapeHelper, std::optional<Shape>> {
101 public:
102 using Result = std::optional<Shape>;
103 using Base = AnyTraverse<GetShapeHelper, Result>;
104 using Base::operator();
GetShapeHelper()105 GetShapeHelper() : Base{*this} {}
GetShapeHelper(FoldingContext & c)106 explicit GetShapeHelper(FoldingContext &c) : Base{*this}, context_{&c} {}
107
operator()108 Result operator()(const ImpliedDoIndex &) const { return ScalarShape(); }
operator()109 Result operator()(const DescriptorInquiry &) const { return ScalarShape(); }
operator()110 Result operator()(const TypeParamInquiry &) const { return ScalarShape(); }
operator()111 Result operator()(const BOZLiteralConstant &) const { return ScalarShape(); }
operator()112 Result operator()(const StaticDataObject::Pointer &) const {
113 return ScalarShape();
114 }
operator()115 Result operator()(const StructureConstructor &) const {
116 return ScalarShape();
117 }
118
operator()119 template <typename T> Result operator()(const Constant<T> &c) const {
120 return ConstantShape(c.SHAPE());
121 }
122
123 Result operator()(const Symbol &) const;
124 Result operator()(const Component &) const;
125 Result operator()(const ArrayRef &) const;
126 Result operator()(const CoarrayRef &) const;
127 Result operator()(const Substring &) const;
128 Result operator()(const ProcedureRef &) const;
129
130 template <typename T>
operator()131 Result operator()(const ArrayConstructor<T> &aconst) const {
132 return Shape{GetArrayConstructorExtent(aconst)};
133 }
134 template <typename D, typename R, typename LO, typename RO>
operator()135 Result operator()(const Operation<D, R, LO, RO> &operation) const {
136 if (operation.right().Rank() > 0) {
137 return (*this)(operation.right());
138 } else {
139 return (*this)(operation.left());
140 }
141 }
142
143 private:
ScalarShape()144 static Result ScalarShape() { return Shape{}; }
145 static Shape ConstantShape(const Constant<ExtentType> &);
146 Result AsShape(ExtentExpr &&) const;
147 static Shape CreateShape(int rank, NamedEntity &);
148
149 template <typename T>
GetArrayConstructorValueExtent(const ArrayConstructorValue<T> & value)150 MaybeExtentExpr GetArrayConstructorValueExtent(
151 const ArrayConstructorValue<T> &value) const {
152 return std::visit(
153 common::visitors{
154 [&](const Expr<T> &x) -> MaybeExtentExpr {
155 if (auto xShape{
156 context_ ? GetShape(*context_, x) : GetShape(x)}) {
157 // Array values in array constructors get linearized.
158 return GetSize(std::move(*xShape));
159 } else {
160 return std::nullopt;
161 }
162 },
163 [&](const ImpliedDo<T> &ido) -> MaybeExtentExpr {
164 // Don't be heroic and try to figure out triangular implied DO
165 // nests.
166 if (!ContainsAnyImpliedDoIndex(ido.lower()) &&
167 !ContainsAnyImpliedDoIndex(ido.upper()) &&
168 !ContainsAnyImpliedDoIndex(ido.stride())) {
169 if (auto nValues{GetArrayConstructorExtent(ido.values())}) {
170 return std::move(*nValues) *
171 CountTrips(ido.lower(), ido.upper(), ido.stride());
172 }
173 }
174 return std::nullopt;
175 },
176 },
177 value.u);
178 }
179
180 template <typename T>
GetArrayConstructorExtent(const ArrayConstructorValues<T> & values)181 MaybeExtentExpr GetArrayConstructorExtent(
182 const ArrayConstructorValues<T> &values) const {
183 ExtentExpr result{0};
184 for (const auto &value : values) {
185 if (MaybeExtentExpr n{GetArrayConstructorValueExtent(value)}) {
186 result = std::move(result) + std::move(*n);
187 if (context_) {
188 // Fold during expression creation to avoid creating an expression so
189 // large we can't evalute it without overflowing the stack.
190 result = Fold(*context_, std::move(result));
191 }
192 } else {
193 return std::nullopt;
194 }
195 }
196 return result;
197 }
198
199 FoldingContext *context_{nullptr};
200 };
201
202 template <typename A>
GetShape(FoldingContext & context,const A & x)203 std::optional<Shape> GetShape(FoldingContext &context, const A &x) {
204 if (auto shape{GetShapeHelper{context}(x)}) {
205 return Fold(context, std::move(shape));
206 } else {
207 return std::nullopt;
208 }
209 }
210
GetShape(const A & x)211 template <typename A> std::optional<Shape> GetShape(const A &x) {
212 return GetShapeHelper{}(x);
213 }
214
215 template <typename A>
GetShape(FoldingContext * context,const A & x)216 std::optional<Shape> GetShape(FoldingContext *context, const A &x) {
217 if (context) {
218 return GetShape(*context, x);
219 } else {
220 return GetShapeHelper{}(x);
221 }
222 }
223
224 template <typename A>
GetConstantShape(FoldingContext & context,const A & x)225 std::optional<Constant<ExtentType>> GetConstantShape(
226 FoldingContext &context, const A &x) {
227 if (auto shape{GetShape(context, x)}) {
228 return AsConstantShape(context, *shape);
229 } else {
230 return std::nullopt;
231 }
232 }
233
234 template <typename A>
GetConstantExtents(FoldingContext & context,const A & x)235 std::optional<ConstantSubscripts> GetConstantExtents(
236 FoldingContext &context, const A &x) {
237 if (auto shape{GetShape(context, x)}) {
238 return AsConstantExtents(context, *shape);
239 } else {
240 return std::nullopt;
241 }
242 }
243
244 // Compilation-time shape conformance checking, when corresponding extents
245 // are or should be known. The result is an optional Boolean:
246 // - nullopt: no error found or reported, but conformance cannot
247 // be guaranteed during compilation; this result is possible only
248 // when one or both arrays are allowed to have deferred shape
249 // - true: no error found or reported, arrays conform
250 // - false: errors found and reported
251 // Use "CheckConformance(...).value_or()" to specify a default result
252 // when you don't care whether messages have been emitted.
253 struct CheckConformanceFlags {
254 enum Flags {
255 None = 0,
256 LeftScalarExpandable = 1,
257 RightScalarExpandable = 2,
258 LeftIsDeferredShape = 4,
259 RightIsDeferredShape = 8,
260 EitherScalarExpandable = LeftScalarExpandable | RightScalarExpandable,
261 BothDeferredShape = LeftIsDeferredShape | RightIsDeferredShape,
262 RightIsExpandableDeferred = RightScalarExpandable | RightIsDeferredShape,
263 };
264 };
265 std::optional<bool> CheckConformance(parser::ContextualMessages &,
266 const Shape &left, const Shape &right,
267 CheckConformanceFlags::Flags flags = CheckConformanceFlags::None,
268 const char *leftIs = "left operand", const char *rightIs = "right operand");
269
270 // Increments one-based subscripts in element order (first varies fastest)
271 // and returns true when they remain in range; resets them all to one and
272 // return false otherwise (including the case where one or more of the
273 // extents are zero).
274 bool IncrementSubscripts(
275 ConstantSubscripts &, const ConstantSubscripts &extents);
276
277 } // namespace Fortran::evaluate
278 #endif // FORTRAN_EVALUATE_SHAPE_H_
279