1 //===-- include/flang/Evaluate/shape.h --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 // GetShape() analyzes an expression and determines its shape, if possible,
10 // representing the result as a vector of scalar integer expressions.
11
12 #ifndef FORTRAN_EVALUATE_SHAPE_H_
13 #define FORTRAN_EVALUATE_SHAPE_H_
14
15 #include "expression.h"
16 #include "traverse.h"
17 #include "variable.h"
18 #include "flang/Common/indirection.h"
19 #include "flang/Evaluate/tools.h"
20 #include "flang/Evaluate/type.h"
21 #include <optional>
22 #include <variant>
23
24 namespace Fortran::parser {
25 class ContextualMessages;
26 }
27
28 namespace Fortran::evaluate {
29
30 class FoldingContext;
31
32 using ExtentType = SubscriptInteger;
33 using ExtentExpr = Expr<ExtentType>;
34 using MaybeExtentExpr = std::optional<ExtentExpr>;
35 using Shape = std::vector<MaybeExtentExpr>;
36
37 bool IsImpliedShape(const Symbol &);
38 bool IsExplicitShape(const Symbol &);
39
40 // Conversions between various representations of shapes.
41 Shape AsShape(const Constant<ExtentType> &);
42 std::optional<Shape> AsShape(FoldingContext &, ExtentExpr &&);
43
44 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &);
45
46 std::optional<Constant<ExtentType>> AsConstantShape(
47 FoldingContext &, const Shape &);
48 Constant<ExtentType> AsConstantShape(const ConstantSubscripts &);
49
50 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &);
51 std::optional<ConstantSubscripts> AsConstantExtents(
52 FoldingContext &, const Shape &);
53
GetRank(const Shape & s)54 inline int GetRank(const Shape &s) { return static_cast<int>(s.size()); }
55
56 template <typename A>
57 std::optional<Shape> GetShape(FoldingContext &, const A &);
58
59 // The dimension argument to these inquiries is zero-based,
60 // unlike the DIM= arguments to many intrinsics.
61 ExtentExpr GetLowerBound(FoldingContext &, const NamedEntity &, int dimension);
62 MaybeExtentExpr GetUpperBound(
63 FoldingContext &, const NamedEntity &, int dimension);
64 MaybeExtentExpr ComputeUpperBound(
65 FoldingContext &, ExtentExpr &&lower, MaybeExtentExpr &&extent);
66 Shape GetLowerBounds(FoldingContext &, const NamedEntity &);
67 Shape GetUpperBounds(FoldingContext &, const NamedEntity &);
68 MaybeExtentExpr GetExtent(FoldingContext &, const NamedEntity &, int dimension);
69 MaybeExtentExpr GetExtent(
70 FoldingContext &, const Subscript &, const NamedEntity &, int dimension);
71
72 // Compute an element count for a triplet or trip count for a DO.
73 ExtentExpr CountTrips(FoldingContext &, ExtentExpr &&lower, ExtentExpr &&upper,
74 ExtentExpr &&stride);
75 ExtentExpr CountTrips(FoldingContext &, const ExtentExpr &lower,
76 const ExtentExpr &upper, const ExtentExpr &stride);
77 MaybeExtentExpr CountTrips(FoldingContext &, MaybeExtentExpr &&lower,
78 MaybeExtentExpr &&upper, MaybeExtentExpr &&stride);
79
80 // Computes SIZE() == PRODUCT(shape)
81 MaybeExtentExpr GetSize(Shape &&);
82
83 // Utility predicate: does an expression reference any implied DO index?
84 bool ContainsAnyImpliedDoIndex(const ExtentExpr &);
85
86 class GetShapeHelper
87 : public AnyTraverse<GetShapeHelper, std::optional<Shape>> {
88 public:
89 using Result = std::optional<Shape>;
90 using Base = AnyTraverse<GetShapeHelper, Result>;
91 using Base::operator();
GetShapeHelper(FoldingContext & c)92 explicit GetShapeHelper(FoldingContext &c) : Base{*this}, context_{c} {}
93
operator()94 Result operator()(const ImpliedDoIndex &) const { return Scalar(); }
operator()95 Result operator()(const DescriptorInquiry &) const { return Scalar(); }
operator()96 template <int KIND> Result operator()(const TypeParamInquiry<KIND> &) const {
97 return Scalar();
98 }
operator()99 Result operator()(const BOZLiteralConstant &) const { return Scalar(); }
operator()100 Result operator()(const StaticDataObject::Pointer &) const {
101 return Scalar();
102 }
operator()103 Result operator()(const StructureConstructor &) const { return Scalar(); }
104
operator()105 template <typename T> Result operator()(const Constant<T> &c) const {
106 return AsShape(c.SHAPE());
107 }
108
109 Result operator()(const Symbol &) const;
110 Result operator()(const Component &) const;
111 Result operator()(const ArrayRef &) const;
112 Result operator()(const CoarrayRef &) const;
113 Result operator()(const Substring &) const;
114 Result operator()(const ProcedureRef &) const;
115
116 template <typename T>
operator()117 Result operator()(const ArrayConstructor<T> &aconst) const {
118 return Shape{GetArrayConstructorExtent(aconst)};
119 }
120 template <typename D, typename R, typename LO, typename RO>
operator()121 Result operator()(const Operation<D, R, LO, RO> &operation) const {
122 if (operation.right().Rank() > 0) {
123 return (*this)(operation.right());
124 } else {
125 return (*this)(operation.left());
126 }
127 }
128
129 private:
Scalar()130 static Result Scalar() { return Shape{}; }
CreateShape(int rank,NamedEntity & base)131 Shape CreateShape(int rank, NamedEntity &base) const {
132 Shape shape;
133 for (int dimension{0}; dimension < rank; ++dimension) {
134 shape.emplace_back(GetExtent(context_, base, dimension));
135 }
136 return shape;
137 }
138 template <typename T>
GetArrayConstructorValueExtent(const ArrayConstructorValue<T> & value)139 MaybeExtentExpr GetArrayConstructorValueExtent(
140 const ArrayConstructorValue<T> &value) const {
141 return std::visit(
142 common::visitors{
143 [&](const Expr<T> &x) -> MaybeExtentExpr {
144 if (std::optional<Shape> xShape{GetShape(context_, x)}) {
145 // Array values in array constructors get linearized.
146 return GetSize(std::move(*xShape));
147 } else {
148 return std::nullopt;
149 }
150 },
151 [&](const ImpliedDo<T> &ido) -> MaybeExtentExpr {
152 // Don't be heroic and try to figure out triangular implied DO
153 // nests.
154 if (!ContainsAnyImpliedDoIndex(ido.lower()) &&
155 !ContainsAnyImpliedDoIndex(ido.upper()) &&
156 !ContainsAnyImpliedDoIndex(ido.stride())) {
157 if (auto nValues{GetArrayConstructorExtent(ido.values())}) {
158 return std::move(*nValues) *
159 CountTrips(
160 context_, ido.lower(), ido.upper(), ido.stride());
161 }
162 }
163 return std::nullopt;
164 },
165 },
166 value.u);
167 }
168
169 template <typename T>
GetArrayConstructorExtent(const ArrayConstructorValues<T> & values)170 MaybeExtentExpr GetArrayConstructorExtent(
171 const ArrayConstructorValues<T> &values) const {
172 ExtentExpr result{0};
173 for (const auto &value : values) {
174 if (MaybeExtentExpr n{GetArrayConstructorValueExtent(value)}) {
175 result = std::move(result) + std::move(*n);
176 } else {
177 return std::nullopt;
178 }
179 }
180 return result;
181 }
182
183 FoldingContext &context_;
184 };
185
186 template <typename A>
GetShape(FoldingContext & context,const A & x)187 std::optional<Shape> GetShape(FoldingContext &context, const A &x) {
188 return GetShapeHelper{context}(x);
189 }
190
191 template <typename A>
GetConstantShape(FoldingContext & context,const A & x)192 std::optional<Constant<ExtentType>> GetConstantShape(
193 FoldingContext &context, const A &x) {
194 if (auto shape{GetShape(context, x)}) {
195 return AsConstantShape(context, *shape);
196 } else {
197 return std::nullopt;
198 }
199 }
200
201 template <typename A>
GetConstantExtents(FoldingContext & context,const A & x)202 std::optional<ConstantSubscripts> GetConstantExtents(
203 FoldingContext &context, const A &x) {
204 if (auto shape{GetShape(context, x)}) {
205 return AsConstantExtents(context, *shape);
206 } else {
207 return std::nullopt;
208 }
209 }
210
211 // Compilation-time shape conformance checking, when corresponding extents
212 // are known.
213 bool CheckConformance(parser::ContextualMessages &, const Shape &left,
214 const Shape &right, const char *leftIs = "left operand",
215 const char *rightIs = "right operand");
216
217 // Increments one-based subscripts in element order (first varies fastest)
218 // and returns true when they remain in range; resets them all to one and
219 // return false otherwise (including the case where one or more of the
220 // extents are zero).
221 bool IncrementSubscripts(
222 ConstantSubscripts &, const ConstantSubscripts &extents);
223
224 } // namespace Fortran::evaluate
225 #endif // FORTRAN_EVALUATE_SHAPE_H_
226