1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/intrinsics.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/type.h"
17 #include "flang/Parser/message.h"
18 #include "flang/Semantics/symbol.h"
19 #include <functional>
20 
21 using namespace std::placeholders; // _1, _2, &c. for std::bind()
22 
23 namespace Fortran::evaluate {
24 
IsImpliedShape(const Symbol & original)25 bool IsImpliedShape(const Symbol &original) {
26   const Symbol &symbol{ResolveAssociations(original)};
27   const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
28   return details && symbol.attrs().test(semantics::Attr::PARAMETER) &&
29       details->shape().IsImpliedShape();
30 }
31 
IsExplicitShape(const Symbol & original)32 bool IsExplicitShape(const Symbol &original) {
33   const Symbol &symbol{ResolveAssociations(original)};
34   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
35     const auto &shape{details->shape()};
36     return shape.Rank() == 0 ||
37         shape.IsExplicitShape(); // true when scalar, too
38   } else {
39     return symbol
40         .has<semantics::AssocEntityDetails>(); // exprs have explicit shape
41   }
42 }
43 
ConstantShape(const Constant<ExtentType> & arrayConstant)44 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) {
45   CHECK(arrayConstant.Rank() == 1);
46   Shape result;
47   std::size_t dimensions{arrayConstant.size()};
48   for (std::size_t j{0}; j < dimensions; ++j) {
49     Scalar<ExtentType> extent{arrayConstant.values().at(j)};
50     result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}});
51   }
52   return result;
53 }
54 
AsShape(ExtentExpr && arrayExpr) const55 auto GetShapeHelper::AsShape(ExtentExpr &&arrayExpr) const -> Result {
56   if (context_) {
57     arrayExpr = Fold(*context_, std::move(arrayExpr));
58   }
59   if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
60     return ConstantShape(*constArray);
61   }
62   if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
63     Shape result;
64     for (auto &value : *constructor) {
65       if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
66         if (expr->Rank() == 0) {
67           result.emplace_back(std::move(*expr));
68           continue;
69         }
70       }
71       return std::nullopt;
72     }
73     return result;
74   }
75   return std::nullopt;
76 }
77 
CreateShape(int rank,NamedEntity & base)78 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) {
79   Shape shape;
80   for (int dimension{0}; dimension < rank; ++dimension) {
81     shape.emplace_back(GetExtent(base, dimension));
82   }
83   return shape;
84 }
85 
AsExtentArrayExpr(const Shape & shape)86 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
87   ArrayConstructorValues<ExtentType> values;
88   for (const auto &dim : shape) {
89     if (dim) {
90       values.Push(common::Clone(*dim));
91     } else {
92       return std::nullopt;
93     }
94   }
95   return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
96 }
97 
AsConstantShape(FoldingContext & context,const Shape & shape)98 std::optional<Constant<ExtentType>> AsConstantShape(
99     FoldingContext &context, const Shape &shape) {
100   if (auto shapeArray{AsExtentArrayExpr(shape)}) {
101     auto folded{Fold(context, std::move(*shapeArray))};
102     if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
103       return std::move(*p);
104     }
105   }
106   return std::nullopt;
107 }
108 
AsConstantShape(const ConstantSubscripts & shape)109 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
110   using IntType = Scalar<SubscriptInteger>;
111   std::vector<IntType> result;
112   for (auto dim : shape) {
113     result.emplace_back(dim);
114   }
115   return {std::move(result), ConstantSubscripts{GetRank(shape)}};
116 }
117 
AsConstantExtents(const Constant<ExtentType> & shape)118 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
119   ConstantSubscripts result;
120   for (const auto &extent : shape.values()) {
121     result.push_back(extent.ToInt64());
122   }
123   return result;
124 }
125 
AsConstantExtents(FoldingContext & context,const Shape & shape)126 std::optional<ConstantSubscripts> AsConstantExtents(
127     FoldingContext &context, const Shape &shape) {
128   if (auto shapeConstant{AsConstantShape(context, shape)}) {
129     return AsConstantExtents(*shapeConstant);
130   } else {
131     return std::nullopt;
132   }
133 }
134 
AsShape(const ConstantSubscripts & shape)135 Shape AsShape(const ConstantSubscripts &shape) {
136   Shape result;
137   for (const auto &extent : shape) {
138     result.emplace_back(ExtentExpr{extent});
139   }
140   return result;
141 }
142 
AsShape(const std::optional<ConstantSubscripts> & shape)143 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &shape) {
144   if (shape) {
145     return AsShape(*shape);
146   } else {
147     return std::nullopt;
148   }
149 }
150 
Fold(FoldingContext & context,Shape && shape)151 Shape Fold(FoldingContext &context, Shape &&shape) {
152   for (auto &dim : shape) {
153     dim = Fold(context, std::move(dim));
154   }
155   return std::move(shape);
156 }
157 
Fold(FoldingContext & context,std::optional<Shape> && shape)158 std::optional<Shape> Fold(
159     FoldingContext &context, std::optional<Shape> &&shape) {
160   if (shape) {
161     return Fold(context, std::move(*shape));
162   } else {
163     return std::nullopt;
164   }
165 }
166 
ComputeTripCount(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)167 static ExtentExpr ComputeTripCount(
168     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
169   ExtentExpr strideCopy{common::Clone(stride)};
170   ExtentExpr span{
171       (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
172       std::move(stride)};
173   return ExtentExpr{
174       Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
175 }
176 
CountTrips(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)177 ExtentExpr CountTrips(
178     ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
179   return ComputeTripCount(
180       std::move(lower), std::move(upper), std::move(stride));
181 }
182 
CountTrips(const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)183 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
184     const ExtentExpr &stride) {
185   return ComputeTripCount(
186       common::Clone(lower), common::Clone(upper), common::Clone(stride));
187 }
188 
CountTrips(MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)189 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper,
190     MaybeExtentExpr &&stride) {
191   std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
192       std::bind(ComputeTripCount, _1, _2, _3)};
193   return common::MapOptional(
194       std::move(bound), std::move(lower), std::move(upper), std::move(stride));
195 }
196 
GetSize(Shape && shape)197 MaybeExtentExpr GetSize(Shape &&shape) {
198   ExtentExpr extent{1};
199   for (auto &&dim : std::move(shape)) {
200     if (dim) {
201       extent = std::move(extent) * std::move(*dim);
202     } else {
203       return std::nullopt;
204     }
205   }
206   return extent;
207 }
208 
GetSize(const ConstantSubscripts & shape)209 ConstantSubscript GetSize(const ConstantSubscripts &shape) {
210   ConstantSubscript size{1};
211   for (auto dim : shape) {
212     CHECK(dim >= 0);
213     size *= dim;
214   }
215   return size;
216 }
217 
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)218 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
219   struct MyVisitor : public AnyTraverse<MyVisitor> {
220     using Base = AnyTraverse<MyVisitor>;
221     MyVisitor() : Base{*this} {}
222     using Base::operator();
223     bool operator()(const ImpliedDoIndex &) { return true; }
224   };
225   return MyVisitor{}(expr);
226 }
227 
228 // Determines lower bound on a dimension.  This can be other than 1 only
229 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
230 // ASSOCIATE construct entities may require traversal of their referents.
231 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> {
232 public:
233   using Result = ExtentExpr;
234   using Base = Traverse<GetLowerBoundHelper, ExtentExpr>;
235   using Base::operator();
GetLowerBoundHelper(int d)236   explicit GetLowerBoundHelper(int d) : Base{*this}, dimension_{d} {}
Default()237   static ExtentExpr Default() { return ExtentExpr{1}; }
Combine(Result &&,Result &&)238   static ExtentExpr Combine(Result &&, Result &&) { return Default(); }
239   ExtentExpr operator()(const Symbol &);
240   ExtentExpr operator()(const Component &);
241 
242 private:
243   int dimension_;
244 };
245 
operator ()(const Symbol & symbol0)246 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
247   const Symbol &symbol{symbol0.GetUltimate()};
248   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
249     int j{0};
250     for (const auto &shapeSpec : details->shape()) {
251       if (j++ == dimension_) {
252         if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
253           return *bound;
254         } else if (IsDescriptor(symbol)) {
255           return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
256               DescriptorInquiry::Field::LowerBound, dimension_}};
257         } else {
258           break;
259         }
260       }
261     }
262   } else if (const auto *assoc{
263                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
264     if (assoc->rank()) { // SELECT RANK case
265       const Symbol &resolved{ResolveAssociations(symbol)};
266       if (IsDescriptor(resolved) && dimension_ < *assoc->rank()) {
267         return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
268             DescriptorInquiry::Field::LowerBound, dimension_}};
269       }
270     } else {
271       return (*this)(assoc->expr());
272     }
273   }
274   return Default();
275 }
276 
operator ()(const Component & component)277 auto GetLowerBoundHelper::operator()(const Component &component) -> Result {
278   if (component.base().Rank() == 0) {
279     const Symbol &symbol{component.GetLastSymbol().GetUltimate()};
280     if (const auto *details{
281             symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
282       int j{0};
283       for (const auto &shapeSpec : details->shape()) {
284         if (j++ == dimension_) {
285           if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
286             return *bound;
287           } else if (IsDescriptor(symbol)) {
288             return ExtentExpr{
289                 DescriptorInquiry{NamedEntity{common::Clone(component)},
290                     DescriptorInquiry::Field::LowerBound, dimension_}};
291           } else {
292             break;
293           }
294         }
295       }
296     }
297   }
298   return Default();
299 }
300 
GetLowerBound(const NamedEntity & base,int dimension)301 ExtentExpr GetLowerBound(const NamedEntity &base, int dimension) {
302   return GetLowerBoundHelper{dimension}(base);
303 }
304 
GetLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)305 ExtentExpr GetLowerBound(
306     FoldingContext &context, const NamedEntity &base, int dimension) {
307   return Fold(context, GetLowerBound(base, dimension));
308 }
309 
GetLowerBounds(const NamedEntity & base)310 Shape GetLowerBounds(const NamedEntity &base) {
311   Shape result;
312   int rank{base.Rank()};
313   for (int dim{0}; dim < rank; ++dim) {
314     result.emplace_back(GetLowerBound(base, dim));
315   }
316   return result;
317 }
318 
GetLowerBounds(FoldingContext & context,const NamedEntity & base)319 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) {
320   Shape result;
321   int rank{base.Rank()};
322   for (int dim{0}; dim < rank; ++dim) {
323     result.emplace_back(GetLowerBound(context, base, dim));
324   }
325   return result;
326 }
327 
328 // If the upper and lower bounds are constant, return a constant expression for
329 // the extent.  In particular, if the upper bound is less than the lower bound,
330 // return zero.
GetNonNegativeExtent(const semantics::ShapeSpec & shapeSpec)331 static MaybeExtentExpr GetNonNegativeExtent(
332     const semantics::ShapeSpec &shapeSpec) {
333   const auto &ubound{shapeSpec.ubound().GetExplicit()};
334   const auto &lbound{shapeSpec.lbound().GetExplicit()};
335   std::optional<ConstantSubscript> uval{ToInt64(ubound)};
336   std::optional<ConstantSubscript> lval{ToInt64(lbound)};
337   if (uval && lval) {
338     if (*uval < *lval) {
339       return ExtentExpr{0};
340     } else {
341       return ExtentExpr{*uval - *lval + 1};
342     }
343   }
344   return common::Clone(ubound.value()) - common::Clone(lbound.value()) +
345       ExtentExpr{1};
346 }
347 
GetExtent(const NamedEntity & base,int dimension)348 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
349   CHECK(dimension >= 0);
350   const Symbol &last{base.GetLastSymbol()};
351   const Symbol &symbol{ResolveAssociations(last)};
352   if (const auto *assoc{last.detailsIf<semantics::AssocEntityDetails>()}) {
353     if (assoc->rank()) { // SELECT RANK case
354       if (semantics::IsDescriptor(symbol) && dimension < *assoc->rank()) {
355         return ExtentExpr{DescriptorInquiry{
356             NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
357       }
358     } else if (auto shape{GetShape(assoc->expr())}) {
359       if (dimension < static_cast<int>(shape->size())) {
360         return std::move(shape->at(dimension));
361       }
362     }
363   }
364   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
365     if (IsImpliedShape(symbol) && details->init()) {
366       if (auto shape{GetShape(symbol)}) {
367         if (dimension < static_cast<int>(shape->size())) {
368           return std::move(shape->at(dimension));
369         }
370       }
371     } else {
372       int j{0};
373       for (const auto &shapeSpec : details->shape()) {
374         if (j++ == dimension) {
375           if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
376             if (shapeSpec.ubound().GetExplicit()) {
377               // 8.5.8.2, paragraph 3.  If the upper bound is less than the
378               // lower bound, the extent is zero.
379               if (shapeSpec.lbound().GetExplicit()) {
380                 return GetNonNegativeExtent(shapeSpec);
381               } else {
382                 return ubound.value();
383               }
384             }
385           } else if (details->IsAssumedSize() && j == symbol.Rank()) {
386             return std::nullopt;
387           } else if (semantics::IsDescriptor(symbol)) {
388             return ExtentExpr{DescriptorInquiry{NamedEntity{base},
389                 DescriptorInquiry::Field::Extent, dimension}};
390           }
391         }
392       }
393     }
394   }
395   return std::nullopt;
396 }
397 
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)398 MaybeExtentExpr GetExtent(
399     FoldingContext &context, const NamedEntity &base, int dimension) {
400   return Fold(context, GetExtent(base, dimension));
401 }
402 
GetExtent(const Subscript & subscript,const NamedEntity & base,int dimension)403 MaybeExtentExpr GetExtent(
404     const Subscript &subscript, const NamedEntity &base, int dimension) {
405   return std::visit(
406       common::visitors{
407           [&](const Triplet &triplet) -> MaybeExtentExpr {
408             MaybeExtentExpr upper{triplet.upper()};
409             if (!upper) {
410               upper = GetUpperBound(base, dimension);
411             }
412             MaybeExtentExpr lower{triplet.lower()};
413             if (!lower) {
414               lower = GetLowerBound(base, dimension);
415             }
416             return CountTrips(std::move(lower), std::move(upper),
417                 MaybeExtentExpr{triplet.stride()});
418           },
419           [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
420             if (auto shape{GetShape(subs.value())}) {
421               if (GetRank(*shape) > 0) {
422                 CHECK(GetRank(*shape) == 1); // vector-valued subscript
423                 return std::move(shape->at(0));
424               }
425             }
426             return std::nullopt;
427           },
428       },
429       subscript.u);
430 }
431 
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)432 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
433     const NamedEntity &base, int dimension) {
434   return Fold(context, GetExtent(subscript, base, dimension));
435 }
436 
ComputeUpperBound(ExtentExpr && lower,MaybeExtentExpr && extent)437 MaybeExtentExpr ComputeUpperBound(
438     ExtentExpr &&lower, MaybeExtentExpr &&extent) {
439   if (extent) {
440     return std::move(*extent) + std::move(lower) - ExtentExpr{1};
441   } else {
442     return std::nullopt;
443   }
444 }
445 
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)446 MaybeExtentExpr ComputeUpperBound(
447     FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
448   return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent)));
449 }
450 
GetUpperBound(const NamedEntity & base,int dimension)451 MaybeExtentExpr GetUpperBound(const NamedEntity &base, int dimension) {
452   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
453   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
454     int j{0};
455     for (const auto &shapeSpec : details->shape()) {
456       if (j++ == dimension) {
457         if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
458           return *bound;
459         } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
460           break;
461         } else {
462           return ComputeUpperBound(
463               GetLowerBound(base, dimension), GetExtent(base, dimension));
464         }
465       }
466     }
467   } else if (const auto *assoc{
468                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
469     if (auto shape{GetShape(assoc->expr())}) {
470       if (dimension < static_cast<int>(shape->size())) {
471         return ComputeUpperBound(
472             GetLowerBound(base, dimension), std::move(shape->at(dimension)));
473       }
474     }
475   }
476   return std::nullopt;
477 }
478 
GetUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)479 MaybeExtentExpr GetUpperBound(
480     FoldingContext &context, const NamedEntity &base, int dimension) {
481   return Fold(context, GetUpperBound(base, dimension));
482 }
483 
GetUpperBounds(const NamedEntity & base)484 Shape GetUpperBounds(const NamedEntity &base) {
485   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
486   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
487     Shape result;
488     int dim{0};
489     for (const auto &shapeSpec : details->shape()) {
490       if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
491         result.push_back(*bound);
492       } else if (details->IsAssumedSize()) {
493         CHECK(dim + 1 == base.Rank());
494         result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
495       } else {
496         result.emplace_back(
497             ComputeUpperBound(GetLowerBound(base, dim), GetExtent(base, dim)));
498       }
499       ++dim;
500     }
501     CHECK(GetRank(result) == symbol.Rank());
502     return result;
503   } else {
504     return std::move(GetShape(symbol).value());
505   }
506 }
507 
GetUpperBounds(FoldingContext & context,const NamedEntity & base)508 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
509   return Fold(context, GetUpperBounds(base));
510 }
511 
operator ()(const Symbol & symbol) const512 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
513   return std::visit(
514       common::visitors{
515           [&](const semantics::ObjectEntityDetails &object) {
516             if (IsImpliedShape(symbol) && object.init()) {
517               return (*this)(object.init());
518             } else if (IsAssumedRank(symbol)) {
519               return Result{};
520             } else {
521               int n{object.shape().Rank()};
522               NamedEntity base{symbol};
523               return Result{CreateShape(n, base)};
524             }
525           },
526           [](const semantics::EntityDetails &) {
527             return ScalarShape(); // no dimensions seen
528           },
529           [&](const semantics::ProcEntityDetails &proc) {
530             if (const Symbol * interface{proc.interface().symbol()}) {
531               return (*this)(*interface);
532             } else {
533               return ScalarShape();
534             }
535           },
536           [&](const semantics::AssocEntityDetails &assoc) {
537             if (assoc.rank()) { // SELECT RANK case
538               int n{assoc.rank().value()};
539               NamedEntity base{symbol};
540               return Result{CreateShape(n, base)};
541             } else {
542               return (*this)(assoc.expr());
543             }
544           },
545           [&](const semantics::SubprogramDetails &subp) {
546             if (subp.isFunction()) {
547               return (*this)(subp.result());
548             } else {
549               return Result{};
550             }
551           },
552           [&](const semantics::ProcBindingDetails &binding) {
553             return (*this)(binding.symbol());
554           },
555           [](const semantics::TypeParamDetails &) { return ScalarShape(); },
556           [](const auto &) { return Result{}; },
557       },
558       symbol.GetUltimate().details());
559 }
560 
operator ()(const Component & component) const561 auto GetShapeHelper::operator()(const Component &component) const -> Result {
562   const Symbol &symbol{component.GetLastSymbol()};
563   int rank{symbol.Rank()};
564   if (rank == 0) {
565     return (*this)(component.base());
566   } else if (symbol.has<semantics::ObjectEntityDetails>()) {
567     NamedEntity base{Component{component}};
568     return CreateShape(rank, base);
569   } else if (symbol.has<semantics::AssocEntityDetails>()) {
570     NamedEntity base{Component{component}};
571     return Result{CreateShape(rank, base)};
572   } else {
573     return (*this)(symbol);
574   }
575 }
576 
operator ()(const ArrayRef & arrayRef) const577 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
578   Shape shape;
579   int dimension{0};
580   const NamedEntity &base{arrayRef.base()};
581   for (const Subscript &ss : arrayRef.subscript()) {
582     if (ss.Rank() > 0) {
583       shape.emplace_back(GetExtent(ss, base, dimension));
584     }
585     ++dimension;
586   }
587   if (shape.empty()) {
588     if (const Component * component{base.UnwrapComponent()}) {
589       return (*this)(component->base());
590     }
591   }
592   return shape;
593 }
594 
operator ()(const CoarrayRef & coarrayRef) const595 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
596   NamedEntity base{coarrayRef.GetBase()};
597   if (coarrayRef.subscript().empty()) {
598     return (*this)(base);
599   } else {
600     Shape shape;
601     int dimension{0};
602     for (const Subscript &ss : coarrayRef.subscript()) {
603       if (ss.Rank() > 0) {
604         shape.emplace_back(GetExtent(ss, base, dimension));
605       }
606       ++dimension;
607     }
608     return shape;
609   }
610 }
611 
operator ()(const Substring & substring) const612 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
613   return (*this)(substring.parent());
614 }
615 
operator ()(const ProcedureRef & call) const616 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
617   if (call.Rank() == 0) {
618     return ScalarShape();
619   } else if (call.IsElemental()) {
620     for (const auto &arg : call.arguments()) {
621       if (arg && arg->Rank() > 0) {
622         return (*this)(*arg);
623       }
624     }
625     return ScalarShape();
626   } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
627     return (*this)(*symbol);
628   } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
629     if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
630         intrinsic->name == "ubound") {
631       // These are the array-valued cases for LBOUND and UBOUND (no DIM=).
632       const auto *expr{call.arguments().front().value().UnwrapExpr()};
633       CHECK(expr);
634       return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
635     } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
636         intrinsic->name == "count" || intrinsic->name == "iall" ||
637         intrinsic->name == "iany" || intrinsic->name == "iparity" ||
638         intrinsic->name == "maxval" || intrinsic->name == "minval" ||
639         intrinsic->name == "norm2" || intrinsic->name == "parity" ||
640         intrinsic->name == "product" || intrinsic->name == "sum") {
641       // Reduction with DIM=
642       if (call.arguments().size() >= 2) {
643         auto arrayShape{
644             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
645         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
646         if (arrayShape && dimArg) {
647           if (auto dim{ToInt64(*dimArg)}) {
648             if (*dim >= 1 &&
649                 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
650               arrayShape->erase(arrayShape->begin() + (*dim - 1));
651               return std::move(*arrayShape);
652             }
653           }
654         }
655       }
656     } else if (intrinsic->name == "findloc" || intrinsic->name == "maxloc" ||
657         intrinsic->name == "minloc") {
658       std::size_t dimIndex{intrinsic->name == "findloc" ? 2u : 1u};
659       if (call.arguments().size() > dimIndex) {
660         if (auto arrayShape{
661                 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
662           auto rank{static_cast<int>(arrayShape->size())};
663           if (const auto *dimArg{
664                   UnwrapExpr<Expr<SomeType>>(call.arguments()[dimIndex])}) {
665             auto dim{ToInt64(*dimArg)};
666             if (dim && *dim >= 1 && *dim <= rank) {
667               arrayShape->erase(arrayShape->begin() + (*dim - 1));
668               return std::move(*arrayShape);
669             }
670           } else {
671             // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=))
672             return Shape{ExtentExpr{rank}};
673           }
674         }
675       }
676     } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
677       if (!call.arguments().empty()) {
678         return (*this)(call.arguments()[0]);
679       }
680     } else if (intrinsic->name == "matmul") {
681       if (call.arguments().size() == 2) {
682         if (auto ashape{(*this)(call.arguments()[0])}) {
683           if (auto bshape{(*this)(call.arguments()[1])}) {
684             if (ashape->size() == 1 && bshape->size() == 2) {
685               bshape->erase(bshape->begin());
686               return std::move(*bshape); // matmul(vector, matrix)
687             } else if (ashape->size() == 2 && bshape->size() == 1) {
688               ashape->pop_back();
689               return std::move(*ashape); // matmul(matrix, vector)
690             } else if (ashape->size() == 2 && bshape->size() == 2) {
691               (*ashape)[1] = std::move((*bshape)[1]);
692               return std::move(*ashape); // matmul(matrix, matrix)
693             }
694           }
695         }
696       }
697     } else if (intrinsic->name == "reshape") {
698       if (call.arguments().size() >= 2 && call.arguments().at(1)) {
699         // SHAPE(RESHAPE(array,shape)) -> shape
700         if (const auto *shapeExpr{
701                 call.arguments().at(1).value().UnwrapExpr()}) {
702           auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
703           return AsShape(ConvertToType<ExtentType>(std::move(shape)));
704         }
705       }
706     } else if (intrinsic->name == "pack") {
707       if (call.arguments().size() >= 3 && call.arguments().at(2)) {
708         // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
709         return (*this)(call.arguments().at(2));
710       } else if (call.arguments().size() >= 2 && context_) {
711         if (auto maskShape{(*this)(call.arguments().at(1))}) {
712           if (maskShape->size() == 0) {
713             // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
714             if (auto arrayShape{(*this)(call.arguments().at(0))}) {
715               auto arraySize{GetSize(std::move(*arrayShape))};
716               CHECK(arraySize);
717               ActualArguments toMerge{
718                   ActualArgument{AsGenericExpr(std::move(*arraySize))},
719                   ActualArgument{AsGenericExpr(ExtentExpr{0})},
720                   common::Clone(call.arguments().at(1))};
721               auto specific{context_->intrinsics().Probe(
722                   CallCharacteristics{"merge"}, toMerge, *context_)};
723               CHECK(specific);
724               return Shape{ExtentExpr{FunctionRef<ExtentType>{
725                   ProcedureDesignator{std::move(specific->specificIntrinsic)},
726                   std::move(specific->arguments)}}};
727             }
728           } else {
729             // Non-scalar MASK= -> [COUNT(mask)]
730             ActualArguments toCount{ActualArgument{common::Clone(
731                 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
732             auto specific{context_->intrinsics().Probe(
733                 CallCharacteristics{"count"}, toCount, *context_)};
734             CHECK(specific);
735             return Shape{ExtentExpr{FunctionRef<ExtentType>{
736                 ProcedureDesignator{std::move(specific->specificIntrinsic)},
737                 std::move(specific->arguments)}}};
738           }
739         }
740       }
741     } else if (intrinsic->name == "spread") {
742       // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
743       // at position DIM.
744       if (call.arguments().size() == 3) {
745         auto arrayShape{
746             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
747         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
748         const auto *nCopies{
749             UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
750         if (arrayShape && dimArg && nCopies) {
751           if (auto dim{ToInt64(*dimArg)}) {
752             if (*dim >= 1 &&
753                 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
754               arrayShape->emplace(arrayShape->begin() + *dim - 1,
755                   ConvertToType<ExtentType>(common::Clone(*nCopies)));
756               return std::move(*arrayShape);
757             }
758           }
759         }
760       }
761     } else if (intrinsic->name == "transfer") {
762       if (call.arguments().size() == 3 && call.arguments().at(2)) {
763         // SIZE= is present; shape is vector [SIZE=]
764         if (const auto *size{
765                 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
766           return Shape{
767               MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
768         }
769       } else if (context_) {
770         if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize(
771                 call.arguments().at(1), *context_)}) {
772           if (GetRank(moldTypeAndShape->shape()) == 0) {
773             // SIZE= is absent and MOLD= is scalar: result is scalar
774             return ScalarShape();
775           } else {
776             // SIZE= is absent and MOLD= is array: result is vector whose
777             // length is determined by sizes of types.  See 16.9.193p4 case(ii).
778             if (auto sourceTypeAndShape{
779                     characteristics::TypeAndShape::Characterize(
780                         call.arguments().at(0), *context_)}) {
781               auto sourceBytes{
782                   sourceTypeAndShape->MeasureSizeInBytes(*context_)};
783               auto moldElementBytes{
784                   moldTypeAndShape->MeasureElementSizeInBytes(*context_, true)};
785               if (sourceBytes && moldElementBytes) {
786                 ExtentExpr extent{Fold(*context_,
787                     (std::move(*sourceBytes) +
788                         common::Clone(*moldElementBytes) - ExtentExpr{1}) /
789                         common::Clone(*moldElementBytes))};
790                 return Shape{MaybeExtentExpr{std::move(extent)}};
791               }
792             }
793           }
794         }
795       }
796     } else if (intrinsic->name == "transpose") {
797       if (call.arguments().size() >= 1) {
798         if (auto shape{(*this)(call.arguments().at(0))}) {
799           if (shape->size() == 2) {
800             std::swap((*shape)[0], (*shape)[1]);
801             return shape;
802           }
803         }
804       }
805     } else if (intrinsic->name == "unpack") {
806       if (call.arguments().size() >= 2) {
807         return (*this)(call.arguments()[1]); // MASK=
808       }
809     } else if (intrinsic->characteristics.value().attrs.test(characteristics::
810                        Procedure::Attr::NullPointer)) { // NULL(MOLD=)
811       return (*this)(call.arguments());
812     } else {
813       // TODO: shapes of other non-elemental intrinsic results
814     }
815   }
816   return std::nullopt;
817 }
818 
819 // Check conformance of the passed shapes.
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,CheckConformanceFlags::Flags flags,const char * leftIs,const char * rightIs)820 std::optional<bool> CheckConformance(parser::ContextualMessages &messages,
821     const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags,
822     const char *leftIs, const char *rightIs) {
823   int n{GetRank(left)};
824   if (n == 0 && (flags & CheckConformanceFlags::LeftScalarExpandable)) {
825     return true;
826   }
827   int rn{GetRank(right)};
828   if (rn == 0 && (flags & CheckConformanceFlags::RightScalarExpandable)) {
829     return true;
830   }
831   if (n != rn) {
832     messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
833         leftIs, n, rightIs, rn);
834     return false;
835   }
836   for (int j{0}; j < n; ++j) {
837     if (auto leftDim{ToInt64(left[j])}) {
838       if (auto rightDim{ToInt64(right[j])}) {
839         if (*leftDim != *rightDim) {
840           messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
841                        "but %4$s has extent %5$jd"_err_en_US,
842               j + 1, leftIs, *leftDim, rightIs, *rightDim);
843           return false;
844         }
845       } else if (!(flags & CheckConformanceFlags::RightIsDeferredShape)) {
846         return std::nullopt;
847       }
848     } else if (!(flags & CheckConformanceFlags::LeftIsDeferredShape)) {
849       return std::nullopt;
850     }
851   }
852   return true;
853 }
854 
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)855 bool IncrementSubscripts(
856     ConstantSubscripts &indices, const ConstantSubscripts &extents) {
857   std::size_t rank(indices.size());
858   CHECK(rank <= extents.size());
859   for (std::size_t j{0}; j < rank; ++j) {
860     if (extents[j] < 1) {
861       return false;
862     }
863   }
864   for (std::size_t j{0}; j < rank; ++j) {
865     if (indices[j]++ < extents[j]) {
866       return true;
867     }
868     indices[j] = 1;
869   }
870   return false;
871 }
872 
873 } // namespace Fortran::evaluate
874