1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/intrinsics.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/type.h"
17 #include "flang/Parser/message.h"
18 #include "flang/Semantics/symbol.h"
19 #include <functional>
20 
21 using namespace std::placeholders; // _1, _2, &c. for std::bind()
22 
23 namespace Fortran::evaluate {
24 
IsImpliedShape(const Symbol & original)25 bool IsImpliedShape(const Symbol &original) {
26   const Symbol &symbol{ResolveAssociations(original)};
27   const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
28   return details && symbol.attrs().test(semantics::Attr::PARAMETER) &&
29       details->shape().IsImpliedShape();
30 }
31 
IsExplicitShape(const Symbol & original)32 bool IsExplicitShape(const Symbol &original) {
33   const Symbol &symbol{ResolveAssociations(original)};
34   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
35     const auto &shape{details->shape()};
36     return shape.Rank() == 0 ||
37         shape.IsExplicitShape(); // true when scalar, too
38   } else {
39     return symbol
40         .has<semantics::AssocEntityDetails>(); // exprs have explicit shape
41   }
42 }
43 
ConstantShape(const Constant<ExtentType> & arrayConstant)44 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) {
45   CHECK(arrayConstant.Rank() == 1);
46   Shape result;
47   std::size_t dimensions{arrayConstant.size()};
48   for (std::size_t j{0}; j < dimensions; ++j) {
49     Scalar<ExtentType> extent{arrayConstant.values().at(j)};
50     result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}});
51   }
52   return result;
53 }
54 
AsShape(ExtentExpr && arrayExpr) const55 auto GetShapeHelper::AsShape(ExtentExpr &&arrayExpr) const -> Result {
56   if (context_) {
57     arrayExpr = Fold(*context_, std::move(arrayExpr));
58   }
59   if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
60     return ConstantShape(*constArray);
61   }
62   if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
63     Shape result;
64     for (auto &value : *constructor) {
65       if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
66         if (expr->Rank() == 0) {
67           result.emplace_back(std::move(*expr));
68           continue;
69         }
70       }
71       return std::nullopt;
72     }
73     return result;
74   }
75   return std::nullopt;
76 }
77 
CreateShape(int rank,NamedEntity & base)78 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) {
79   Shape shape;
80   for (int dimension{0}; dimension < rank; ++dimension) {
81     shape.emplace_back(GetExtent(base, dimension));
82   }
83   return shape;
84 }
85 
AsExtentArrayExpr(const Shape & shape)86 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
87   ArrayConstructorValues<ExtentType> values;
88   for (const auto &dim : shape) {
89     if (dim) {
90       values.Push(common::Clone(*dim));
91     } else {
92       return std::nullopt;
93     }
94   }
95   return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
96 }
97 
AsConstantShape(FoldingContext & context,const Shape & shape)98 std::optional<Constant<ExtentType>> AsConstantShape(
99     FoldingContext &context, const Shape &shape) {
100   if (auto shapeArray{AsExtentArrayExpr(shape)}) {
101     auto folded{Fold(context, std::move(*shapeArray))};
102     if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
103       return std::move(*p);
104     }
105   }
106   return std::nullopt;
107 }
108 
AsConstantShape(const ConstantSubscripts & shape)109 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
110   using IntType = Scalar<SubscriptInteger>;
111   std::vector<IntType> result;
112   for (auto dim : shape) {
113     result.emplace_back(dim);
114   }
115   return {std::move(result), ConstantSubscripts{GetRank(shape)}};
116 }
117 
AsConstantExtents(const Constant<ExtentType> & shape)118 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
119   ConstantSubscripts result;
120   for (const auto &extent : shape.values()) {
121     result.push_back(extent.ToInt64());
122   }
123   return result;
124 }
125 
AsConstantExtents(FoldingContext & context,const Shape & shape)126 std::optional<ConstantSubscripts> AsConstantExtents(
127     FoldingContext &context, const Shape &shape) {
128   if (auto shapeConstant{AsConstantShape(context, shape)}) {
129     return AsConstantExtents(*shapeConstant);
130   } else {
131     return std::nullopt;
132   }
133 }
134 
AsShape(const ConstantSubscripts & shape)135 Shape AsShape(const ConstantSubscripts &shape) {
136   Shape result;
137   for (const auto &extent : shape) {
138     result.emplace_back(ExtentExpr{extent});
139   }
140   return result;
141 }
142 
AsShape(const std::optional<ConstantSubscripts> & shape)143 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &shape) {
144   if (shape) {
145     return AsShape(*shape);
146   } else {
147     return std::nullopt;
148   }
149 }
150 
Fold(FoldingContext & context,Shape && shape)151 Shape Fold(FoldingContext &context, Shape &&shape) {
152   for (auto &dim : shape) {
153     dim = Fold(context, std::move(dim));
154   }
155   return std::move(shape);
156 }
157 
Fold(FoldingContext & context,std::optional<Shape> && shape)158 std::optional<Shape> Fold(
159     FoldingContext &context, std::optional<Shape> &&shape) {
160   if (shape) {
161     return Fold(context, std::move(*shape));
162   } else {
163     return std::nullopt;
164   }
165 }
166 
ComputeTripCount(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)167 static ExtentExpr ComputeTripCount(
168     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
169   ExtentExpr strideCopy{common::Clone(stride)};
170   ExtentExpr span{
171       (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
172       std::move(stride)};
173   return ExtentExpr{
174       Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
175 }
176 
CountTrips(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)177 ExtentExpr CountTrips(
178     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
179   return ComputeTripCount(
180       std::move(lower), std::move(upper), std::move(stride));
181 }
182 
CountTrips(const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)183 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
184     const ExtentExpr &stride) {
185   return ComputeTripCount(
186       common::Clone(lower), common::Clone(upper), common::Clone(stride));
187 }
188 
CountTrips(MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)189 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper,
190     MaybeExtentExpr &&stride) {
191   std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
192       std::bind(ComputeTripCount, _1, _2, _3)};
193   return common::MapOptional(
194       std::move(bound), std::move(lower), std::move(upper), std::move(stride));
195 }
196 
GetSize(Shape && shape)197 MaybeExtentExpr GetSize(Shape &&shape) {
198   ExtentExpr extent{1};
199   for (auto &&dim : std::move(shape)) {
200     if (dim) {
201       extent = std::move(extent) * std::move(*dim);
202     } else {
203       return std::nullopt;
204     }
205   }
206   return extent;
207 }
208 
GetSize(const ConstantSubscripts & shape)209 ConstantSubscript GetSize(const ConstantSubscripts &shape) {
210   ConstantSubscript size{1};
211   for (auto dim : std::move(shape)) {
212     size *= dim;
213   }
214   return size;
215 }
216 
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)217 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
218   struct MyVisitor : public AnyTraverse<MyVisitor> {
219     using Base = AnyTraverse<MyVisitor>;
220     MyVisitor() : Base{*this} {}
221     using Base::operator();
222     bool operator()(const ImpliedDoIndex &) { return true; }
223   };
224   return MyVisitor{}(expr);
225 }
226 
227 // Determines lower bound on a dimension.  This can be other than 1 only
228 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
229 // ASSOCIATE construct entities may require traversal of their referents.
230 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> {
231 public:
232   using Result = ExtentExpr;
233   using Base = Traverse<GetLowerBoundHelper, ExtentExpr>;
234   using Base::operator();
GetLowerBoundHelper(int d)235   explicit GetLowerBoundHelper(int d) : Base{*this}, dimension_{d} {}
Default()236   static ExtentExpr Default() { return ExtentExpr{1}; }
Combine(Result &&,Result &&)237   static ExtentExpr Combine(Result &&, Result &&) { return Default(); }
238   ExtentExpr operator()(const Symbol &);
239   ExtentExpr operator()(const Component &);
240 
241 private:
242   int dimension_;
243 };
244 
operator ()(const Symbol & symbol0)245 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
246   const Symbol &symbol{symbol0.GetUltimate()};
247   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
248     int j{0};
249     for (const auto &shapeSpec : details->shape()) {
250       if (j++ == dimension_) {
251         if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
252           return *bound;
253         } else if (IsDescriptor(symbol)) {
254           return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
255               DescriptorInquiry::Field::LowerBound, dimension_}};
256         } else {
257           break;
258         }
259       }
260     }
261   } else if (const auto *assoc{
262                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
263     return (*this)(assoc->expr());
264   }
265   return Default();
266 }
267 
operator ()(const Component & component)268 auto GetLowerBoundHelper::operator()(const Component &component) -> Result {
269   if (component.base().Rank() == 0) {
270     const Symbol &symbol{component.GetLastSymbol().GetUltimate()};
271     if (const auto *details{
272             symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
273       int j{0};
274       for (const auto &shapeSpec : details->shape()) {
275         if (j++ == dimension_) {
276           if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
277             return *bound;
278           } else if (IsDescriptor(symbol)) {
279             return ExtentExpr{
280                 DescriptorInquiry{NamedEntity{common::Clone(component)},
281                     DescriptorInquiry::Field::LowerBound, dimension_}};
282           } else {
283             break;
284           }
285         }
286       }
287     }
288   }
289   return Default();
290 }
291 
GetLowerBound(const NamedEntity & base,int dimension)292 ExtentExpr GetLowerBound(const NamedEntity &base, int dimension) {
293   return GetLowerBoundHelper{dimension}(base);
294 }
295 
GetLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)296 ExtentExpr GetLowerBound(
297     FoldingContext &context, const NamedEntity &base, int dimension) {
298   return Fold(context, GetLowerBound(base, dimension));
299 }
300 
GetLowerBounds(const NamedEntity & base)301 Shape GetLowerBounds(const NamedEntity &base) {
302   Shape result;
303   int rank{base.Rank()};
304   for (int dim{0}; dim < rank; ++dim) {
305     result.emplace_back(GetLowerBound(base, dim));
306   }
307   return result;
308 }
309 
GetLowerBounds(FoldingContext & context,const NamedEntity & base)310 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) {
311   Shape result;
312   int rank{base.Rank()};
313   for (int dim{0}; dim < rank; ++dim) {
314     result.emplace_back(GetLowerBound(context, base, dim));
315   }
316   return result;
317 }
318 
GetExtent(const NamedEntity & base,int dimension)319 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
320   CHECK(dimension >= 0);
321   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
322   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
323     if (IsImpliedShape(symbol) && details->init()) {
324       if (auto shape{GetShape(symbol)}) {
325         if (dimension < static_cast<int>(shape->size())) {
326           return std::move(shape->at(dimension));
327         }
328       }
329     } else {
330       int j{0};
331       for (const auto &shapeSpec : details->shape()) {
332         if (j++ == dimension) {
333           if (shapeSpec.ubound().isExplicit()) {
334             if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
335               if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
336                 return common::Clone(ubound.value()) -
337                     common::Clone(lbound.value()) + ExtentExpr{1};
338               } else {
339                 return ubound.value();
340               }
341             }
342           } else if (details->IsAssumedSize() && j == symbol.Rank()) {
343             return std::nullopt;
344           } else if (semantics::IsDescriptor(symbol)) {
345             return ExtentExpr{DescriptorInquiry{NamedEntity{base},
346                 DescriptorInquiry::Field::Extent, dimension}};
347           }
348         }
349       }
350     }
351   } else if (const auto *assoc{
352                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
353     if (auto shape{GetShape(assoc->expr())}) {
354       if (dimension < static_cast<int>(shape->size())) {
355         return std::move(shape->at(dimension));
356       }
357     }
358   }
359   return std::nullopt;
360 }
361 
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)362 MaybeExtentExpr GetExtent(
363     FoldingContext &context, const NamedEntity &base, int dimension) {
364   return Fold(context, GetExtent(base, dimension));
365 }
366 
GetExtent(const Subscript & subscript,const NamedEntity & base,int dimension)367 MaybeExtentExpr GetExtent(
368     const Subscript &subscript, const NamedEntity &base, int dimension) {
369   return std::visit(
370       common::visitors{
371           [&](const Triplet &triplet) -> MaybeExtentExpr {
372             MaybeExtentExpr upper{triplet.upper()};
373             if (!upper) {
374               upper = GetUpperBound(base, dimension);
375             }
376             MaybeExtentExpr lower{triplet.lower()};
377             if (!lower) {
378               lower = GetLowerBound(base, dimension);
379             }
380             return CountTrips(std::move(lower), std::move(upper),
381                 MaybeExtentExpr{triplet.stride()});
382           },
383           [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
384             if (auto shape{GetShape(subs.value())}) {
385               if (GetRank(*shape) > 0) {
386                 CHECK(GetRank(*shape) == 1); // vector-valued subscript
387                 return std::move(shape->at(0));
388               }
389             }
390             return std::nullopt;
391           },
392       },
393       subscript.u);
394 }
395 
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)396 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
397     const NamedEntity &base, int dimension) {
398   return Fold(context, GetExtent(subscript, base, dimension));
399 }
400 
ComputeUpperBound(ExtentExpr && lower,MaybeExtentExpr && extent)401 MaybeExtentExpr ComputeUpperBound(
402     ExtentExpr &&lower, MaybeExtentExpr &&extent) {
403   if (extent) {
404     return std::move(*extent) + std::move(lower) - ExtentExpr{1};
405   } else {
406     return std::nullopt;
407   }
408 }
409 
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)410 MaybeExtentExpr ComputeUpperBound(
411     FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
412   return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent)));
413 }
414 
GetUpperBound(const NamedEntity & base,int dimension)415 MaybeExtentExpr GetUpperBound(const NamedEntity &base, int dimension) {
416   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
417   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
418     int j{0};
419     for (const auto &shapeSpec : details->shape()) {
420       if (j++ == dimension) {
421         if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
422           return *bound;
423         } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
424           break;
425         } else {
426           return ComputeUpperBound(
427               GetLowerBound(base, dimension), GetExtent(base, dimension));
428         }
429       }
430     }
431   } else if (const auto *assoc{
432                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
433     if (auto shape{GetShape(assoc->expr())}) {
434       if (dimension < static_cast<int>(shape->size())) {
435         return ComputeUpperBound(
436             GetLowerBound(base, dimension), std::move(shape->at(dimension)));
437       }
438     }
439   }
440   return std::nullopt;
441 }
442 
GetUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)443 MaybeExtentExpr GetUpperBound(
444     FoldingContext &context, const NamedEntity &base, int dimension) {
445   return Fold(context, GetUpperBound(base, dimension));
446 }
447 
GetUpperBounds(const NamedEntity & base)448 Shape GetUpperBounds(const NamedEntity &base) {
449   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
450   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
451     Shape result;
452     int dim{0};
453     for (const auto &shapeSpec : details->shape()) {
454       if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
455         result.push_back(*bound);
456       } else if (details->IsAssumedSize()) {
457         CHECK(dim + 1 == base.Rank());
458         result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
459       } else {
460         result.emplace_back(
461             ComputeUpperBound(GetLowerBound(base, dim), GetExtent(base, dim)));
462       }
463       ++dim;
464     }
465     CHECK(GetRank(result) == symbol.Rank());
466     return result;
467   } else {
468     return std::move(GetShape(symbol).value());
469   }
470 }
471 
GetUpperBounds(FoldingContext & context,const NamedEntity & base)472 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
473   return Fold(context, GetUpperBounds(base));
474 }
475 
operator ()(const Symbol & symbol) const476 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
477   return std::visit(
478       common::visitors{
479           [&](const semantics::ObjectEntityDetails &object) {
480             if (IsImpliedShape(symbol) && object.init()) {
481               return (*this)(object.init());
482             } else {
483               int n{object.shape().Rank()};
484               NamedEntity base{symbol};
485               return Result{CreateShape(n, base)};
486             }
487           },
488           [](const semantics::EntityDetails &) {
489             return ScalarShape(); // no dimensions seen
490           },
491           [&](const semantics::ProcEntityDetails &proc) {
492             if (const Symbol * interface{proc.interface().symbol()}) {
493               return (*this)(*interface);
494             } else {
495               return ScalarShape();
496             }
497           },
498           [&](const semantics::AssocEntityDetails &assoc) {
499             if (!assoc.rank()) {
500               return (*this)(assoc.expr());
501             } else {
502               int n{assoc.rank().value()};
503               NamedEntity base{symbol};
504               return Result{CreateShape(n, base)};
505             }
506           },
507           [&](const semantics::SubprogramDetails &subp) {
508             if (subp.isFunction()) {
509               return (*this)(subp.result());
510             } else {
511               return Result{};
512             }
513           },
514           [&](const semantics::ProcBindingDetails &binding) {
515             return (*this)(binding.symbol());
516           },
517           [](const semantics::TypeParamDetails &) { return ScalarShape(); },
518           [](const auto &) { return Result{}; },
519       },
520       symbol.GetUltimate().details());
521 }
522 
operator ()(const Component & component) const523 auto GetShapeHelper::operator()(const Component &component) const -> Result {
524   const Symbol &symbol{component.GetLastSymbol()};
525   int rank{symbol.Rank()};
526   if (rank == 0) {
527     return (*this)(component.base());
528   } else if (symbol.has<semantics::ObjectEntityDetails>()) {
529     NamedEntity base{Component{component}};
530     return CreateShape(rank, base);
531   } else if (symbol.has<semantics::AssocEntityDetails>()) {
532     NamedEntity base{Component{component}};
533     return Result{CreateShape(rank, base)};
534   } else {
535     return (*this)(symbol);
536   }
537 }
538 
operator ()(const ArrayRef & arrayRef) const539 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
540   Shape shape;
541   int dimension{0};
542   const NamedEntity &base{arrayRef.base()};
543   for (const Subscript &ss : arrayRef.subscript()) {
544     if (ss.Rank() > 0) {
545       shape.emplace_back(GetExtent(ss, base, dimension));
546     }
547     ++dimension;
548   }
549   if (shape.empty()) {
550     if (const Component * component{base.UnwrapComponent()}) {
551       return (*this)(component->base());
552     }
553   }
554   return shape;
555 }
556 
operator ()(const CoarrayRef & coarrayRef) const557 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
558   NamedEntity base{coarrayRef.GetBase()};
559   if (coarrayRef.subscript().empty()) {
560     return (*this)(base);
561   } else {
562     Shape shape;
563     int dimension{0};
564     for (const Subscript &ss : coarrayRef.subscript()) {
565       if (ss.Rank() > 0) {
566         shape.emplace_back(GetExtent(ss, base, dimension));
567       }
568       ++dimension;
569     }
570     return shape;
571   }
572 }
573 
operator ()(const Substring & substring) const574 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
575   return (*this)(substring.parent());
576 }
577 
operator ()(const ProcedureRef & call) const578 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
579   if (call.Rank() == 0) {
580     return ScalarShape();
581   } else if (call.IsElemental()) {
582     for (const auto &arg : call.arguments()) {
583       if (arg && arg->Rank() > 0) {
584         return (*this)(*arg);
585       }
586     }
587     return ScalarShape();
588   } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
589     return (*this)(*symbol);
590   } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
591     if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
592         intrinsic->name == "ubound") {
593       // These are the array-valued cases for LBOUND and UBOUND (no DIM=).
594       const auto *expr{call.arguments().front().value().UnwrapExpr()};
595       CHECK(expr);
596       return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
597     } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
598         intrinsic->name == "count" || intrinsic->name == "iall" ||
599         intrinsic->name == "iany" || intrinsic->name == "iparity" ||
600         intrinsic->name == "maxval" || intrinsic->name == "minval" ||
601         intrinsic->name == "norm2" || intrinsic->name == "parity" ||
602         intrinsic->name == "product" || intrinsic->name == "sum") {
603       // Reduction with DIM=
604       if (call.arguments().size() >= 2) {
605         auto arrayShape{
606             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
607         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
608         if (arrayShape && dimArg) {
609           if (auto dim{ToInt64(*dimArg)}) {
610             if (*dim >= 1 &&
611                 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
612               arrayShape->erase(arrayShape->begin() + (*dim - 1));
613               return std::move(*arrayShape);
614             }
615           }
616         }
617       }
618     } else if (intrinsic->name == "maxloc" || intrinsic->name == "minloc") {
619       // TODO: FINDLOC
620       if (call.arguments().size() >= 2) {
621         if (auto arrayShape{
622                 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
623           auto rank{static_cast<int>(arrayShape->size())};
624           if (const auto *dimArg{
625                   UnwrapExpr<Expr<SomeType>>(call.arguments()[1])}) {
626             auto dim{ToInt64(*dimArg)};
627             if (dim && *dim >= 1 && *dim <= rank) {
628               arrayShape->erase(arrayShape->begin() + (*dim - 1));
629               return std::move(*arrayShape);
630             }
631           } else {
632             // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=))
633             return Shape{ExtentExpr{rank}};
634           }
635         }
636       }
637     } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
638       if (!call.arguments().empty()) {
639         return (*this)(call.arguments()[0]);
640       }
641     } else if (intrinsic->name == "matmul") {
642       if (call.arguments().size() == 2) {
643         if (auto ashape{(*this)(call.arguments()[0])}) {
644           if (auto bshape{(*this)(call.arguments()[1])}) {
645             if (ashape->size() == 1 && bshape->size() == 2) {
646               bshape->erase(bshape->begin());
647               return std::move(*bshape); // matmul(vector, matrix)
648             } else if (ashape->size() == 2 && bshape->size() == 1) {
649               ashape->pop_back();
650               return std::move(*ashape); // matmul(matrix, vector)
651             } else if (ashape->size() == 2 && bshape->size() == 2) {
652               (*ashape)[1] = std::move((*bshape)[1]);
653               return std::move(*ashape); // matmul(matrix, matrix)
654             }
655           }
656         }
657       }
658     } else if (intrinsic->name == "reshape") {
659       if (call.arguments().size() >= 2 && call.arguments().at(1)) {
660         // SHAPE(RESHAPE(array,shape)) -> shape
661         if (const auto *shapeExpr{
662                 call.arguments().at(1).value().UnwrapExpr()}) {
663           auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
664           return AsShape(ConvertToType<ExtentType>(std::move(shape)));
665         }
666       }
667     } else if (intrinsic->name == "pack") {
668       if (call.arguments().size() >= 3 && call.arguments().at(2)) {
669         // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
670         return (*this)(call.arguments().at(2));
671       } else if (call.arguments().size() >= 2 && context_) {
672         if (auto maskShape{(*this)(call.arguments().at(1))}) {
673           if (maskShape->size() == 0) {
674             // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
675             if (auto arrayShape{(*this)(call.arguments().at(0))}) {
676               auto arraySize{GetSize(std::move(*arrayShape))};
677               CHECK(arraySize);
678               ActualArguments toMerge{
679                   ActualArgument{AsGenericExpr(std::move(*arraySize))},
680                   ActualArgument{AsGenericExpr(ExtentExpr{0})},
681                   common::Clone(call.arguments().at(1))};
682               auto specific{context_->intrinsics().Probe(
683                   CallCharacteristics{"merge"}, toMerge, *context_)};
684               CHECK(specific);
685               return Shape{ExtentExpr{FunctionRef<ExtentType>{
686                   ProcedureDesignator{std::move(specific->specificIntrinsic)},
687                   std::move(specific->arguments)}}};
688             }
689           } else {
690             // Non-scalar MASK= -> [COUNT(mask)]
691             ActualArguments toCount{ActualArgument{common::Clone(
692                 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
693             auto specific{context_->intrinsics().Probe(
694                 CallCharacteristics{"count"}, toCount, *context_)};
695             CHECK(specific);
696             return Shape{ExtentExpr{FunctionRef<ExtentType>{
697                 ProcedureDesignator{std::move(specific->specificIntrinsic)},
698                 std::move(specific->arguments)}}};
699           }
700         }
701       }
702     } else if (intrinsic->name == "spread") {
703       // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
704       // at position DIM.
705       if (call.arguments().size() == 3) {
706         auto arrayShape{
707             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
708         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
709         const auto *nCopies{
710             UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
711         if (arrayShape && dimArg && nCopies) {
712           if (auto dim{ToInt64(*dimArg)}) {
713             if (*dim >= 1 &&
714                 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
715               arrayShape->emplace(arrayShape->begin() + *dim - 1,
716                   ConvertToType<ExtentType>(common::Clone(*nCopies)));
717               return std::move(*arrayShape);
718             }
719           }
720         }
721       }
722     } else if (intrinsic->name == "transfer") {
723       if (call.arguments().size() == 3 && call.arguments().at(2)) {
724         // SIZE= is present; shape is vector [SIZE=]
725         if (const auto *size{
726                 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
727           return Shape{
728               MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
729         }
730       } else if (context_) {
731         if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize(
732                 call.arguments().at(1), *context_)}) {
733           if (GetRank(moldTypeAndShape->shape()) == 0) {
734             // SIZE= is absent and MOLD= is scalar: result is scalar
735             return ScalarShape();
736           } else {
737             // SIZE= is absent and MOLD= is array: result is vector whose
738             // length is determined by sizes of types.  See 16.9.193p4 case(ii).
739             if (auto sourceTypeAndShape{
740                     characteristics::TypeAndShape::Characterize(
741                         call.arguments().at(0), *context_)}) {
742               auto sourceBytes{
743                   sourceTypeAndShape->MeasureSizeInBytes(*context_)};
744               auto moldElementBytes{
745                   moldTypeAndShape->MeasureElementSizeInBytes(*context_, true)};
746               if (sourceBytes && moldElementBytes) {
747                 ExtentExpr extent{Fold(*context_,
748                     (std::move(*sourceBytes) +
749                         common::Clone(*moldElementBytes) - ExtentExpr{1}) /
750                         common::Clone(*moldElementBytes))};
751                 return Shape{MaybeExtentExpr{std::move(extent)}};
752               }
753             }
754           }
755         }
756       }
757     } else if (intrinsic->name == "transpose") {
758       if (call.arguments().size() >= 1) {
759         if (auto shape{(*this)(call.arguments().at(0))}) {
760           if (shape->size() == 2) {
761             std::swap((*shape)[0], (*shape)[1]);
762             return shape;
763           }
764         }
765       }
766     } else if (intrinsic->name == "unpack") {
767       if (call.arguments().size() >= 2) {
768         return (*this)(call.arguments()[1]); // MASK=
769       }
770     } else if (intrinsic->characteristics.value().attrs.test(characteristics::
771                        Procedure::Attr::NullPointer)) { // NULL(MOLD=)
772       return (*this)(call.arguments());
773     } else {
774       // TODO: shapes of other non-elemental intrinsic results
775     }
776   }
777   return std::nullopt;
778 }
779 
780 // Check conformance of the passed shapes.
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,CheckConformanceFlags::Flags flags,const char * leftIs,const char * rightIs)781 std::optional<bool> CheckConformance(parser::ContextualMessages &messages,
782     const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags,
783     const char *leftIs, const char *rightIs) {
784   int n{GetRank(left)};
785   if (n == 0 && (flags & CheckConformanceFlags::LeftScalarExpandable)) {
786     return true;
787   }
788   int rn{GetRank(right)};
789   if (rn == 0 && (flags & CheckConformanceFlags::RightScalarExpandable)) {
790     return true;
791   }
792   if (n != rn) {
793     messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
794         leftIs, n, rightIs, rn);
795     return false;
796   }
797   for (int j{0}; j < n; ++j) {
798     if (auto leftDim{ToInt64(left[j])}) {
799       if (auto rightDim{ToInt64(right[j])}) {
800         if (*leftDim != *rightDim) {
801           messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
802                        "but %4$s has extent %5$jd"_err_en_US,
803               j + 1, leftIs, *leftDim, rightIs, *rightDim);
804           return false;
805         }
806       } else if (!(flags & CheckConformanceFlags::RightIsDeferredShape)) {
807         return std::nullopt;
808       }
809     } else if (!(flags & CheckConformanceFlags::LeftIsDeferredShape)) {
810       return std::nullopt;
811     }
812   }
813   return true;
814 }
815 
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)816 bool IncrementSubscripts(
817     ConstantSubscripts &indices, const ConstantSubscripts &extents) {
818   std::size_t rank(indices.size());
819   CHECK(rank <= extents.size());
820   for (std::size_t j{0}; j < rank; ++j) {
821     if (extents[j] < 1) {
822       return false;
823     }
824   }
825   for (std::size_t j{0}; j < rank; ++j) {
826     if (indices[j]++ < extents[j]) {
827       return true;
828     }
829     indices[j] = 1;
830   }
831   return false;
832 }
833 
834 } // namespace Fortran::evaluate
835