1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/intrinsics.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/type.h"
17 #include "flang/Parser/message.h"
18 #include "flang/Semantics/symbol.h"
19 #include <functional>
20
21 using namespace std::placeholders; // _1, _2, &c. for std::bind()
22
23 namespace Fortran::evaluate {
24
IsImpliedShape(const Symbol & original)25 bool IsImpliedShape(const Symbol &original) {
26 const Symbol &symbol{ResolveAssociations(original)};
27 const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
28 return details && symbol.attrs().test(semantics::Attr::PARAMETER) &&
29 details->shape().IsImpliedShape();
30 }
31
IsExplicitShape(const Symbol & original)32 bool IsExplicitShape(const Symbol &original) {
33 const Symbol &symbol{ResolveAssociations(original)};
34 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
35 const auto &shape{details->shape()};
36 return shape.Rank() == 0 ||
37 shape.IsExplicitShape(); // true when scalar, too
38 } else {
39 return symbol
40 .has<semantics::AssocEntityDetails>(); // exprs have explicit shape
41 }
42 }
43
ConstantShape(const Constant<ExtentType> & arrayConstant)44 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) {
45 CHECK(arrayConstant.Rank() == 1);
46 Shape result;
47 std::size_t dimensions{arrayConstant.size()};
48 for (std::size_t j{0}; j < dimensions; ++j) {
49 Scalar<ExtentType> extent{arrayConstant.values().at(j)};
50 result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}});
51 }
52 return result;
53 }
54
AsShape(ExtentExpr && arrayExpr) const55 auto GetShapeHelper::AsShape(ExtentExpr &&arrayExpr) const -> Result {
56 if (context_) {
57 arrayExpr = Fold(*context_, std::move(arrayExpr));
58 }
59 if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
60 return ConstantShape(*constArray);
61 }
62 if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
63 Shape result;
64 for (auto &value : *constructor) {
65 if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
66 if (expr->Rank() == 0) {
67 result.emplace_back(std::move(*expr));
68 continue;
69 }
70 }
71 return std::nullopt;
72 }
73 return result;
74 }
75 return std::nullopt;
76 }
77
CreateShape(int rank,NamedEntity & base)78 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) {
79 Shape shape;
80 for (int dimension{0}; dimension < rank; ++dimension) {
81 shape.emplace_back(GetExtent(base, dimension));
82 }
83 return shape;
84 }
85
AsExtentArrayExpr(const Shape & shape)86 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
87 ArrayConstructorValues<ExtentType> values;
88 for (const auto &dim : shape) {
89 if (dim) {
90 values.Push(common::Clone(*dim));
91 } else {
92 return std::nullopt;
93 }
94 }
95 return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
96 }
97
AsConstantShape(FoldingContext & context,const Shape & shape)98 std::optional<Constant<ExtentType>> AsConstantShape(
99 FoldingContext &context, const Shape &shape) {
100 if (auto shapeArray{AsExtentArrayExpr(shape)}) {
101 auto folded{Fold(context, std::move(*shapeArray))};
102 if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
103 return std::move(*p);
104 }
105 }
106 return std::nullopt;
107 }
108
AsConstantShape(const ConstantSubscripts & shape)109 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
110 using IntType = Scalar<SubscriptInteger>;
111 std::vector<IntType> result;
112 for (auto dim : shape) {
113 result.emplace_back(dim);
114 }
115 return {std::move(result), ConstantSubscripts{GetRank(shape)}};
116 }
117
AsConstantExtents(const Constant<ExtentType> & shape)118 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
119 ConstantSubscripts result;
120 for (const auto &extent : shape.values()) {
121 result.push_back(extent.ToInt64());
122 }
123 return result;
124 }
125
AsConstantExtents(FoldingContext & context,const Shape & shape)126 std::optional<ConstantSubscripts> AsConstantExtents(
127 FoldingContext &context, const Shape &shape) {
128 if (auto shapeConstant{AsConstantShape(context, shape)}) {
129 return AsConstantExtents(*shapeConstant);
130 } else {
131 return std::nullopt;
132 }
133 }
134
AsShape(const ConstantSubscripts & shape)135 Shape AsShape(const ConstantSubscripts &shape) {
136 Shape result;
137 for (const auto &extent : shape) {
138 result.emplace_back(ExtentExpr{extent});
139 }
140 return result;
141 }
142
AsShape(const std::optional<ConstantSubscripts> & shape)143 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &shape) {
144 if (shape) {
145 return AsShape(*shape);
146 } else {
147 return std::nullopt;
148 }
149 }
150
Fold(FoldingContext & context,Shape && shape)151 Shape Fold(FoldingContext &context, Shape &&shape) {
152 for (auto &dim : shape) {
153 dim = Fold(context, std::move(dim));
154 }
155 return std::move(shape);
156 }
157
Fold(FoldingContext & context,std::optional<Shape> && shape)158 std::optional<Shape> Fold(
159 FoldingContext &context, std::optional<Shape> &&shape) {
160 if (shape) {
161 return Fold(context, std::move(*shape));
162 } else {
163 return std::nullopt;
164 }
165 }
166
ComputeTripCount(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)167 static ExtentExpr ComputeTripCount(
168 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
169 ExtentExpr strideCopy{common::Clone(stride)};
170 ExtentExpr span{
171 (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
172 std::move(stride)};
173 return ExtentExpr{
174 Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
175 }
176
CountTrips(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)177 ExtentExpr CountTrips(
178 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
179 return ComputeTripCount(
180 std::move(lower), std::move(upper), std::move(stride));
181 }
182
CountTrips(const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)183 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
184 const ExtentExpr &stride) {
185 return ComputeTripCount(
186 common::Clone(lower), common::Clone(upper), common::Clone(stride));
187 }
188
CountTrips(MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)189 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper,
190 MaybeExtentExpr &&stride) {
191 std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
192 std::bind(ComputeTripCount, _1, _2, _3)};
193 return common::MapOptional(
194 std::move(bound), std::move(lower), std::move(upper), std::move(stride));
195 }
196
GetSize(Shape && shape)197 MaybeExtentExpr GetSize(Shape &&shape) {
198 ExtentExpr extent{1};
199 for (auto &&dim : std::move(shape)) {
200 if (dim) {
201 extent = std::move(extent) * std::move(*dim);
202 } else {
203 return std::nullopt;
204 }
205 }
206 return extent;
207 }
208
GetSize(const ConstantSubscripts & shape)209 ConstantSubscript GetSize(const ConstantSubscripts &shape) {
210 ConstantSubscript size{1};
211 for (auto dim : shape) {
212 CHECK(dim >= 0);
213 size *= dim;
214 }
215 return size;
216 }
217
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)218 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
219 struct MyVisitor : public AnyTraverse<MyVisitor> {
220 using Base = AnyTraverse<MyVisitor>;
221 MyVisitor() : Base{*this} {}
222 using Base::operator();
223 bool operator()(const ImpliedDoIndex &) { return true; }
224 };
225 return MyVisitor{}(expr);
226 }
227
228 // Determines lower bound on a dimension. This can be other than 1 only
229 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
230 // ASSOCIATE construct entities may require traversal of their referents.
231 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> {
232 public:
233 using Result = ExtentExpr;
234 using Base = Traverse<GetLowerBoundHelper, ExtentExpr>;
235 using Base::operator();
GetLowerBoundHelper(int d)236 explicit GetLowerBoundHelper(int d) : Base{*this}, dimension_{d} {}
Default()237 static ExtentExpr Default() { return ExtentExpr{1}; }
Combine(Result &&,Result &&)238 static ExtentExpr Combine(Result &&, Result &&) { return Default(); }
239 ExtentExpr operator()(const Symbol &);
240 ExtentExpr operator()(const Component &);
241
242 private:
243 int dimension_;
244 };
245
operator ()(const Symbol & symbol0)246 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
247 const Symbol &symbol{symbol0.GetUltimate()};
248 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
249 int j{0};
250 for (const auto &shapeSpec : details->shape()) {
251 if (j++ == dimension_) {
252 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
253 return *bound;
254 } else if (IsDescriptor(symbol)) {
255 return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
256 DescriptorInquiry::Field::LowerBound, dimension_}};
257 } else {
258 break;
259 }
260 }
261 }
262 } else if (const auto *assoc{
263 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
264 if (assoc->rank()) { // SELECT RANK case
265 const Symbol &resolved{ResolveAssociations(symbol)};
266 if (IsDescriptor(resolved) && dimension_ < *assoc->rank()) {
267 return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
268 DescriptorInquiry::Field::LowerBound, dimension_}};
269 }
270 } else {
271 return (*this)(assoc->expr());
272 }
273 }
274 return Default();
275 }
276
operator ()(const Component & component)277 auto GetLowerBoundHelper::operator()(const Component &component) -> Result {
278 if (component.base().Rank() == 0) {
279 const Symbol &symbol{component.GetLastSymbol().GetUltimate()};
280 if (const auto *details{
281 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
282 int j{0};
283 for (const auto &shapeSpec : details->shape()) {
284 if (j++ == dimension_) {
285 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
286 return *bound;
287 } else if (IsDescriptor(symbol)) {
288 return ExtentExpr{
289 DescriptorInquiry{NamedEntity{common::Clone(component)},
290 DescriptorInquiry::Field::LowerBound, dimension_}};
291 } else {
292 break;
293 }
294 }
295 }
296 }
297 }
298 return Default();
299 }
300
GetLowerBound(const NamedEntity & base,int dimension)301 ExtentExpr GetLowerBound(const NamedEntity &base, int dimension) {
302 return GetLowerBoundHelper{dimension}(base);
303 }
304
GetLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)305 ExtentExpr GetLowerBound(
306 FoldingContext &context, const NamedEntity &base, int dimension) {
307 return Fold(context, GetLowerBound(base, dimension));
308 }
309
GetLowerBounds(const NamedEntity & base)310 Shape GetLowerBounds(const NamedEntity &base) {
311 Shape result;
312 int rank{base.Rank()};
313 for (int dim{0}; dim < rank; ++dim) {
314 result.emplace_back(GetLowerBound(base, dim));
315 }
316 return result;
317 }
318
GetLowerBounds(FoldingContext & context,const NamedEntity & base)319 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) {
320 Shape result;
321 int rank{base.Rank()};
322 for (int dim{0}; dim < rank; ++dim) {
323 result.emplace_back(GetLowerBound(context, base, dim));
324 }
325 return result;
326 }
327
328 // If the upper and lower bounds are constant, return a constant expression for
329 // the extent. In particular, if the upper bound is less than the lower bound,
330 // return zero.
GetNonNegativeExtent(const semantics::ShapeSpec & shapeSpec)331 static MaybeExtentExpr GetNonNegativeExtent(
332 const semantics::ShapeSpec &shapeSpec) {
333 const auto &ubound{shapeSpec.ubound().GetExplicit()};
334 const auto &lbound{shapeSpec.lbound().GetExplicit()};
335 std::optional<ConstantSubscript> uval{ToInt64(ubound)};
336 std::optional<ConstantSubscript> lval{ToInt64(lbound)};
337 if (uval && lval) {
338 if (*uval < *lval) {
339 return ExtentExpr{0};
340 } else {
341 return ExtentExpr{*uval - *lval + 1};
342 }
343 }
344 return common::Clone(ubound.value()) - common::Clone(lbound.value()) +
345 ExtentExpr{1};
346 }
347
GetExtent(const NamedEntity & base,int dimension)348 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
349 CHECK(dimension >= 0);
350 const Symbol &last{base.GetLastSymbol()};
351 const Symbol &symbol{ResolveAssociations(last)};
352 if (const auto *assoc{last.detailsIf<semantics::AssocEntityDetails>()}) {
353 if (assoc->rank()) { // SELECT RANK case
354 if (semantics::IsDescriptor(symbol) && dimension < *assoc->rank()) {
355 return ExtentExpr{DescriptorInquiry{
356 NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
357 }
358 } else if (auto shape{GetShape(assoc->expr())}) {
359 if (dimension < static_cast<int>(shape->size())) {
360 return std::move(shape->at(dimension));
361 }
362 }
363 }
364 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
365 if (IsImpliedShape(symbol) && details->init()) {
366 if (auto shape{GetShape(symbol)}) {
367 if (dimension < static_cast<int>(shape->size())) {
368 return std::move(shape->at(dimension));
369 }
370 }
371 } else {
372 int j{0};
373 for (const auto &shapeSpec : details->shape()) {
374 if (j++ == dimension) {
375 if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
376 if (shapeSpec.ubound().GetExplicit()) {
377 // 8.5.8.2, paragraph 3. If the upper bound is less than the
378 // lower bound, the extent is zero.
379 if (shapeSpec.lbound().GetExplicit()) {
380 return GetNonNegativeExtent(shapeSpec);
381 } else {
382 return ubound.value();
383 }
384 }
385 } else if (details->IsAssumedSize() && j == symbol.Rank()) {
386 return std::nullopt;
387 } else if (semantics::IsDescriptor(symbol)) {
388 return ExtentExpr{DescriptorInquiry{NamedEntity{base},
389 DescriptorInquiry::Field::Extent, dimension}};
390 }
391 }
392 }
393 }
394 }
395 return std::nullopt;
396 }
397
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)398 MaybeExtentExpr GetExtent(
399 FoldingContext &context, const NamedEntity &base, int dimension) {
400 return Fold(context, GetExtent(base, dimension));
401 }
402
GetExtent(const Subscript & subscript,const NamedEntity & base,int dimension)403 MaybeExtentExpr GetExtent(
404 const Subscript &subscript, const NamedEntity &base, int dimension) {
405 return std::visit(
406 common::visitors{
407 [&](const Triplet &triplet) -> MaybeExtentExpr {
408 MaybeExtentExpr upper{triplet.upper()};
409 if (!upper) {
410 upper = GetUpperBound(base, dimension);
411 }
412 MaybeExtentExpr lower{triplet.lower()};
413 if (!lower) {
414 lower = GetLowerBound(base, dimension);
415 }
416 return CountTrips(std::move(lower), std::move(upper),
417 MaybeExtentExpr{triplet.stride()});
418 },
419 [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
420 if (auto shape{GetShape(subs.value())}) {
421 if (GetRank(*shape) > 0) {
422 CHECK(GetRank(*shape) == 1); // vector-valued subscript
423 return std::move(shape->at(0));
424 }
425 }
426 return std::nullopt;
427 },
428 },
429 subscript.u);
430 }
431
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)432 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
433 const NamedEntity &base, int dimension) {
434 return Fold(context, GetExtent(subscript, base, dimension));
435 }
436
ComputeUpperBound(ExtentExpr && lower,MaybeExtentExpr && extent)437 MaybeExtentExpr ComputeUpperBound(
438 ExtentExpr &&lower, MaybeExtentExpr &&extent) {
439 if (extent) {
440 return std::move(*extent) + std::move(lower) - ExtentExpr{1};
441 } else {
442 return std::nullopt;
443 }
444 }
445
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)446 MaybeExtentExpr ComputeUpperBound(
447 FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
448 return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent)));
449 }
450
GetUpperBound(const NamedEntity & base,int dimension)451 MaybeExtentExpr GetUpperBound(const NamedEntity &base, int dimension) {
452 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
453 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
454 int j{0};
455 for (const auto &shapeSpec : details->shape()) {
456 if (j++ == dimension) {
457 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
458 return *bound;
459 } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
460 break;
461 } else {
462 return ComputeUpperBound(
463 GetLowerBound(base, dimension), GetExtent(base, dimension));
464 }
465 }
466 }
467 } else if (const auto *assoc{
468 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
469 if (auto shape{GetShape(assoc->expr())}) {
470 if (dimension < static_cast<int>(shape->size())) {
471 return ComputeUpperBound(
472 GetLowerBound(base, dimension), std::move(shape->at(dimension)));
473 }
474 }
475 }
476 return std::nullopt;
477 }
478
GetUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)479 MaybeExtentExpr GetUpperBound(
480 FoldingContext &context, const NamedEntity &base, int dimension) {
481 return Fold(context, GetUpperBound(base, dimension));
482 }
483
GetUpperBounds(const NamedEntity & base)484 Shape GetUpperBounds(const NamedEntity &base) {
485 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
486 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
487 Shape result;
488 int dim{0};
489 for (const auto &shapeSpec : details->shape()) {
490 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
491 result.push_back(*bound);
492 } else if (details->IsAssumedSize()) {
493 CHECK(dim + 1 == base.Rank());
494 result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
495 } else {
496 result.emplace_back(
497 ComputeUpperBound(GetLowerBound(base, dim), GetExtent(base, dim)));
498 }
499 ++dim;
500 }
501 CHECK(GetRank(result) == symbol.Rank());
502 return result;
503 } else {
504 return std::move(GetShape(symbol).value());
505 }
506 }
507
GetUpperBounds(FoldingContext & context,const NamedEntity & base)508 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
509 return Fold(context, GetUpperBounds(base));
510 }
511
operator ()(const Symbol & symbol) const512 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
513 return std::visit(
514 common::visitors{
515 [&](const semantics::ObjectEntityDetails &object) {
516 if (IsImpliedShape(symbol) && object.init()) {
517 return (*this)(object.init());
518 } else if (IsAssumedRank(symbol)) {
519 return Result{};
520 } else {
521 int n{object.shape().Rank()};
522 NamedEntity base{symbol};
523 return Result{CreateShape(n, base)};
524 }
525 },
526 [](const semantics::EntityDetails &) {
527 return ScalarShape(); // no dimensions seen
528 },
529 [&](const semantics::ProcEntityDetails &proc) {
530 if (const Symbol * interface{proc.interface().symbol()}) {
531 return (*this)(*interface);
532 } else {
533 return ScalarShape();
534 }
535 },
536 [&](const semantics::AssocEntityDetails &assoc) {
537 if (assoc.rank()) { // SELECT RANK case
538 int n{assoc.rank().value()};
539 NamedEntity base{symbol};
540 return Result{CreateShape(n, base)};
541 } else {
542 return (*this)(assoc.expr());
543 }
544 },
545 [&](const semantics::SubprogramDetails &subp) {
546 if (subp.isFunction()) {
547 return (*this)(subp.result());
548 } else {
549 return Result{};
550 }
551 },
552 [&](const semantics::ProcBindingDetails &binding) {
553 return (*this)(binding.symbol());
554 },
555 [](const semantics::TypeParamDetails &) { return ScalarShape(); },
556 [](const auto &) { return Result{}; },
557 },
558 symbol.GetUltimate().details());
559 }
560
operator ()(const Component & component) const561 auto GetShapeHelper::operator()(const Component &component) const -> Result {
562 const Symbol &symbol{component.GetLastSymbol()};
563 int rank{symbol.Rank()};
564 if (rank == 0) {
565 return (*this)(component.base());
566 } else if (symbol.has<semantics::ObjectEntityDetails>()) {
567 NamedEntity base{Component{component}};
568 return CreateShape(rank, base);
569 } else if (symbol.has<semantics::AssocEntityDetails>()) {
570 NamedEntity base{Component{component}};
571 return Result{CreateShape(rank, base)};
572 } else {
573 return (*this)(symbol);
574 }
575 }
576
operator ()(const ArrayRef & arrayRef) const577 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
578 Shape shape;
579 int dimension{0};
580 const NamedEntity &base{arrayRef.base()};
581 for (const Subscript &ss : arrayRef.subscript()) {
582 if (ss.Rank() > 0) {
583 shape.emplace_back(GetExtent(ss, base, dimension));
584 }
585 ++dimension;
586 }
587 if (shape.empty()) {
588 if (const Component * component{base.UnwrapComponent()}) {
589 return (*this)(component->base());
590 }
591 }
592 return shape;
593 }
594
operator ()(const CoarrayRef & coarrayRef) const595 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
596 NamedEntity base{coarrayRef.GetBase()};
597 if (coarrayRef.subscript().empty()) {
598 return (*this)(base);
599 } else {
600 Shape shape;
601 int dimension{0};
602 for (const Subscript &ss : coarrayRef.subscript()) {
603 if (ss.Rank() > 0) {
604 shape.emplace_back(GetExtent(ss, base, dimension));
605 }
606 ++dimension;
607 }
608 return shape;
609 }
610 }
611
operator ()(const Substring & substring) const612 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
613 return (*this)(substring.parent());
614 }
615
operator ()(const ProcedureRef & call) const616 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
617 if (call.Rank() == 0) {
618 return ScalarShape();
619 } else if (call.IsElemental()) {
620 for (const auto &arg : call.arguments()) {
621 if (arg && arg->Rank() > 0) {
622 return (*this)(*arg);
623 }
624 }
625 return ScalarShape();
626 } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
627 return (*this)(*symbol);
628 } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
629 if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
630 intrinsic->name == "ubound") {
631 // These are the array-valued cases for LBOUND and UBOUND (no DIM=).
632 const auto *expr{call.arguments().front().value().UnwrapExpr()};
633 CHECK(expr);
634 return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
635 } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
636 intrinsic->name == "count" || intrinsic->name == "iall" ||
637 intrinsic->name == "iany" || intrinsic->name == "iparity" ||
638 intrinsic->name == "maxval" || intrinsic->name == "minval" ||
639 intrinsic->name == "norm2" || intrinsic->name == "parity" ||
640 intrinsic->name == "product" || intrinsic->name == "sum") {
641 // Reduction with DIM=
642 if (call.arguments().size() >= 2) {
643 auto arrayShape{
644 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
645 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
646 if (arrayShape && dimArg) {
647 if (auto dim{ToInt64(*dimArg)}) {
648 if (*dim >= 1 &&
649 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
650 arrayShape->erase(arrayShape->begin() + (*dim - 1));
651 return std::move(*arrayShape);
652 }
653 }
654 }
655 }
656 } else if (intrinsic->name == "findloc" || intrinsic->name == "maxloc" ||
657 intrinsic->name == "minloc") {
658 std::size_t dimIndex{intrinsic->name == "findloc" ? 2u : 1u};
659 if (call.arguments().size() > dimIndex) {
660 if (auto arrayShape{
661 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
662 auto rank{static_cast<int>(arrayShape->size())};
663 if (const auto *dimArg{
664 UnwrapExpr<Expr<SomeType>>(call.arguments()[dimIndex])}) {
665 auto dim{ToInt64(*dimArg)};
666 if (dim && *dim >= 1 && *dim <= rank) {
667 arrayShape->erase(arrayShape->begin() + (*dim - 1));
668 return std::move(*arrayShape);
669 }
670 } else {
671 // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=))
672 return Shape{ExtentExpr{rank}};
673 }
674 }
675 }
676 } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
677 if (!call.arguments().empty()) {
678 return (*this)(call.arguments()[0]);
679 }
680 } else if (intrinsic->name == "matmul") {
681 if (call.arguments().size() == 2) {
682 if (auto ashape{(*this)(call.arguments()[0])}) {
683 if (auto bshape{(*this)(call.arguments()[1])}) {
684 if (ashape->size() == 1 && bshape->size() == 2) {
685 bshape->erase(bshape->begin());
686 return std::move(*bshape); // matmul(vector, matrix)
687 } else if (ashape->size() == 2 && bshape->size() == 1) {
688 ashape->pop_back();
689 return std::move(*ashape); // matmul(matrix, vector)
690 } else if (ashape->size() == 2 && bshape->size() == 2) {
691 (*ashape)[1] = std::move((*bshape)[1]);
692 return std::move(*ashape); // matmul(matrix, matrix)
693 }
694 }
695 }
696 }
697 } else if (intrinsic->name == "reshape") {
698 if (call.arguments().size() >= 2 && call.arguments().at(1)) {
699 // SHAPE(RESHAPE(array,shape)) -> shape
700 if (const auto *shapeExpr{
701 call.arguments().at(1).value().UnwrapExpr()}) {
702 auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
703 return AsShape(ConvertToType<ExtentType>(std::move(shape)));
704 }
705 }
706 } else if (intrinsic->name == "pack") {
707 if (call.arguments().size() >= 3 && call.arguments().at(2)) {
708 // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
709 return (*this)(call.arguments().at(2));
710 } else if (call.arguments().size() >= 2 && context_) {
711 if (auto maskShape{(*this)(call.arguments().at(1))}) {
712 if (maskShape->size() == 0) {
713 // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
714 if (auto arrayShape{(*this)(call.arguments().at(0))}) {
715 auto arraySize{GetSize(std::move(*arrayShape))};
716 CHECK(arraySize);
717 ActualArguments toMerge{
718 ActualArgument{AsGenericExpr(std::move(*arraySize))},
719 ActualArgument{AsGenericExpr(ExtentExpr{0})},
720 common::Clone(call.arguments().at(1))};
721 auto specific{context_->intrinsics().Probe(
722 CallCharacteristics{"merge"}, toMerge, *context_)};
723 CHECK(specific);
724 return Shape{ExtentExpr{FunctionRef<ExtentType>{
725 ProcedureDesignator{std::move(specific->specificIntrinsic)},
726 std::move(specific->arguments)}}};
727 }
728 } else {
729 // Non-scalar MASK= -> [COUNT(mask)]
730 ActualArguments toCount{ActualArgument{common::Clone(
731 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
732 auto specific{context_->intrinsics().Probe(
733 CallCharacteristics{"count"}, toCount, *context_)};
734 CHECK(specific);
735 return Shape{ExtentExpr{FunctionRef<ExtentType>{
736 ProcedureDesignator{std::move(specific->specificIntrinsic)},
737 std::move(specific->arguments)}}};
738 }
739 }
740 }
741 } else if (intrinsic->name == "spread") {
742 // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
743 // at position DIM.
744 if (call.arguments().size() == 3) {
745 auto arrayShape{
746 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
747 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
748 const auto *nCopies{
749 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
750 if (arrayShape && dimArg && nCopies) {
751 if (auto dim{ToInt64(*dimArg)}) {
752 if (*dim >= 1 &&
753 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
754 arrayShape->emplace(arrayShape->begin() + *dim - 1,
755 ConvertToType<ExtentType>(common::Clone(*nCopies)));
756 return std::move(*arrayShape);
757 }
758 }
759 }
760 }
761 } else if (intrinsic->name == "transfer") {
762 if (call.arguments().size() == 3 && call.arguments().at(2)) {
763 // SIZE= is present; shape is vector [SIZE=]
764 if (const auto *size{
765 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
766 return Shape{
767 MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
768 }
769 } else if (context_) {
770 if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize(
771 call.arguments().at(1), *context_)}) {
772 if (GetRank(moldTypeAndShape->shape()) == 0) {
773 // SIZE= is absent and MOLD= is scalar: result is scalar
774 return ScalarShape();
775 } else {
776 // SIZE= is absent and MOLD= is array: result is vector whose
777 // length is determined by sizes of types. See 16.9.193p4 case(ii).
778 if (auto sourceTypeAndShape{
779 characteristics::TypeAndShape::Characterize(
780 call.arguments().at(0), *context_)}) {
781 auto sourceBytes{
782 sourceTypeAndShape->MeasureSizeInBytes(*context_)};
783 auto moldElementBytes{
784 moldTypeAndShape->MeasureElementSizeInBytes(*context_, true)};
785 if (sourceBytes && moldElementBytes) {
786 ExtentExpr extent{Fold(*context_,
787 (std::move(*sourceBytes) +
788 common::Clone(*moldElementBytes) - ExtentExpr{1}) /
789 common::Clone(*moldElementBytes))};
790 return Shape{MaybeExtentExpr{std::move(extent)}};
791 }
792 }
793 }
794 }
795 }
796 } else if (intrinsic->name == "transpose") {
797 if (call.arguments().size() >= 1) {
798 if (auto shape{(*this)(call.arguments().at(0))}) {
799 if (shape->size() == 2) {
800 std::swap((*shape)[0], (*shape)[1]);
801 return shape;
802 }
803 }
804 }
805 } else if (intrinsic->name == "unpack") {
806 if (call.arguments().size() >= 2) {
807 return (*this)(call.arguments()[1]); // MASK=
808 }
809 } else if (intrinsic->characteristics.value().attrs.test(characteristics::
810 Procedure::Attr::NullPointer)) { // NULL(MOLD=)
811 return (*this)(call.arguments());
812 } else {
813 // TODO: shapes of other non-elemental intrinsic results
814 }
815 }
816 return std::nullopt;
817 }
818
819 // Check conformance of the passed shapes.
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,CheckConformanceFlags::Flags flags,const char * leftIs,const char * rightIs)820 std::optional<bool> CheckConformance(parser::ContextualMessages &messages,
821 const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags,
822 const char *leftIs, const char *rightIs) {
823 int n{GetRank(left)};
824 if (n == 0 && (flags & CheckConformanceFlags::LeftScalarExpandable)) {
825 return true;
826 }
827 int rn{GetRank(right)};
828 if (rn == 0 && (flags & CheckConformanceFlags::RightScalarExpandable)) {
829 return true;
830 }
831 if (n != rn) {
832 messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
833 leftIs, n, rightIs, rn);
834 return false;
835 }
836 for (int j{0}; j < n; ++j) {
837 if (auto leftDim{ToInt64(left[j])}) {
838 if (auto rightDim{ToInt64(right[j])}) {
839 if (*leftDim != *rightDim) {
840 messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
841 "but %4$s has extent %5$jd"_err_en_US,
842 j + 1, leftIs, *leftDim, rightIs, *rightDim);
843 return false;
844 }
845 } else if (!(flags & CheckConformanceFlags::RightIsDeferredShape)) {
846 return std::nullopt;
847 }
848 } else if (!(flags & CheckConformanceFlags::LeftIsDeferredShape)) {
849 return std::nullopt;
850 }
851 }
852 return true;
853 }
854
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)855 bool IncrementSubscripts(
856 ConstantSubscripts &indices, const ConstantSubscripts &extents) {
857 std::size_t rank(indices.size());
858 CHECK(rank <= extents.size());
859 for (std::size_t j{0}; j < rank; ++j) {
860 if (extents[j] < 1) {
861 return false;
862 }
863 }
864 for (std::size_t j{0}; j < rank; ++j) {
865 if (indices[j]++ < extents[j]) {
866 return true;
867 }
868 indices[j] = 1;
869 }
870 return false;
871 }
872
873 } // namespace Fortran::evaluate
874