1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/intrinsics.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/type.h"
17 #include "flang/Parser/message.h"
18 #include "flang/Semantics/symbol.h"
19 #include <functional>
20
21 using namespace std::placeholders; // _1, _2, &c. for std::bind()
22
23 namespace Fortran::evaluate {
24
IsImpliedShape(const Symbol & symbol0)25 bool IsImpliedShape(const Symbol &symbol0) {
26 const Symbol &symbol{ResolveAssociations(symbol0)};
27 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
28 if (symbol.attrs().test(semantics::Attr::PARAMETER) && details->init()) {
29 return details->shape().IsImpliedShape();
30 }
31 }
32 return false;
33 }
34
IsExplicitShape(const Symbol & symbol0)35 bool IsExplicitShape(const Symbol &symbol0) {
36 const Symbol &symbol{ResolveAssociations(symbol0)};
37 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
38 const auto &shape{details->shape()};
39 return shape.Rank() == 0 || shape.IsExplicitShape(); // even if scalar
40 } else {
41 return false;
42 }
43 }
44
AsShape(const Constant<ExtentType> & arrayConstant)45 Shape AsShape(const Constant<ExtentType> &arrayConstant) {
46 CHECK(arrayConstant.Rank() == 1);
47 Shape result;
48 std::size_t dimensions{arrayConstant.size()};
49 for (std::size_t j{0}; j < dimensions; ++j) {
50 Scalar<ExtentType> extent{arrayConstant.values().at(j)};
51 result.emplace_back(MaybeExtentExpr{ExtentExpr{extent}});
52 }
53 return result;
54 }
55
AsShape(FoldingContext & context,ExtentExpr && arrayExpr)56 std::optional<Shape> AsShape(FoldingContext &context, ExtentExpr &&arrayExpr) {
57 // Flatten any array expression into an array constructor if possible.
58 arrayExpr = Fold(context, std::move(arrayExpr));
59 if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
60 return AsShape(*constArray);
61 }
62 if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
63 Shape result;
64 for (auto &value : *constructor) {
65 if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
66 if (expr->Rank() == 0) {
67 result.emplace_back(std::move(*expr));
68 continue;
69 }
70 }
71 return std::nullopt;
72 }
73 return result;
74 }
75 return std::nullopt;
76 }
77
AsExtentArrayExpr(const Shape & shape)78 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
79 ArrayConstructorValues<ExtentType> values;
80 for (const auto &dim : shape) {
81 if (dim) {
82 values.Push(common::Clone(*dim));
83 } else {
84 return std::nullopt;
85 }
86 }
87 return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
88 }
89
AsConstantShape(FoldingContext & context,const Shape & shape)90 std::optional<Constant<ExtentType>> AsConstantShape(
91 FoldingContext &context, const Shape &shape) {
92 if (auto shapeArray{AsExtentArrayExpr(shape)}) {
93 auto folded{Fold(context, std::move(*shapeArray))};
94 if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
95 return std::move(*p);
96 }
97 }
98 return std::nullopt;
99 }
100
AsConstantShape(const ConstantSubscripts & shape)101 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
102 using IntType = Scalar<SubscriptInteger>;
103 std::vector<IntType> result;
104 for (auto dim : shape) {
105 result.emplace_back(dim);
106 }
107 return {std::move(result), ConstantSubscripts{GetRank(shape)}};
108 }
109
AsConstantExtents(const Constant<ExtentType> & shape)110 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
111 ConstantSubscripts result;
112 for (const auto &extent : shape.values()) {
113 result.push_back(extent.ToInt64());
114 }
115 return result;
116 }
117
AsConstantExtents(FoldingContext & context,const Shape & shape)118 std::optional<ConstantSubscripts> AsConstantExtents(
119 FoldingContext &context, const Shape &shape) {
120 if (auto shapeConstant{AsConstantShape(context, shape)}) {
121 return AsConstantExtents(*shapeConstant);
122 } else {
123 return std::nullopt;
124 }
125 }
126
ComputeTripCount(FoldingContext & context,ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)127 static ExtentExpr ComputeTripCount(FoldingContext &context, ExtentExpr &&lower,
128 ExtentExpr &&upper, ExtentExpr &&stride) {
129 ExtentExpr strideCopy{common::Clone(stride)};
130 ExtentExpr span{
131 (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
132 std::move(stride)};
133 ExtentExpr extent{
134 Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
135 return Fold(context, std::move(extent));
136 }
137
CountTrips(FoldingContext & context,ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)138 ExtentExpr CountTrips(FoldingContext &context, ExtentExpr &&lower,
139 ExtentExpr &&upper, ExtentExpr &&stride) {
140 return ComputeTripCount(
141 context, std::move(lower), std::move(upper), std::move(stride));
142 }
143
CountTrips(FoldingContext & context,const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)144 ExtentExpr CountTrips(FoldingContext &context, const ExtentExpr &lower,
145 const ExtentExpr &upper, const ExtentExpr &stride) {
146 return ComputeTripCount(context, common::Clone(lower), common::Clone(upper),
147 common::Clone(stride));
148 }
149
CountTrips(FoldingContext & context,MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)150 MaybeExtentExpr CountTrips(FoldingContext &context, MaybeExtentExpr &&lower,
151 MaybeExtentExpr &&upper, MaybeExtentExpr &&stride) {
152 std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
153 std::bind(ComputeTripCount, context, _1, _2, _3)};
154 return common::MapOptional(
155 std::move(bound), std::move(lower), std::move(upper), std::move(stride));
156 }
157
GetSize(Shape && shape)158 MaybeExtentExpr GetSize(Shape &&shape) {
159 ExtentExpr extent{1};
160 for (auto &&dim : std::move(shape)) {
161 if (dim) {
162 extent = std::move(extent) * std::move(*dim);
163 } else {
164 return std::nullopt;
165 }
166 }
167 return extent;
168 }
169
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)170 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
171 struct MyVisitor : public AnyTraverse<MyVisitor> {
172 using Base = AnyTraverse<MyVisitor>;
173 MyVisitor() : Base{*this} {}
174 using Base::operator();
175 bool operator()(const ImpliedDoIndex &) { return true; }
176 };
177 return MyVisitor{}(expr);
178 }
179
180 // Determines lower bound on a dimension. This can be other than 1 only
181 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
182 // ASSOCIATE construct entities may require tranversal of their referents.
183 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> {
184 public:
185 using Result = ExtentExpr;
186 using Base = Traverse<GetLowerBoundHelper, ExtentExpr>;
187 using Base::operator();
GetLowerBoundHelper(FoldingContext & c,int d)188 GetLowerBoundHelper(FoldingContext &c, int d)
189 : Base{*this}, context_{c}, dimension_{d} {}
Default()190 static ExtentExpr Default() { return ExtentExpr{1}; }
Combine(Result &&,Result &&)191 static ExtentExpr Combine(Result &&, Result &&) { return Default(); }
192 ExtentExpr operator()(const Symbol &);
193 ExtentExpr operator()(const Component &);
194
195 private:
196 FoldingContext &context_;
197 int dimension_;
198 };
199
operator ()(const Symbol & symbol0)200 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
201 const Symbol &symbol{symbol0.GetUltimate()};
202 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
203 int j{0};
204 for (const auto &shapeSpec : details->shape()) {
205 if (j++ == dimension_) {
206 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
207 return Fold(context_, common::Clone(*bound));
208 } else if (IsDescriptor(symbol)) {
209 return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
210 DescriptorInquiry::Field::LowerBound, dimension_}};
211 } else {
212 break;
213 }
214 }
215 }
216 } else if (const auto *assoc{
217 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
218 return (*this)(assoc->expr());
219 }
220 return Default();
221 }
222
operator ()(const Component & component)223 auto GetLowerBoundHelper::operator()(const Component &component) -> Result {
224 if (component.base().Rank() == 0) {
225 const Symbol &symbol{component.GetLastSymbol().GetUltimate()};
226 if (const auto *details{
227 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
228 int j{0};
229 for (const auto &shapeSpec : details->shape()) {
230 if (j++ == dimension_) {
231 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
232 return Fold(context_, common::Clone(*bound));
233 } else if (IsDescriptor(symbol)) {
234 return ExtentExpr{
235 DescriptorInquiry{NamedEntity{common::Clone(component)},
236 DescriptorInquiry::Field::LowerBound, dimension_}};
237 } else {
238 break;
239 }
240 }
241 }
242 }
243 }
244 return Default();
245 }
246
GetLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)247 ExtentExpr GetLowerBound(
248 FoldingContext &context, const NamedEntity &base, int dimension) {
249 return GetLowerBoundHelper{context, dimension}(base);
250 }
251
GetLowerBounds(FoldingContext & context,const NamedEntity & base)252 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) {
253 Shape result;
254 int rank{base.Rank()};
255 for (int dim{0}; dim < rank; ++dim) {
256 result.emplace_back(GetLowerBound(context, base, dim));
257 }
258 return result;
259 }
260
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)261 MaybeExtentExpr GetExtent(
262 FoldingContext &context, const NamedEntity &base, int dimension) {
263 CHECK(dimension >= 0);
264 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
265 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
266 if (IsImpliedShape(symbol)) {
267 Shape shape{GetShape(context, symbol).value()};
268 return std::move(shape.at(dimension));
269 }
270 int j{0};
271 for (const auto &shapeSpec : details->shape()) {
272 if (j++ == dimension) {
273 if (shapeSpec.ubound().isExplicit()) {
274 if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
275 if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
276 return Fold(context,
277 common::Clone(ubound.value()) -
278 common::Clone(lbound.value()) + ExtentExpr{1});
279 } else {
280 return Fold(context, common::Clone(ubound.value()));
281 }
282 }
283 } else if (details->IsAssumedSize() && j == symbol.Rank()) {
284 return std::nullopt;
285 } else if (semantics::IsDescriptor(symbol)) {
286 return ExtentExpr{DescriptorInquiry{
287 NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
288 }
289 }
290 }
291 } else if (const auto *assoc{
292 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
293 if (auto shape{GetShape(context, assoc->expr())}) {
294 if (dimension < static_cast<int>(shape->size())) {
295 return std::move(shape->at(dimension));
296 }
297 }
298 }
299 return std::nullopt;
300 }
301
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)302 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
303 const NamedEntity &base, int dimension) {
304 return std::visit(
305 common::visitors{
306 [&](const Triplet &triplet) -> MaybeExtentExpr {
307 MaybeExtentExpr upper{triplet.upper()};
308 if (!upper) {
309 upper = GetUpperBound(context, base, dimension);
310 }
311 MaybeExtentExpr lower{triplet.lower()};
312 if (!lower) {
313 lower = GetLowerBound(context, base, dimension);
314 }
315 return CountTrips(context, std::move(lower), std::move(upper),
316 MaybeExtentExpr{triplet.stride()});
317 },
318 [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
319 if (auto shape{GetShape(context, subs.value())}) {
320 if (GetRank(*shape) > 0) {
321 CHECK(GetRank(*shape) == 1); // vector-valued subscript
322 return std::move(shape->at(0));
323 }
324 }
325 return std::nullopt;
326 },
327 },
328 subscript.u);
329 }
330
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)331 MaybeExtentExpr ComputeUpperBound(
332 FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
333 if (extent) {
334 return Fold(context, std::move(*extent) - std::move(lower) + ExtentExpr{1});
335 } else {
336 return std::nullopt;
337 }
338 }
339
GetUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)340 MaybeExtentExpr GetUpperBound(
341 FoldingContext &context, const NamedEntity &base, int dimension) {
342 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
343 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
344 int j{0};
345 for (const auto &shapeSpec : details->shape()) {
346 if (j++ == dimension) {
347 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
348 return Fold(context, common::Clone(*bound));
349 } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
350 break;
351 } else {
352 return ComputeUpperBound(context,
353 GetLowerBound(context, base, dimension),
354 GetExtent(context, base, dimension));
355 }
356 }
357 }
358 } else if (const auto *assoc{
359 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
360 if (auto shape{GetShape(context, assoc->expr())}) {
361 if (dimension < static_cast<int>(shape->size())) {
362 return ComputeUpperBound(context,
363 GetLowerBound(context, base, dimension),
364 std::move(shape->at(dimension)));
365 }
366 }
367 }
368 return std::nullopt;
369 }
370
GetUpperBounds(FoldingContext & context,const NamedEntity & base)371 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
372 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
373 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
374 Shape result;
375 int dim{0};
376 for (const auto &shapeSpec : details->shape()) {
377 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
378 result.emplace_back(Fold(context, common::Clone(*bound)));
379 } else if (details->IsAssumedSize()) {
380 CHECK(dim + 1 == base.Rank());
381 result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
382 } else {
383 result.emplace_back(ComputeUpperBound(context,
384 GetLowerBound(context, base, dim), GetExtent(context, base, dim)));
385 }
386 ++dim;
387 }
388 CHECK(GetRank(result) == symbol.Rank());
389 return result;
390 } else {
391 return std::move(GetShape(context, base).value());
392 }
393 }
394
operator ()(const Symbol & symbol) const395 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
396 return std::visit(
397 common::visitors{
398 [&](const semantics::ObjectEntityDetails &object) {
399 if (IsImpliedShape(symbol)) {
400 return (*this)(object.init());
401 } else {
402 int n{object.shape().Rank()};
403 NamedEntity base{symbol};
404 return Result{CreateShape(n, base)};
405 }
406 },
407 [](const semantics::EntityDetails &) {
408 return Scalar(); // no dimensions seen
409 },
410 [&](const semantics::ProcEntityDetails &proc) {
411 if (const Symbol * interface{proc.interface().symbol()}) {
412 return (*this)(*interface);
413 } else {
414 return Scalar();
415 }
416 },
417 [&](const semantics::AssocEntityDetails &assoc) {
418 if (!assoc.rank()) {
419 return (*this)(assoc.expr());
420 } else {
421 int n{assoc.rank().value()};
422 NamedEntity base{symbol};
423 return Result{CreateShape(n, base)};
424 }
425 },
426 [&](const semantics::SubprogramDetails &subp) {
427 if (subp.isFunction()) {
428 return (*this)(subp.result());
429 } else {
430 return Result{};
431 }
432 },
433 [&](const semantics::ProcBindingDetails &binding) {
434 return (*this)(binding.symbol());
435 },
436 [&](const semantics::UseDetails &use) {
437 return (*this)(use.symbol());
438 },
439 [&](const semantics::HostAssocDetails &assoc) {
440 return (*this)(assoc.symbol());
441 },
442 [](const semantics::TypeParamDetails &) { return Scalar(); },
443 [](const auto &) { return Result{}; },
444 },
445 symbol.details());
446 }
447
operator ()(const Component & component) const448 auto GetShapeHelper::operator()(const Component &component) const -> Result {
449 const Symbol &symbol{component.GetLastSymbol()};
450 int rank{symbol.Rank()};
451 if (rank == 0) {
452 return (*this)(component.base());
453 } else if (symbol.has<semantics::ObjectEntityDetails>()) {
454 NamedEntity base{Component{component}};
455 return CreateShape(rank, base);
456 } else if (symbol.has<semantics::AssocEntityDetails>()) {
457 NamedEntity base{Component{component}};
458 return Result{CreateShape(rank, base)};
459 } else {
460 return (*this)(symbol);
461 }
462 }
463
operator ()(const ArrayRef & arrayRef) const464 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
465 Shape shape;
466 int dimension{0};
467 const NamedEntity &base{arrayRef.base()};
468 for (const Subscript &ss : arrayRef.subscript()) {
469 if (ss.Rank() > 0) {
470 shape.emplace_back(GetExtent(context_, ss, base, dimension));
471 }
472 ++dimension;
473 }
474 if (shape.empty()) {
475 if (const Component * component{base.UnwrapComponent()}) {
476 return (*this)(component->base());
477 }
478 }
479 return shape;
480 }
481
operator ()(const CoarrayRef & coarrayRef) const482 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
483 NamedEntity base{coarrayRef.GetBase()};
484 if (coarrayRef.subscript().empty()) {
485 return (*this)(base);
486 } else {
487 Shape shape;
488 int dimension{0};
489 for (const Subscript &ss : coarrayRef.subscript()) {
490 if (ss.Rank() > 0) {
491 shape.emplace_back(GetExtent(context_, ss, base, dimension));
492 }
493 ++dimension;
494 }
495 return shape;
496 }
497 }
498
operator ()(const Substring & substring) const499 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
500 return (*this)(substring.parent());
501 }
502
operator ()(const ProcedureRef & call) const503 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
504 if (call.Rank() == 0) {
505 return Scalar();
506 } else if (call.IsElemental()) {
507 for (const auto &arg : call.arguments()) {
508 if (arg && arg->Rank() > 0) {
509 return (*this)(*arg);
510 }
511 }
512 return Scalar();
513 } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
514 return (*this)(*symbol);
515 } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
516 if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
517 intrinsic->name == "ubound") {
518 // These are the array-valued cases for LBOUND and UBOUND (no DIM=).
519 const auto *expr{call.arguments().front().value().UnwrapExpr()};
520 CHECK(expr);
521 return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
522 } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
523 intrinsic->name == "count" || intrinsic->name == "iall" ||
524 intrinsic->name == "iany" || intrinsic->name == "iparity" ||
525 intrinsic->name == "maxloc" || intrinsic->name == "maxval" ||
526 intrinsic->name == "minloc" || intrinsic->name == "minval" ||
527 intrinsic->name == "norm2" || intrinsic->name == "parity" ||
528 intrinsic->name == "product" || intrinsic->name == "sum") {
529 // Reduction with DIM=
530 if (call.arguments().size() >= 2) {
531 auto arrayShape{
532 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
533 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
534 if (arrayShape && dimArg) {
535 if (auto dim{ToInt64(*dimArg)}) {
536 if (*dim >= 1 &&
537 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
538 arrayShape->erase(arrayShape->begin() + (*dim - 1));
539 return std::move(*arrayShape);
540 }
541 }
542 }
543 }
544 } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
545 if (!call.arguments().empty()) {
546 return (*this)(call.arguments()[0]);
547 }
548 } else if (intrinsic->name == "matmul") {
549 if (call.arguments().size() == 2) {
550 if (auto ashape{(*this)(call.arguments()[0])}) {
551 if (auto bshape{(*this)(call.arguments()[1])}) {
552 if (ashape->size() == 1 && bshape->size() == 2) {
553 bshape->erase(bshape->begin());
554 return std::move(*bshape); // matmul(vector, matrix)
555 } else if (ashape->size() == 2 && bshape->size() == 1) {
556 ashape->pop_back();
557 return std::move(*ashape); // matmul(matrix, vector)
558 } else if (ashape->size() == 2 && bshape->size() == 2) {
559 (*ashape)[1] = std::move((*bshape)[1]);
560 return std::move(*ashape); // matmul(matrix, matrix)
561 }
562 }
563 }
564 }
565 } else if (intrinsic->name == "reshape") {
566 if (call.arguments().size() >= 2 && call.arguments().at(1)) {
567 // SHAPE(RESHAPE(array,shape)) -> shape
568 if (const auto *shapeExpr{
569 call.arguments().at(1).value().UnwrapExpr()}) {
570 auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
571 return AsShape(context_, ConvertToType<ExtentType>(std::move(shape)));
572 }
573 }
574 } else if (intrinsic->name == "pack") {
575 if (call.arguments().size() >= 3 && call.arguments().at(2)) {
576 // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
577 return (*this)(call.arguments().at(2));
578 } else if (call.arguments().size() >= 2) {
579 if (auto maskShape{(*this)(call.arguments().at(1))}) {
580 if (maskShape->size() == 0) {
581 // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
582 if (auto arrayShape{(*this)(call.arguments().at(0))}) {
583 auto arraySize{GetSize(std::move(*arrayShape))};
584 CHECK(arraySize);
585 ActualArguments toMerge{
586 ActualArgument{AsGenericExpr(std::move(*arraySize))},
587 ActualArgument{AsGenericExpr(ExtentExpr{0})},
588 common::Clone(call.arguments().at(1))};
589 auto specific{context_.intrinsics().Probe(
590 CallCharacteristics{"merge"}, toMerge, context_)};
591 CHECK(specific);
592 return Shape{ExtentExpr{FunctionRef<ExtentType>{
593 ProcedureDesignator{std::move(specific->specificIntrinsic)},
594 std::move(specific->arguments)}}};
595 }
596 } else {
597 // Non-scalar MASK= -> [COUNT(mask)]
598 ActualArguments toCount{ActualArgument{common::Clone(
599 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
600 auto specific{context_.intrinsics().Probe(
601 CallCharacteristics{"count"}, toCount, context_)};
602 CHECK(specific);
603 return Shape{ExtentExpr{FunctionRef<ExtentType>{
604 ProcedureDesignator{std::move(specific->specificIntrinsic)},
605 std::move(specific->arguments)}}};
606 }
607 }
608 }
609 } else if (intrinsic->name == "spread") {
610 // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
611 // at position DIM.
612 if (call.arguments().size() == 3) {
613 auto arrayShape{
614 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
615 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
616 const auto *nCopies{
617 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
618 if (arrayShape && dimArg && nCopies) {
619 if (auto dim{ToInt64(*dimArg)}) {
620 if (*dim >= 1 &&
621 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
622 arrayShape->emplace(arrayShape->begin() + *dim - 1,
623 ConvertToType<ExtentType>(common::Clone(*nCopies)));
624 return std::move(*arrayShape);
625 }
626 }
627 }
628 }
629 } else if (intrinsic->name == "transfer") {
630 if (call.arguments().size() == 3 && call.arguments().at(2)) {
631 // SIZE= is present; shape is vector [SIZE=]
632 if (const auto *size{
633 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
634 return Shape{
635 MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
636 }
637 } else if (auto moldTypeAndShape{
638 characteristics::TypeAndShape::Characterize(
639 call.arguments().at(1), context_)}) {
640 if (GetRank(moldTypeAndShape->shape()) == 0) {
641 // SIZE= is absent and MOLD= is scalar: result is scalar
642 return Scalar();
643 } else {
644 // SIZE= is absent and MOLD= is array: result is vector whose
645 // length is determined by sizes of types. See 16.9.193p4 case(ii).
646 if (auto sourceTypeAndShape{
647 characteristics::TypeAndShape::Characterize(
648 call.arguments().at(0), context_)}) {
649 auto sourceElements{
650 GetSize(common::Clone(sourceTypeAndShape->shape()))};
651 auto sourceElementBytes{
652 sourceTypeAndShape->MeasureSizeInBytes(&context_)};
653 auto moldElementBytes{
654 moldTypeAndShape->MeasureSizeInBytes(&context_)};
655 if (sourceElements && sourceElementBytes && moldElementBytes) {
656 ExtentExpr extent{Fold(context_,
657 ((std::move(*sourceElements) *
658 std::move(*sourceElementBytes)) +
659 common::Clone(*moldElementBytes) - ExtentExpr{1}) /
660 common::Clone(*moldElementBytes))};
661 return Shape{MaybeExtentExpr{std::move(extent)}};
662 }
663 }
664 }
665 }
666 } else if (intrinsic->name == "transpose") {
667 if (call.arguments().size() >= 1) {
668 if (auto shape{(*this)(call.arguments().at(0))}) {
669 if (shape->size() == 2) {
670 std::swap((*shape)[0], (*shape)[1]);
671 return shape;
672 }
673 }
674 }
675 } else if (intrinsic->characteristics.value().attrs.test(characteristics::
676 Procedure::Attr::NullPointer)) { // NULL(MOLD=)
677 return (*this)(call.arguments());
678 } else {
679 // TODO: shapes of other non-elemental intrinsic results
680 }
681 }
682 return std::nullopt;
683 }
684
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,const char * leftIs,const char * rightIs)685 bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
686 const Shape &right, const char *leftIs, const char *rightIs) {
687 int n{GetRank(left)};
688 int rn{GetRank(right)};
689 if (n != 0 && rn != 0) {
690 if (n != rn) {
691 messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
692 leftIs, n, rightIs, rn);
693 return false;
694 } else {
695 for (int j{0}; j < n; ++j) {
696 if (auto leftDim{ToInt64(left[j])}) {
697 if (auto rightDim{ToInt64(right[j])}) {
698 if (*leftDim != *rightDim) {
699 messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
700 "but %4$s has extent %5$jd"_err_en_US,
701 j + 1, leftIs, *leftDim, rightIs, *rightDim);
702 return false;
703 }
704 }
705 }
706 }
707 }
708 }
709 return true;
710 }
711
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)712 bool IncrementSubscripts(
713 ConstantSubscripts &indices, const ConstantSubscripts &extents) {
714 std::size_t rank(indices.size());
715 CHECK(rank <= extents.size());
716 for (std::size_t j{0}; j < rank; ++j) {
717 if (extents[j] < 1) {
718 return false;
719 }
720 }
721 for (std::size_t j{0}; j < rank; ++j) {
722 if (indices[j]++ < extents[j]) {
723 return true;
724 }
725 indices[j] = 1;
726 }
727 return false;
728 }
729 } // namespace Fortran::evaluate
730