1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/intrinsics.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/type.h"
17 #include "flang/Parser/message.h"
18 #include "flang/Semantics/symbol.h"
19 #include <functional>
20
21 using namespace std::placeholders; // _1, _2, &c. for std::bind()
22
23 namespace Fortran::evaluate {
24
IsImpliedShape(const Symbol & original)25 bool IsImpliedShape(const Symbol &original) {
26 const Symbol &symbol{ResolveAssociations(original)};
27 const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
28 return details && symbol.attrs().test(semantics::Attr::PARAMETER) &&
29 details->shape().IsImpliedShape();
30 }
31
IsExplicitShape(const Symbol & original)32 bool IsExplicitShape(const Symbol &original) {
33 const Symbol &symbol{ResolveAssociations(original)};
34 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
35 const auto &shape{details->shape()};
36 return shape.Rank() == 0 ||
37 shape.IsExplicitShape(); // true when scalar, too
38 } else {
39 return symbol
40 .has<semantics::AssocEntityDetails>(); // exprs have explicit shape
41 }
42 }
43
ConstantShape(const Constant<ExtentType> & arrayConstant)44 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) {
45 CHECK(arrayConstant.Rank() == 1);
46 Shape result;
47 std::size_t dimensions{arrayConstant.size()};
48 for (std::size_t j{0}; j < dimensions; ++j) {
49 Scalar<ExtentType> extent{arrayConstant.values().at(j)};
50 result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}});
51 }
52 return result;
53 }
54
AsShape(ExtentExpr && arrayExpr) const55 auto GetShapeHelper::AsShape(ExtentExpr &&arrayExpr) const -> Result {
56 if (context_) {
57 arrayExpr = Fold(*context_, std::move(arrayExpr));
58 }
59 if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
60 return ConstantShape(*constArray);
61 }
62 if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
63 Shape result;
64 for (auto &value : *constructor) {
65 if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
66 if (expr->Rank() == 0) {
67 result.emplace_back(std::move(*expr));
68 continue;
69 }
70 }
71 return std::nullopt;
72 }
73 return result;
74 }
75 return std::nullopt;
76 }
77
CreateShape(int rank,NamedEntity & base)78 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) {
79 Shape shape;
80 for (int dimension{0}; dimension < rank; ++dimension) {
81 shape.emplace_back(GetExtent(base, dimension));
82 }
83 return shape;
84 }
85
AsExtentArrayExpr(const Shape & shape)86 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
87 ArrayConstructorValues<ExtentType> values;
88 for (const auto &dim : shape) {
89 if (dim) {
90 values.Push(common::Clone(*dim));
91 } else {
92 return std::nullopt;
93 }
94 }
95 return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
96 }
97
AsConstantShape(FoldingContext & context,const Shape & shape)98 std::optional<Constant<ExtentType>> AsConstantShape(
99 FoldingContext &context, const Shape &shape) {
100 if (auto shapeArray{AsExtentArrayExpr(shape)}) {
101 auto folded{Fold(context, std::move(*shapeArray))};
102 if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
103 return std::move(*p);
104 }
105 }
106 return std::nullopt;
107 }
108
AsConstantShape(const ConstantSubscripts & shape)109 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
110 using IntType = Scalar<SubscriptInteger>;
111 std::vector<IntType> result;
112 for (auto dim : shape) {
113 result.emplace_back(dim);
114 }
115 return {std::move(result), ConstantSubscripts{GetRank(shape)}};
116 }
117
AsConstantExtents(const Constant<ExtentType> & shape)118 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
119 ConstantSubscripts result;
120 for (const auto &extent : shape.values()) {
121 result.push_back(extent.ToInt64());
122 }
123 return result;
124 }
125
AsConstantExtents(FoldingContext & context,const Shape & shape)126 std::optional<ConstantSubscripts> AsConstantExtents(
127 FoldingContext &context, const Shape &shape) {
128 if (auto shapeConstant{AsConstantShape(context, shape)}) {
129 return AsConstantExtents(*shapeConstant);
130 } else {
131 return std::nullopt;
132 }
133 }
134
AsShape(const ConstantSubscripts & shape)135 Shape AsShape(const ConstantSubscripts &shape) {
136 Shape result;
137 for (const auto &extent : shape) {
138 result.emplace_back(ExtentExpr{extent});
139 }
140 return result;
141 }
142
AsShape(const std::optional<ConstantSubscripts> & shape)143 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &shape) {
144 if (shape) {
145 return AsShape(*shape);
146 } else {
147 return std::nullopt;
148 }
149 }
150
Fold(FoldingContext & context,Shape && shape)151 Shape Fold(FoldingContext &context, Shape &&shape) {
152 for (auto &dim : shape) {
153 dim = Fold(context, std::move(dim));
154 }
155 return std::move(shape);
156 }
157
Fold(FoldingContext & context,std::optional<Shape> && shape)158 std::optional<Shape> Fold(
159 FoldingContext &context, std::optional<Shape> &&shape) {
160 if (shape) {
161 return Fold(context, std::move(*shape));
162 } else {
163 return std::nullopt;
164 }
165 }
166
ComputeTripCount(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)167 static ExtentExpr ComputeTripCount(
168 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
169 ExtentExpr strideCopy{common::Clone(stride)};
170 ExtentExpr span{
171 (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
172 std::move(stride)};
173 return ExtentExpr{
174 Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
175 }
176
CountTrips(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)177 ExtentExpr CountTrips(
178 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
179 return ComputeTripCount(
180 std::move(lower), std::move(upper), std::move(stride));
181 }
182
CountTrips(const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)183 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
184 const ExtentExpr &stride) {
185 return ComputeTripCount(
186 common::Clone(lower), common::Clone(upper), common::Clone(stride));
187 }
188
CountTrips(MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)189 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper,
190 MaybeExtentExpr &&stride) {
191 std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
192 std::bind(ComputeTripCount, _1, _2, _3)};
193 return common::MapOptional(
194 std::move(bound), std::move(lower), std::move(upper), std::move(stride));
195 }
196
GetSize(Shape && shape)197 MaybeExtentExpr GetSize(Shape &&shape) {
198 ExtentExpr extent{1};
199 for (auto &&dim : std::move(shape)) {
200 if (dim) {
201 extent = std::move(extent) * std::move(*dim);
202 } else {
203 return std::nullopt;
204 }
205 }
206 return extent;
207 }
208
GetSize(const ConstantSubscripts & shape)209 ConstantSubscript GetSize(const ConstantSubscripts &shape) {
210 ConstantSubscript size{1};
211 for (auto dim : std::move(shape)) {
212 size *= dim;
213 }
214 return size;
215 }
216
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)217 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
218 struct MyVisitor : public AnyTraverse<MyVisitor> {
219 using Base = AnyTraverse<MyVisitor>;
220 MyVisitor() : Base{*this} {}
221 using Base::operator();
222 bool operator()(const ImpliedDoIndex &) { return true; }
223 };
224 return MyVisitor{}(expr);
225 }
226
227 // Determines lower bound on a dimension. This can be other than 1 only
228 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
229 // ASSOCIATE construct entities may require traversal of their referents.
230 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> {
231 public:
232 using Result = ExtentExpr;
233 using Base = Traverse<GetLowerBoundHelper, ExtentExpr>;
234 using Base::operator();
GetLowerBoundHelper(int d)235 explicit GetLowerBoundHelper(int d) : Base{*this}, dimension_{d} {}
Default()236 static ExtentExpr Default() { return ExtentExpr{1}; }
Combine(Result &&,Result &&)237 static ExtentExpr Combine(Result &&, Result &&) { return Default(); }
238 ExtentExpr operator()(const Symbol &);
239 ExtentExpr operator()(const Component &);
240
241 private:
242 int dimension_;
243 };
244
operator ()(const Symbol & symbol0)245 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
246 const Symbol &symbol{symbol0.GetUltimate()};
247 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
248 int j{0};
249 for (const auto &shapeSpec : details->shape()) {
250 if (j++ == dimension_) {
251 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
252 return *bound;
253 } else if (IsDescriptor(symbol)) {
254 return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
255 DescriptorInquiry::Field::LowerBound, dimension_}};
256 } else {
257 break;
258 }
259 }
260 }
261 } else if (const auto *assoc{
262 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
263 return (*this)(assoc->expr());
264 }
265 return Default();
266 }
267
operator ()(const Component & component)268 auto GetLowerBoundHelper::operator()(const Component &component) -> Result {
269 if (component.base().Rank() == 0) {
270 const Symbol &symbol{component.GetLastSymbol().GetUltimate()};
271 if (const auto *details{
272 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
273 int j{0};
274 for (const auto &shapeSpec : details->shape()) {
275 if (j++ == dimension_) {
276 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
277 return *bound;
278 } else if (IsDescriptor(symbol)) {
279 return ExtentExpr{
280 DescriptorInquiry{NamedEntity{common::Clone(component)},
281 DescriptorInquiry::Field::LowerBound, dimension_}};
282 } else {
283 break;
284 }
285 }
286 }
287 }
288 }
289 return Default();
290 }
291
GetLowerBound(const NamedEntity & base,int dimension)292 ExtentExpr GetLowerBound(const NamedEntity &base, int dimension) {
293 return GetLowerBoundHelper{dimension}(base);
294 }
295
GetLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)296 ExtentExpr GetLowerBound(
297 FoldingContext &context, const NamedEntity &base, int dimension) {
298 return Fold(context, GetLowerBound(base, dimension));
299 }
300
GetLowerBounds(const NamedEntity & base)301 Shape GetLowerBounds(const NamedEntity &base) {
302 Shape result;
303 int rank{base.Rank()};
304 for (int dim{0}; dim < rank; ++dim) {
305 result.emplace_back(GetLowerBound(base, dim));
306 }
307 return result;
308 }
309
GetLowerBounds(FoldingContext & context,const NamedEntity & base)310 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) {
311 Shape result;
312 int rank{base.Rank()};
313 for (int dim{0}; dim < rank; ++dim) {
314 result.emplace_back(GetLowerBound(context, base, dim));
315 }
316 return result;
317 }
318
GetExtent(const NamedEntity & base,int dimension)319 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
320 CHECK(dimension >= 0);
321 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
322 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
323 if (IsImpliedShape(symbol) && details->init()) {
324 if (auto shape{GetShape(symbol)}) {
325 if (dimension < static_cast<int>(shape->size())) {
326 return std::move(shape->at(dimension));
327 }
328 }
329 } else {
330 int j{0};
331 for (const auto &shapeSpec : details->shape()) {
332 if (j++ == dimension) {
333 if (shapeSpec.ubound().isExplicit()) {
334 if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
335 if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
336 return common::Clone(ubound.value()) -
337 common::Clone(lbound.value()) + ExtentExpr{1};
338 } else {
339 return ubound.value();
340 }
341 }
342 } else if (details->IsAssumedSize() && j == symbol.Rank()) {
343 return std::nullopt;
344 } else if (semantics::IsDescriptor(symbol)) {
345 return ExtentExpr{DescriptorInquiry{NamedEntity{base},
346 DescriptorInquiry::Field::Extent, dimension}};
347 }
348 }
349 }
350 }
351 } else if (const auto *assoc{
352 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
353 if (auto shape{GetShape(assoc->expr())}) {
354 if (dimension < static_cast<int>(shape->size())) {
355 return std::move(shape->at(dimension));
356 }
357 }
358 }
359 return std::nullopt;
360 }
361
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)362 MaybeExtentExpr GetExtent(
363 FoldingContext &context, const NamedEntity &base, int dimension) {
364 return Fold(context, GetExtent(base, dimension));
365 }
366
GetExtent(const Subscript & subscript,const NamedEntity & base,int dimension)367 MaybeExtentExpr GetExtent(
368 const Subscript &subscript, const NamedEntity &base, int dimension) {
369 return std::visit(
370 common::visitors{
371 [&](const Triplet &triplet) -> MaybeExtentExpr {
372 MaybeExtentExpr upper{triplet.upper()};
373 if (!upper) {
374 upper = GetUpperBound(base, dimension);
375 }
376 MaybeExtentExpr lower{triplet.lower()};
377 if (!lower) {
378 lower = GetLowerBound(base, dimension);
379 }
380 return CountTrips(std::move(lower), std::move(upper),
381 MaybeExtentExpr{triplet.stride()});
382 },
383 [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
384 if (auto shape{GetShape(subs.value())}) {
385 if (GetRank(*shape) > 0) {
386 CHECK(GetRank(*shape) == 1); // vector-valued subscript
387 return std::move(shape->at(0));
388 }
389 }
390 return std::nullopt;
391 },
392 },
393 subscript.u);
394 }
395
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)396 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
397 const NamedEntity &base, int dimension) {
398 return Fold(context, GetExtent(subscript, base, dimension));
399 }
400
ComputeUpperBound(ExtentExpr && lower,MaybeExtentExpr && extent)401 MaybeExtentExpr ComputeUpperBound(
402 ExtentExpr &&lower, MaybeExtentExpr &&extent) {
403 if (extent) {
404 return std::move(*extent) + std::move(lower) - ExtentExpr{1};
405 } else {
406 return std::nullopt;
407 }
408 }
409
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)410 MaybeExtentExpr ComputeUpperBound(
411 FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
412 return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent)));
413 }
414
GetUpperBound(const NamedEntity & base,int dimension)415 MaybeExtentExpr GetUpperBound(const NamedEntity &base, int dimension) {
416 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
417 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
418 int j{0};
419 for (const auto &shapeSpec : details->shape()) {
420 if (j++ == dimension) {
421 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
422 return *bound;
423 } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
424 break;
425 } else {
426 return ComputeUpperBound(
427 GetLowerBound(base, dimension), GetExtent(base, dimension));
428 }
429 }
430 }
431 } else if (const auto *assoc{
432 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
433 if (auto shape{GetShape(assoc->expr())}) {
434 if (dimension < static_cast<int>(shape->size())) {
435 return ComputeUpperBound(
436 GetLowerBound(base, dimension), std::move(shape->at(dimension)));
437 }
438 }
439 }
440 return std::nullopt;
441 }
442
GetUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)443 MaybeExtentExpr GetUpperBound(
444 FoldingContext &context, const NamedEntity &base, int dimension) {
445 return Fold(context, GetUpperBound(base, dimension));
446 }
447
GetUpperBounds(const NamedEntity & base)448 Shape GetUpperBounds(const NamedEntity &base) {
449 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
450 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
451 Shape result;
452 int dim{0};
453 for (const auto &shapeSpec : details->shape()) {
454 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
455 result.push_back(*bound);
456 } else if (details->IsAssumedSize()) {
457 CHECK(dim + 1 == base.Rank());
458 result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
459 } else {
460 result.emplace_back(
461 ComputeUpperBound(GetLowerBound(base, dim), GetExtent(base, dim)));
462 }
463 ++dim;
464 }
465 CHECK(GetRank(result) == symbol.Rank());
466 return result;
467 } else {
468 return std::move(GetShape(symbol).value());
469 }
470 }
471
GetUpperBounds(FoldingContext & context,const NamedEntity & base)472 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
473 return Fold(context, GetUpperBounds(base));
474 }
475
operator ()(const Symbol & symbol) const476 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
477 return std::visit(
478 common::visitors{
479 [&](const semantics::ObjectEntityDetails &object) {
480 if (IsImpliedShape(symbol) && object.init()) {
481 return (*this)(object.init());
482 } else {
483 int n{object.shape().Rank()};
484 NamedEntity base{symbol};
485 return Result{CreateShape(n, base)};
486 }
487 },
488 [](const semantics::EntityDetails &) {
489 return ScalarShape(); // no dimensions seen
490 },
491 [&](const semantics::ProcEntityDetails &proc) {
492 if (const Symbol * interface{proc.interface().symbol()}) {
493 return (*this)(*interface);
494 } else {
495 return ScalarShape();
496 }
497 },
498 [&](const semantics::AssocEntityDetails &assoc) {
499 if (!assoc.rank()) {
500 return (*this)(assoc.expr());
501 } else {
502 int n{assoc.rank().value()};
503 NamedEntity base{symbol};
504 return Result{CreateShape(n, base)};
505 }
506 },
507 [&](const semantics::SubprogramDetails &subp) {
508 if (subp.isFunction()) {
509 return (*this)(subp.result());
510 } else {
511 return Result{};
512 }
513 },
514 [&](const semantics::ProcBindingDetails &binding) {
515 return (*this)(binding.symbol());
516 },
517 [](const semantics::TypeParamDetails &) { return ScalarShape(); },
518 [](const auto &) { return Result{}; },
519 },
520 symbol.GetUltimate().details());
521 }
522
operator ()(const Component & component) const523 auto GetShapeHelper::operator()(const Component &component) const -> Result {
524 const Symbol &symbol{component.GetLastSymbol()};
525 int rank{symbol.Rank()};
526 if (rank == 0) {
527 return (*this)(component.base());
528 } else if (symbol.has<semantics::ObjectEntityDetails>()) {
529 NamedEntity base{Component{component}};
530 return CreateShape(rank, base);
531 } else if (symbol.has<semantics::AssocEntityDetails>()) {
532 NamedEntity base{Component{component}};
533 return Result{CreateShape(rank, base)};
534 } else {
535 return (*this)(symbol);
536 }
537 }
538
operator ()(const ArrayRef & arrayRef) const539 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
540 Shape shape;
541 int dimension{0};
542 const NamedEntity &base{arrayRef.base()};
543 for (const Subscript &ss : arrayRef.subscript()) {
544 if (ss.Rank() > 0) {
545 shape.emplace_back(GetExtent(ss, base, dimension));
546 }
547 ++dimension;
548 }
549 if (shape.empty()) {
550 if (const Component * component{base.UnwrapComponent()}) {
551 return (*this)(component->base());
552 }
553 }
554 return shape;
555 }
556
operator ()(const CoarrayRef & coarrayRef) const557 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
558 NamedEntity base{coarrayRef.GetBase()};
559 if (coarrayRef.subscript().empty()) {
560 return (*this)(base);
561 } else {
562 Shape shape;
563 int dimension{0};
564 for (const Subscript &ss : coarrayRef.subscript()) {
565 if (ss.Rank() > 0) {
566 shape.emplace_back(GetExtent(ss, base, dimension));
567 }
568 ++dimension;
569 }
570 return shape;
571 }
572 }
573
operator ()(const Substring & substring) const574 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
575 return (*this)(substring.parent());
576 }
577
operator ()(const ProcedureRef & call) const578 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
579 if (call.Rank() == 0) {
580 return ScalarShape();
581 } else if (call.IsElemental()) {
582 for (const auto &arg : call.arguments()) {
583 if (arg && arg->Rank() > 0) {
584 return (*this)(*arg);
585 }
586 }
587 return ScalarShape();
588 } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
589 return (*this)(*symbol);
590 } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
591 if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
592 intrinsic->name == "ubound") {
593 // These are the array-valued cases for LBOUND and UBOUND (no DIM=).
594 const auto *expr{call.arguments().front().value().UnwrapExpr()};
595 CHECK(expr);
596 return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
597 } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
598 intrinsic->name == "count" || intrinsic->name == "iall" ||
599 intrinsic->name == "iany" || intrinsic->name == "iparity" ||
600 intrinsic->name == "maxval" || intrinsic->name == "minval" ||
601 intrinsic->name == "norm2" || intrinsic->name == "parity" ||
602 intrinsic->name == "product" || intrinsic->name == "sum") {
603 // Reduction with DIM=
604 if (call.arguments().size() >= 2) {
605 auto arrayShape{
606 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
607 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
608 if (arrayShape && dimArg) {
609 if (auto dim{ToInt64(*dimArg)}) {
610 if (*dim >= 1 &&
611 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
612 arrayShape->erase(arrayShape->begin() + (*dim - 1));
613 return std::move(*arrayShape);
614 }
615 }
616 }
617 }
618 } else if (intrinsic->name == "maxloc" || intrinsic->name == "minloc") {
619 // TODO: FINDLOC
620 if (call.arguments().size() >= 2) {
621 if (auto arrayShape{
622 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
623 auto rank{static_cast<int>(arrayShape->size())};
624 if (const auto *dimArg{
625 UnwrapExpr<Expr<SomeType>>(call.arguments()[1])}) {
626 auto dim{ToInt64(*dimArg)};
627 if (dim && *dim >= 1 && *dim <= rank) {
628 arrayShape->erase(arrayShape->begin() + (*dim - 1));
629 return std::move(*arrayShape);
630 }
631 } else {
632 // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=))
633 return Shape{ExtentExpr{rank}};
634 }
635 }
636 }
637 } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
638 if (!call.arguments().empty()) {
639 return (*this)(call.arguments()[0]);
640 }
641 } else if (intrinsic->name == "matmul") {
642 if (call.arguments().size() == 2) {
643 if (auto ashape{(*this)(call.arguments()[0])}) {
644 if (auto bshape{(*this)(call.arguments()[1])}) {
645 if (ashape->size() == 1 && bshape->size() == 2) {
646 bshape->erase(bshape->begin());
647 return std::move(*bshape); // matmul(vector, matrix)
648 } else if (ashape->size() == 2 && bshape->size() == 1) {
649 ashape->pop_back();
650 return std::move(*ashape); // matmul(matrix, vector)
651 } else if (ashape->size() == 2 && bshape->size() == 2) {
652 (*ashape)[1] = std::move((*bshape)[1]);
653 return std::move(*ashape); // matmul(matrix, matrix)
654 }
655 }
656 }
657 }
658 } else if (intrinsic->name == "reshape") {
659 if (call.arguments().size() >= 2 && call.arguments().at(1)) {
660 // SHAPE(RESHAPE(array,shape)) -> shape
661 if (const auto *shapeExpr{
662 call.arguments().at(1).value().UnwrapExpr()}) {
663 auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
664 return AsShape(ConvertToType<ExtentType>(std::move(shape)));
665 }
666 }
667 } else if (intrinsic->name == "pack") {
668 if (call.arguments().size() >= 3 && call.arguments().at(2)) {
669 // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
670 return (*this)(call.arguments().at(2));
671 } else if (call.arguments().size() >= 2 && context_) {
672 if (auto maskShape{(*this)(call.arguments().at(1))}) {
673 if (maskShape->size() == 0) {
674 // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
675 if (auto arrayShape{(*this)(call.arguments().at(0))}) {
676 auto arraySize{GetSize(std::move(*arrayShape))};
677 CHECK(arraySize);
678 ActualArguments toMerge{
679 ActualArgument{AsGenericExpr(std::move(*arraySize))},
680 ActualArgument{AsGenericExpr(ExtentExpr{0})},
681 common::Clone(call.arguments().at(1))};
682 auto specific{context_->intrinsics().Probe(
683 CallCharacteristics{"merge"}, toMerge, *context_)};
684 CHECK(specific);
685 return Shape{ExtentExpr{FunctionRef<ExtentType>{
686 ProcedureDesignator{std::move(specific->specificIntrinsic)},
687 std::move(specific->arguments)}}};
688 }
689 } else {
690 // Non-scalar MASK= -> [COUNT(mask)]
691 ActualArguments toCount{ActualArgument{common::Clone(
692 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
693 auto specific{context_->intrinsics().Probe(
694 CallCharacteristics{"count"}, toCount, *context_)};
695 CHECK(specific);
696 return Shape{ExtentExpr{FunctionRef<ExtentType>{
697 ProcedureDesignator{std::move(specific->specificIntrinsic)},
698 std::move(specific->arguments)}}};
699 }
700 }
701 }
702 } else if (intrinsic->name == "spread") {
703 // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
704 // at position DIM.
705 if (call.arguments().size() == 3) {
706 auto arrayShape{
707 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
708 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
709 const auto *nCopies{
710 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
711 if (arrayShape && dimArg && nCopies) {
712 if (auto dim{ToInt64(*dimArg)}) {
713 if (*dim >= 1 &&
714 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
715 arrayShape->emplace(arrayShape->begin() + *dim - 1,
716 ConvertToType<ExtentType>(common::Clone(*nCopies)));
717 return std::move(*arrayShape);
718 }
719 }
720 }
721 }
722 } else if (intrinsic->name == "transfer") {
723 if (call.arguments().size() == 3 && call.arguments().at(2)) {
724 // SIZE= is present; shape is vector [SIZE=]
725 if (const auto *size{
726 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
727 return Shape{
728 MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
729 }
730 } else if (context_) {
731 if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize(
732 call.arguments().at(1), *context_)}) {
733 if (GetRank(moldTypeAndShape->shape()) == 0) {
734 // SIZE= is absent and MOLD= is scalar: result is scalar
735 return ScalarShape();
736 } else {
737 // SIZE= is absent and MOLD= is array: result is vector whose
738 // length is determined by sizes of types. See 16.9.193p4 case(ii).
739 if (auto sourceTypeAndShape{
740 characteristics::TypeAndShape::Characterize(
741 call.arguments().at(0), *context_)}) {
742 auto sourceBytes{
743 sourceTypeAndShape->MeasureSizeInBytes(*context_)};
744 auto moldElementBytes{
745 moldTypeAndShape->MeasureElementSizeInBytes(*context_, true)};
746 if (sourceBytes && moldElementBytes) {
747 ExtentExpr extent{Fold(*context_,
748 (std::move(*sourceBytes) +
749 common::Clone(*moldElementBytes) - ExtentExpr{1}) /
750 common::Clone(*moldElementBytes))};
751 return Shape{MaybeExtentExpr{std::move(extent)}};
752 }
753 }
754 }
755 }
756 }
757 } else if (intrinsic->name == "transpose") {
758 if (call.arguments().size() >= 1) {
759 if (auto shape{(*this)(call.arguments().at(0))}) {
760 if (shape->size() == 2) {
761 std::swap((*shape)[0], (*shape)[1]);
762 return shape;
763 }
764 }
765 }
766 } else if (intrinsic->name == "unpack") {
767 if (call.arguments().size() >= 2) {
768 return (*this)(call.arguments()[1]); // MASK=
769 }
770 } else if (intrinsic->characteristics.value().attrs.test(characteristics::
771 Procedure::Attr::NullPointer)) { // NULL(MOLD=)
772 return (*this)(call.arguments());
773 } else {
774 // TODO: shapes of other non-elemental intrinsic results
775 }
776 }
777 return std::nullopt;
778 }
779
780 // Check conformance of the passed shapes.
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,CheckConformanceFlags::Flags flags,const char * leftIs,const char * rightIs)781 std::optional<bool> CheckConformance(parser::ContextualMessages &messages,
782 const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags,
783 const char *leftIs, const char *rightIs) {
784 int n{GetRank(left)};
785 if (n == 0 && (flags & CheckConformanceFlags::LeftScalarExpandable)) {
786 return true;
787 }
788 int rn{GetRank(right)};
789 if (rn == 0 && (flags & CheckConformanceFlags::RightScalarExpandable)) {
790 return true;
791 }
792 if (n != rn) {
793 messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
794 leftIs, n, rightIs, rn);
795 return false;
796 }
797 for (int j{0}; j < n; ++j) {
798 if (auto leftDim{ToInt64(left[j])}) {
799 if (auto rightDim{ToInt64(right[j])}) {
800 if (*leftDim != *rightDim) {
801 messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
802 "but %4$s has extent %5$jd"_err_en_US,
803 j + 1, leftIs, *leftDim, rightIs, *rightDim);
804 return false;
805 }
806 } else if (!(flags & CheckConformanceFlags::RightIsDeferredShape)) {
807 return std::nullopt;
808 }
809 } else if (!(flags & CheckConformanceFlags::LeftIsDeferredShape)) {
810 return std::nullopt;
811 }
812 }
813 return true;
814 }
815
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)816 bool IncrementSubscripts(
817 ConstantSubscripts &indices, const ConstantSubscripts &extents) {
818 std::size_t rank(indices.size());
819 CHECK(rank <= extents.size());
820 for (std::size_t j{0}; j < rank; ++j) {
821 if (extents[j] < 1) {
822 return false;
823 }
824 }
825 for (std::size_t j{0}; j < rank; ++j) {
826 if (indices[j]++ < extents[j]) {
827 return true;
828 }
829 indices[j] = 1;
830 }
831 return false;
832 }
833
834 } // namespace Fortran::evaluate
835