1 //===-- lib/Evaluate/fold-integer.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "fold-implementation.h"
10 #include "fold-reduction.h"
11 #include "flang/Evaluate/check-expression.h"
12 
13 namespace Fortran::evaluate {
14 
15 // Class to retrieve the constant lower bound of an expression which is an
16 // array that devolves to a type of Constant<T>
17 class GetConstantArrayLboundHelper {
18 public:
GetConstantArrayLboundHelper(ConstantSubscript dim)19   GetConstantArrayLboundHelper(ConstantSubscript dim) : dim_{dim} {}
20 
GetLbound(const T &)21   template <typename T> ConstantSubscript GetLbound(const T &) {
22     // The method is needed for template expansion, but we should never get
23     // here in practice.
24     CHECK(false);
25     return 0;
26   }
27 
GetLbound(const Constant<T> & x)28   template <typename T> ConstantSubscript GetLbound(const Constant<T> &x) {
29     // Return the lower bound
30     return x.lbounds()[dim_];
31   }
32 
GetLbound(const Parentheses<T> & x)33   template <typename T> ConstantSubscript GetLbound(const Parentheses<T> &x) {
34     // Strip off the parentheses
35     return GetLbound(x.left());
36   }
37 
GetLbound(const Expr<T> & x)38   template <typename T> ConstantSubscript GetLbound(const Expr<T> &x) {
39     // recurse through Expr<T>'a until we hit a constant
40     return std::visit([&](const auto &inner) { return GetLbound(inner); },
41         //      [&](const auto &) { return 0; },
42         x.u);
43   }
44 
45 private:
46   ConstantSubscript dim_;
47 };
48 
49 template <int KIND>
LBOUND(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)50 Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context,
51     FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
52   using T = Type<TypeCategory::Integer, KIND>;
53   ActualArguments &args{funcRef.arguments()};
54   if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
55     if (int rank{array->Rank()}; rank > 0) {
56       std::optional<int> dim;
57       if (funcRef.Rank() == 0) {
58         // Optional DIM= argument is present: result is scalar.
59         if (auto dim64{GetInt64Arg(args[1])}) {
60           if (*dim64 < 1 || *dim64 > rank) {
61             context.messages().Say("DIM=%jd dimension is out of range for "
62                                    "rank-%d array"_err_en_US,
63                 *dim64, rank);
64             return MakeInvalidIntrinsic<T>(std::move(funcRef));
65           } else {
66             dim = *dim64 - 1; // 1-based to 0-based
67           }
68         } else {
69           // DIM= is present but not constant
70           return Expr<T>{std::move(funcRef)};
71         }
72       }
73       bool lowerBoundsAreOne{true};
74       if (auto named{ExtractNamedEntity(*array)}) {
75         const Symbol &symbol{named->GetLastSymbol()};
76         if (symbol.Rank() == rank) {
77           lowerBoundsAreOne = false;
78           if (dim) {
79             return Fold(context,
80                 ConvertToType<T>(GetLowerBound(context, *named, *dim)));
81           } else if (auto extents{
82                          AsExtentArrayExpr(GetLowerBounds(context, *named))}) {
83             return Fold(context,
84                 ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
85           }
86         } else {
87           lowerBoundsAreOne = symbol.Rank() == 0; // LBOUND(array%component)
88         }
89       }
90       if (IsActuallyConstant(*array)) {
91         return Expr<T>{GetConstantArrayLboundHelper{*dim}.GetLbound(*array)};
92       }
93       if (lowerBoundsAreOne) {
94         if (dim) {
95           return Expr<T>{1};
96         } else {
97           std::vector<Scalar<T>> ones(rank, Scalar<T>{1});
98           return Expr<T>{
99               Constant<T>{std::move(ones), ConstantSubscripts{rank}}};
100         }
101       }
102     }
103   }
104   return Expr<T>{std::move(funcRef)};
105 }
106 
107 template <int KIND>
UBOUND(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)108 Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
109     FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
110   using T = Type<TypeCategory::Integer, KIND>;
111   ActualArguments &args{funcRef.arguments()};
112   if (auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
113     if (int rank{array->Rank()}; rank > 0) {
114       std::optional<int> dim;
115       if (funcRef.Rank() == 0) {
116         // Optional DIM= argument is present: result is scalar.
117         if (auto dim64{GetInt64Arg(args[1])}) {
118           if (*dim64 < 1 || *dim64 > rank) {
119             context.messages().Say("DIM=%jd dimension is out of range for "
120                                    "rank-%d array"_err_en_US,
121                 *dim64, rank);
122             return MakeInvalidIntrinsic<T>(std::move(funcRef));
123           } else {
124             dim = *dim64 - 1; // 1-based to 0-based
125           }
126         } else {
127           // DIM= is present but not constant
128           return Expr<T>{std::move(funcRef)};
129         }
130       }
131       bool takeBoundsFromShape{true};
132       if (auto named{ExtractNamedEntity(*array)}) {
133         const Symbol &symbol{named->GetLastSymbol()};
134         if (symbol.Rank() == rank) {
135           takeBoundsFromShape = false;
136           if (dim) {
137             if (semantics::IsAssumedSizeArray(symbol) && *dim == rank - 1) {
138               context.messages().Say("DIM=%jd dimension is out of range for "
139                                      "rank-%d assumed-size array"_err_en_US,
140                   rank, rank);
141               return MakeInvalidIntrinsic<T>(std::move(funcRef));
142             } else if (auto ub{GetUpperBound(context, *named, *dim)}) {
143               return Fold(context, ConvertToType<T>(std::move(*ub)));
144             }
145           } else {
146             Shape ubounds{GetUpperBounds(context, *named)};
147             if (semantics::IsAssumedSizeArray(symbol)) {
148               CHECK(!ubounds.back());
149               ubounds.back() = ExtentExpr{-1};
150             }
151             if (auto extents{AsExtentArrayExpr(ubounds)}) {
152               return Fold(context,
153                   ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
154             }
155           }
156         } else {
157           takeBoundsFromShape = symbol.Rank() == 0; // UBOUND(array%component)
158         }
159       }
160       if (takeBoundsFromShape) {
161         if (auto shape{GetShape(context, *array)}) {
162           if (dim) {
163             if (auto &dimSize{shape->at(*dim)}) {
164               return Fold(context,
165                   ConvertToType<T>(Expr<ExtentType>{std::move(*dimSize)}));
166             }
167           } else if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
168             return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
169           }
170         }
171       }
172     }
173   }
174   return Expr<T>{std::move(funcRef)};
175 }
176 
177 // COUNT()
178 template <typename T>
FoldCount(FoldingContext & context,FunctionRef<T> && ref)179 static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
180   static_assert(T::category == TypeCategory::Integer);
181   ActualArguments &arg{ref.arguments()};
182   if (const Constant<LogicalResult> *mask{arg.empty()
183               ? nullptr
184               : Folder<LogicalResult>{context}.Folding(arg[0])}) {
185     std::optional<int> dim;
186     if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
187       auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
188         if (mask->At(at).IsTrue()) {
189           element = element.AddSigned(Scalar<T>{1}).value;
190         }
191       }};
192       return Expr<T>{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
193     }
194   }
195   return Expr<T>{std::move(ref)};
196 }
197 
198 // FINDLOC(), MAXLOC(), & MINLOC()
199 enum class WhichLocation { Findloc, Maxloc, Minloc };
200 template <WhichLocation WHICH> class LocationHelper {
201 public:
LocationHelper(DynamicType && type,ActualArguments & arg,FoldingContext & context)202   LocationHelper(
203       DynamicType &&type, ActualArguments &arg, FoldingContext &context)
204       : type_{type}, arg_{arg}, context_{context} {}
205   using Result = std::optional<Constant<SubscriptInteger>>;
206   using Types = std::conditional_t<WHICH == WhichLocation::Findloc,
207       AllIntrinsicTypes, RelationalTypes>;
208 
Test() const209   template <typename T> Result Test() const {
210     if (T::category != type_.category() || T::kind != type_.kind()) {
211       return std::nullopt;
212     }
213     CHECK(arg_.size() == (WHICH == WhichLocation::Findloc ? 6 : 5));
214     Folder<T> folder{context_};
215     Constant<T> *array{folder.Folding(arg_[0])};
216     if (!array) {
217       return std::nullopt;
218     }
219     std::optional<Constant<T>> value;
220     if constexpr (WHICH == WhichLocation::Findloc) {
221       if (const Constant<T> *p{folder.Folding(arg_[1])}) {
222         value.emplace(*p);
223       } else {
224         return std::nullopt;
225       }
226     }
227     std::optional<int> dim;
228     Constant<LogicalResult> *mask{
229         GetReductionMASK(arg_[maskArg], array->shape(), context_)};
230     if ((!mask && arg_[maskArg]) ||
231         !CheckReductionDIM(dim, context_, arg_, dimArg, array->Rank())) {
232       return std::nullopt;
233     }
234     bool back{false};
235     if (arg_[backArg]) {
236       const auto *backConst{
237           Folder<LogicalResult>{context_}.Folding(arg_[backArg])};
238       if (backConst) {
239         back = backConst->GetScalarValue().value().IsTrue();
240       } else {
241         return std::nullopt;
242       }
243     }
244     const RelationalOperator relation{WHICH == WhichLocation::Findloc
245             ? RelationalOperator::EQ
246             : WHICH == WhichLocation::Maxloc
247             ? (back ? RelationalOperator::GE : RelationalOperator::GT)
248             : back ? RelationalOperator::LE
249                    : RelationalOperator::LT};
250     // Use lower bounds of 1 exclusively.
251     array->SetLowerBoundsToOne();
252     ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
253     if (mask) {
254       mask->SetLowerBoundsToOne();
255       maskAt = mask->lbounds();
256     }
257     if (dim) { // DIM=
258       if (*dim < 1 || *dim > array->Rank()) {
259         context_.messages().Say(
260             "FINDLOC(DIM=%d) is out of range"_err_en_US, *dim);
261         return std::nullopt;
262       }
263       int zbDim{*dim - 1};
264       resultShape = array->shape();
265       resultShape.erase(
266           resultShape.begin() + zbDim); // scalar if array is vector
267       ConstantSubscript dimLength{array->shape()[zbDim]};
268       ConstantSubscript n{GetSize(resultShape)};
269       for (ConstantSubscript j{0}; j < n; ++j) {
270         ConstantSubscript hit{array->lbounds()[zbDim] - 1};
271         value.reset();
272         for (ConstantSubscript k{0}; k < dimLength;
273              ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
274           if ((!mask || mask->At(maskAt).IsTrue()) &&
275               IsHit(array->At(at), value, relation)) {
276             hit = at[zbDim];
277             if (!back) {
278               break;
279             }
280           }
281         }
282         resultIndices.emplace_back(hit);
283         at[zbDim] = array->lbounds()[zbDim] + dimLength - 1;
284         array->IncrementSubscripts(at);
285         at[zbDim] = array->lbounds()[zbDim];
286         if (mask) {
287           maskAt[zbDim] = mask->lbounds()[zbDim] + dimLength - 1;
288           mask->IncrementSubscripts(maskAt);
289           maskAt[zbDim] = mask->lbounds()[zbDim];
290         }
291       }
292     } else { // no DIM=
293       resultShape = ConstantSubscripts{array->Rank()}; // always a vector
294       ConstantSubscript n{GetSize(array->shape())};
295       resultIndices = ConstantSubscripts(array->Rank(), 0);
296       for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
297            mask && mask->IncrementSubscripts(maskAt)) {
298         if ((!mask || mask->At(maskAt).IsTrue()) &&
299             IsHit(array->At(at), value, relation)) {
300           resultIndices = at;
301           if (!back) {
302             break;
303           }
304         }
305       }
306     }
307     std::vector<Scalar<SubscriptInteger>> resultElements;
308     for (ConstantSubscript j : resultIndices) {
309       resultElements.emplace_back(j);
310     }
311     return Constant<SubscriptInteger>{
312         std::move(resultElements), std::move(resultShape)};
313   }
314 
315 private:
316   template <typename T>
IsHit(typename Constant<T>::Element element,std::optional<Constant<T>> & value,RelationalOperator relation) const317   bool IsHit(typename Constant<T>::Element element,
318       std::optional<Constant<T>> &value,
319       [[maybe_unused]] RelationalOperator relation) const {
320     std::optional<Expr<LogicalResult>> cmp;
321     if (value) {
322       if constexpr (T::category == TypeCategory::Logical) {
323         // array(at) .EQV. value?
324         static_assert(WHICH == WhichLocation::Findloc);
325         cmp.emplace(
326             ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
327                 LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
328                 Expr<T>{Constant<T>{*value}}}}));
329       } else { // compare array(at) to value
330         cmp.emplace(
331             PackageRelation(relation, Expr<T>{Constant<T>{std::move(element)}},
332                 Expr<T>{Constant<T>{*value}}));
333       }
334       Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
335       return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
336     } else { // first unmasked element seen for MAXLOC/MINLOC
337       value.emplace(std::move(element));
338       return true;
339     }
340   }
341 
342   static constexpr int dimArg{WHICH == WhichLocation::Findloc ? 2 : 1};
343   static constexpr int maskArg{dimArg + 1};
344   static constexpr int backArg{maskArg + 2};
345 
346   DynamicType type_;
347   ActualArguments &arg_;
348   FoldingContext &context_;
349 };
350 
351 template <WhichLocation which>
FoldLocationCall(ActualArguments & arg,FoldingContext & context)352 static std::optional<Constant<SubscriptInteger>> FoldLocationCall(
353     ActualArguments &arg, FoldingContext &context) {
354   if (arg[0]) {
355     if (auto type{arg[0]->GetType()}) {
356       return common::SearchTypes(
357           LocationHelper<which>{std::move(*type), arg, context});
358     }
359   }
360   return std::nullopt;
361 }
362 
363 template <WhichLocation which, typename T>
FoldLocation(FoldingContext & context,FunctionRef<T> && ref)364 static Expr<T> FoldLocation(FoldingContext &context, FunctionRef<T> &&ref) {
365   static_assert(T::category == TypeCategory::Integer);
366   if (std::optional<Constant<SubscriptInteger>> found{
367           FoldLocationCall<which>(ref.arguments(), context)}) {
368     return Expr<T>{Fold(
369         context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
370   } else {
371     return Expr<T>{std::move(ref)};
372   }
373 }
374 
375 // for IALL, IANY, & IPARITY
376 template <typename T>
FoldBitReduction(FoldingContext & context,FunctionRef<T> && ref,Scalar<T> (Scalar<T>::* operation)(const Scalar<T> &)const,Scalar<T> identity)377 static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
378     Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
379     Scalar<T> identity) {
380   static_assert(T::category == TypeCategory::Integer);
381   std::optional<int> dim;
382   if (std::optional<Constant<T>> array{
383           ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
384               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
385     auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
386       element = (element.*operation)(array->At(at));
387     }};
388     return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
389   }
390   return Expr<T>{std::move(ref)};
391 }
392 
393 template <int KIND>
FoldIntrinsicFunction(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)394 Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
395     FoldingContext &context,
396     FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
397   using T = Type<TypeCategory::Integer, KIND>;
398   using Int4 = Type<TypeCategory::Integer, 4>;
399   ActualArguments &args{funcRef.arguments()};
400   auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
401   CHECK(intrinsic);
402   std::string name{intrinsic->name};
403   if (name == "abs") {
404     return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
405         ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> {
406           typename Scalar<T>::ValueWithOverflow j{i.ABS()};
407           if (j.overflow) {
408             context.messages().Say(
409                 "abs(integer(kind=%d)) folding overflowed"_en_US, KIND);
410           }
411           return j.value;
412         }));
413   } else if (name == "bit_size") {
414     return Expr<T>{Scalar<T>::bits};
415   } else if (name == "ceiling" || name == "floor" || name == "nint") {
416     if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
417       // NINT rounds ties away from zero, not to even
418       common::RoundingMode mode{name == "ceiling" ? common::RoundingMode::Up
419               : name == "floor"                   ? common::RoundingMode::Down
420                                 : common::RoundingMode::TiesAwayFromZero};
421       return std::visit(
422           [&](const auto &kx) {
423             using TR = ResultType<decltype(kx)>;
424             return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
425                 ScalarFunc<T, TR>([&](const Scalar<TR> &x) {
426                   auto y{x.template ToInteger<Scalar<T>>(mode)};
427                   if (y.flags.test(RealFlag::Overflow)) {
428                     context.messages().Say(
429                         "%s intrinsic folding overflow"_en_US, name);
430                   }
431                   return y.value;
432                 }));
433           },
434           cx->u);
435     }
436   } else if (name == "count") {
437     return FoldCount<T>(context, std::move(funcRef));
438   } else if (name == "digits") {
439     if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
440       return Expr<T>{std::visit(
441           [](const auto &kx) {
442             return Scalar<ResultType<decltype(kx)>>::DIGITS;
443           },
444           cx->u)};
445     } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
446       return Expr<T>{std::visit(
447           [](const auto &kx) {
448             return Scalar<ResultType<decltype(kx)>>::DIGITS;
449           },
450           cx->u)};
451     } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
452       return Expr<T>{std::visit(
453           [](const auto &kx) {
454             return Scalar<typename ResultType<decltype(kx)>::Part>::DIGITS;
455           },
456           cx->u)};
457     }
458   } else if (name == "dim") {
459     return FoldElementalIntrinsic<T, T, T>(
460         context, std::move(funcRef), &Scalar<T>::DIM);
461   } else if (name == "dshiftl" || name == "dshiftr") {
462     const auto fptr{
463         name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
464     // Third argument can be of any kind. However, it must be smaller or equal
465     // than BIT_SIZE. It can be converted to Int4 to simplify.
466     return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef),
467         ScalarFunc<T, T, T, Int4>(
468             [&fptr](const Scalar<T> &i, const Scalar<T> &j,
469                 const Scalar<Int4> &shift) -> Scalar<T> {
470               return std::invoke(fptr, i, j, static_cast<int>(shift.ToInt64()));
471             }));
472   } else if (name == "exponent") {
473     if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
474       return std::visit(
475           [&funcRef, &context](const auto &x) -> Expr<T> {
476             using TR = typename std::decay_t<decltype(x)>::Result;
477             return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
478                 &Scalar<TR>::template EXPONENT<Scalar<T>>);
479           },
480           sx->u);
481     } else {
482       DIE("exponent argument must be real");
483     }
484   } else if (name == "findloc") {
485     return FoldLocation<WhichLocation::Findloc, T>(context, std::move(funcRef));
486   } else if (name == "huge") {
487     return Expr<T>{Scalar<T>::HUGE()};
488   } else if (name == "iachar" || name == "ichar") {
489     auto *someChar{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
490     CHECK(someChar);
491     if (auto len{ToInt64(someChar->LEN())}) {
492       if (len.value() != 1) {
493         // Do not die, this was not checked before
494         context.messages().Say(
495             "Character in intrinsic function %s must have length one"_en_US,
496             name);
497       } else {
498         return std::visit(
499             [&funcRef, &context](const auto &str) -> Expr<T> {
500               using Char = typename std::decay_t<decltype(str)>::Result;
501               return FoldElementalIntrinsic<T, Char>(context,
502                   std::move(funcRef),
503                   ScalarFunc<T, Char>([](const Scalar<Char> &c) {
504                     return Scalar<T>{CharacterUtils<Char::kind>::ICHAR(c)};
505                   }));
506             },
507             someChar->u);
508       }
509     }
510   } else if (name == "iand" || name == "ior" || name == "ieor") {
511     auto fptr{&Scalar<T>::IAND};
512     if (name == "iand") { // done in fptr declaration
513     } else if (name == "ior") {
514       fptr = &Scalar<T>::IOR;
515     } else if (name == "ieor") {
516       fptr = &Scalar<T>::IEOR;
517     } else {
518       common::die("missing case to fold intrinsic function %s", name.c_str());
519     }
520     return FoldElementalIntrinsic<T, T, T>(
521         context, std::move(funcRef), ScalarFunc<T, T, T>(fptr));
522   } else if (name == "iall") {
523     return FoldBitReduction(
524         context, std::move(funcRef), &Scalar<T>::IAND, Scalar<T>{}.NOT());
525   } else if (name == "iany") {
526     return FoldBitReduction(
527         context, std::move(funcRef), &Scalar<T>::IOR, Scalar<T>{});
528   } else if (name == "ibclr" || name == "ibset") {
529     // Second argument can be of any kind. However, it must be smaller
530     // than BIT_SIZE. It can be converted to Int4 to simplify.
531     auto fptr{&Scalar<T>::IBCLR};
532     if (name == "ibclr") { // done in fptr definition
533     } else if (name == "ibset") {
534       fptr = &Scalar<T>::IBSET;
535     } else {
536       common::die("missing case to fold intrinsic function %s", name.c_str());
537     }
538     return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
539         ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
540                                    const Scalar<Int4> &pos) -> Scalar<T> {
541           auto posVal{static_cast<int>(pos.ToInt64())};
542           if (posVal < 0) {
543             context.messages().Say(
544                 "bit position for %s (%d) is negative"_err_en_US, name, posVal);
545           } else if (posVal >= i.bits) {
546             context.messages().Say(
547                 "bit position for %s (%d) is not less than %d"_err_en_US, name,
548                 posVal, i.bits);
549           }
550           return std::invoke(fptr, i, posVal);
551         }));
552   } else if (name == "index" || name == "scan" || name == "verify") {
553     if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
554       return std::visit(
555           [&](const auto &kch) -> Expr<T> {
556             using TC = typename std::decay_t<decltype(kch)>::Result;
557             if (UnwrapExpr<Expr<SomeLogical>>(args[2])) { // BACK=
558               return FoldElementalIntrinsic<T, TC, TC, LogicalResult>(context,
559                   std::move(funcRef),
560                   ScalarFunc<T, TC, TC, LogicalResult>{
561                       [&name](const Scalar<TC> &str, const Scalar<TC> &other,
562                           const Scalar<LogicalResult> &back) -> Scalar<T> {
563                         return name == "index"
564                             ? CharacterUtils<TC::kind>::INDEX(
565                                   str, other, back.IsTrue())
566                             : name == "scan" ? CharacterUtils<TC::kind>::SCAN(
567                                                    str, other, back.IsTrue())
568                                              : CharacterUtils<TC::kind>::VERIFY(
569                                                    str, other, back.IsTrue());
570                       }});
571             } else {
572               return FoldElementalIntrinsic<T, TC, TC>(context,
573                   std::move(funcRef),
574                   ScalarFunc<T, TC, TC>{
575                       [&name](const Scalar<TC> &str,
576                           const Scalar<TC> &other) -> Scalar<T> {
577                         return name == "index"
578                             ? CharacterUtils<TC::kind>::INDEX(str, other)
579                             : name == "scan"
580                             ? CharacterUtils<TC::kind>::SCAN(str, other)
581                             : CharacterUtils<TC::kind>::VERIFY(str, other);
582                       }});
583             }
584           },
585           charExpr->u);
586     } else {
587       DIE("first argument must be CHARACTER");
588     }
589   } else if (name == "int") {
590     if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
591       return std::visit(
592           [&](auto &&x) -> Expr<T> {
593             using From = std::decay_t<decltype(x)>;
594             if constexpr (std::is_same_v<From, BOZLiteralConstant> ||
595                 IsNumericCategoryExpr<From>()) {
596               return Fold(context, ConvertToType<T>(std::move(x)));
597             }
598             DIE("int() argument type not valid");
599           },
600           std::move(expr->u));
601     }
602   } else if (name == "int_ptr_kind") {
603     return Expr<T>{8};
604   } else if (name == "kind") {
605     if constexpr (common::HasMember<T, IntegerTypes>) {
606       return Expr<T>{args[0].value().GetType()->kind()};
607     } else {
608       DIE("kind() result not integral");
609     }
610   } else if (name == "iparity") {
611     return FoldBitReduction(
612         context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
613   } else if (name == "ishft" || name == "shifta" || name == "shiftr" ||
614       name == "shiftl") {
615     // Second argument can be of any kind. However, it must be smaller or
616     // equal than BIT_SIZE. It can be converted to Int4 to simplify.
617     auto fptr{&Scalar<T>::ISHFT};
618     if (name == "ishft") { // done in fptr definition
619     } else if (name == "shifta") {
620       fptr = &Scalar<T>::SHIFTA;
621     } else if (name == "shiftr") {
622       fptr = &Scalar<T>::SHIFTR;
623     } else if (name == "shiftl") {
624       fptr = &Scalar<T>::SHIFTL;
625     } else {
626       common::die("missing case to fold intrinsic function %s", name.c_str());
627     }
628     return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
629         ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
630                                    const Scalar<Int4> &pos) -> Scalar<T> {
631           auto posVal{static_cast<int>(pos.ToInt64())};
632           if (posVal < 0) {
633             context.messages().Say(
634                 "shift count for %s (%d) is negative"_err_en_US, name, posVal);
635           } else if (posVal > i.bits) {
636             context.messages().Say(
637                 "shift count for %s (%d) is greater than %d"_err_en_US, name,
638                 posVal, i.bits);
639           }
640           return std::invoke(fptr, i, posVal);
641         }));
642   } else if (name == "lbound") {
643     return LBOUND(context, std::move(funcRef));
644   } else if (name == "leadz" || name == "trailz" || name == "poppar" ||
645       name == "popcnt") {
646     if (auto *sn{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
647       return std::visit(
648           [&funcRef, &context, &name](const auto &n) -> Expr<T> {
649             using TI = typename std::decay_t<decltype(n)>::Result;
650             if (name == "poppar") {
651               return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
652                   ScalarFunc<T, TI>([](const Scalar<TI> &i) -> Scalar<T> {
653                     return Scalar<T>{i.POPPAR() ? 1 : 0};
654                   }));
655             }
656             auto fptr{&Scalar<TI>::LEADZ};
657             if (name == "leadz") { // done in fptr definition
658             } else if (name == "trailz") {
659               fptr = &Scalar<TI>::TRAILZ;
660             } else if (name == "popcnt") {
661               fptr = &Scalar<TI>::POPCNT;
662             } else {
663               common::die(
664                   "missing case to fold intrinsic function %s", name.c_str());
665             }
666             return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
667                 ScalarFunc<T, TI>([&fptr](const Scalar<TI> &i) -> Scalar<T> {
668                   return Scalar<T>{std::invoke(fptr, i)};
669                 }));
670           },
671           sn->u);
672     } else {
673       DIE("leadz argument must be integer");
674     }
675   } else if (name == "len") {
676     if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
677       return std::visit(
678           [&](auto &kx) {
679             if (auto len{kx.LEN()}) {
680               return Fold(context, ConvertToType<T>(*std::move(len)));
681             } else {
682               return Expr<T>{std::move(funcRef)};
683             }
684           },
685           charExpr->u);
686     } else {
687       DIE("len() argument must be of character type");
688     }
689   } else if (name == "len_trim") {
690     if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
691       return std::visit(
692           [&](const auto &kch) -> Expr<T> {
693             using TC = typename std::decay_t<decltype(kch)>::Result;
694             return FoldElementalIntrinsic<T, TC>(context, std::move(funcRef),
695                 ScalarFunc<T, TC>{[](const Scalar<TC> &str) -> Scalar<T> {
696                   return CharacterUtils<TC::kind>::LEN_TRIM(str);
697                 }});
698           },
699           charExpr->u);
700     } else {
701       DIE("len_trim() argument must be of character type");
702     }
703   } else if (name == "maskl" || name == "maskr") {
704     // Argument can be of any kind but value has to be smaller than BIT_SIZE.
705     // It can be safely converted to Int4 to simplify.
706     const auto fptr{name == "maskl" ? &Scalar<T>::MASKL : &Scalar<T>::MASKR};
707     return FoldElementalIntrinsic<T, Int4>(context, std::move(funcRef),
708         ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
709           return fptr(static_cast<int>(places.ToInt64()));
710         }));
711   } else if (name == "max") {
712     return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
713   } else if (name == "max0" || name == "max1") {
714     return RewriteSpecificMINorMAX(context, std::move(funcRef));
715   } else if (name == "maxexponent") {
716     if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
717       return std::visit(
718           [](const auto &x) {
719             using TR = typename std::decay_t<decltype(x)>::Result;
720             return Expr<T>{Scalar<TR>::MAXEXPONENT};
721           },
722           sx->u);
723     }
724   } else if (name == "maxloc") {
725     return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef));
726   } else if (name == "maxval") {
727     return FoldMaxvalMinval<T>(context, std::move(funcRef),
728         RelationalOperator::GT, T::Scalar::Least());
729   } else if (name == "merge") {
730     return FoldMerge<T>(context, std::move(funcRef));
731   } else if (name == "merge_bits") {
732     return FoldElementalIntrinsic<T, T, T, T>(
733         context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
734   } else if (name == "min") {
735     return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
736   } else if (name == "min0" || name == "min1") {
737     return RewriteSpecificMINorMAX(context, std::move(funcRef));
738   } else if (name == "minexponent") {
739     if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
740       return std::visit(
741           [](const auto &x) {
742             using TR = typename std::decay_t<decltype(x)>::Result;
743             return Expr<T>{Scalar<TR>::MINEXPONENT};
744           },
745           sx->u);
746     }
747   } else if (name == "minloc") {
748     return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef));
749   } else if (name == "minval") {
750     return FoldMaxvalMinval<T>(
751         context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
752   } else if (name == "mod") {
753     return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
754         ScalarFuncWithContext<T, T, T>(
755             [](FoldingContext &context, const Scalar<T> &x,
756                 const Scalar<T> &y) -> Scalar<T> {
757               auto quotRem{x.DivideSigned(y)};
758               if (quotRem.divisionByZero) {
759                 context.messages().Say("mod() by zero"_en_US);
760               } else if (quotRem.overflow) {
761                 context.messages().Say("mod() folding overflowed"_en_US);
762               }
763               return quotRem.remainder;
764             }));
765   } else if (name == "modulo") {
766     return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
767         ScalarFuncWithContext<T, T, T>(
768             [](FoldingContext &context, const Scalar<T> &x,
769                 const Scalar<T> &y) -> Scalar<T> {
770               auto result{x.MODULO(y)};
771               if (result.overflow) {
772                 context.messages().Say("modulo() folding overflowed"_en_US);
773               }
774               return result.value;
775             }));
776   } else if (name == "not") {
777     return FoldElementalIntrinsic<T, T>(
778         context, std::move(funcRef), &Scalar<T>::NOT);
779   } else if (name == "precision") {
780     if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
781       return Expr<T>{std::visit(
782           [](const auto &kx) {
783             return Scalar<ResultType<decltype(kx)>>::PRECISION;
784           },
785           cx->u)};
786     } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
787       return Expr<T>{std::visit(
788           [](const auto &kx) {
789             return Scalar<typename ResultType<decltype(kx)>::Part>::PRECISION;
790           },
791           cx->u)};
792     }
793   } else if (name == "product") {
794     return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{1});
795   } else if (name == "radix") {
796     return Expr<T>{2};
797   } else if (name == "range") {
798     if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
799       return Expr<T>{std::visit(
800           [](const auto &kx) {
801             return Scalar<ResultType<decltype(kx)>>::RANGE;
802           },
803           cx->u)};
804     } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
805       return Expr<T>{std::visit(
806           [](const auto &kx) {
807             return Scalar<ResultType<decltype(kx)>>::RANGE;
808           },
809           cx->u)};
810     } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
811       return Expr<T>{std::visit(
812           [](const auto &kx) {
813             return Scalar<typename ResultType<decltype(kx)>::Part>::RANGE;
814           },
815           cx->u)};
816     }
817   } else if (name == "rank") {
818     if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
819       if (auto named{ExtractNamedEntity(*array)}) {
820         const Symbol &symbol{named->GetLastSymbol()};
821         if (IsAssumedRank(symbol)) {
822           // DescriptorInquiry can only be placed in expression of kind
823           // DescriptorInquiry::Result::kind.
824           return ConvertToType<T>(Expr<
825               Type<TypeCategory::Integer, DescriptorInquiry::Result::kind>>{
826               DescriptorInquiry{*named, DescriptorInquiry::Field::Rank}});
827         }
828       }
829       return Expr<T>{args[0].value().Rank()};
830     }
831     return Expr<T>{args[0].value().Rank()};
832   } else if (name == "selected_char_kind") {
833     if (const auto *chCon{UnwrapExpr<Constant<TypeOf<std::string>>>(args[0])}) {
834       if (std::optional<std::string> value{chCon->GetScalarValue()}) {
835         int defaultKind{
836             context.defaults().GetDefaultKind(TypeCategory::Character)};
837         return Expr<T>{SelectedCharKind(*value, defaultKind)};
838       }
839     }
840   } else if (name == "selected_int_kind") {
841     if (auto p{GetInt64Arg(args[0])}) {
842       return Expr<T>{SelectedIntKind(*p)};
843     }
844   } else if (name == "selected_real_kind" ||
845       name == "__builtin_ieee_selected_real_kind") {
846     if (auto p{GetInt64ArgOr(args[0], 0)}) {
847       if (auto r{GetInt64ArgOr(args[1], 0)}) {
848         if (auto radix{GetInt64ArgOr(args[2], 2)}) {
849           return Expr<T>{SelectedRealKind(*p, *r, *radix)};
850         }
851       }
852     }
853   } else if (name == "shape") {
854     if (auto shape{GetShape(context, args[0])}) {
855       if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
856         return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
857       }
858     }
859   } else if (name == "sign") {
860     return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
861         ScalarFunc<T, T, T>(
862             [&context](const Scalar<T> &j, const Scalar<T> &k) -> Scalar<T> {
863               typename Scalar<T>::ValueWithOverflow result{j.SIGN(k)};
864               if (result.overflow) {
865                 context.messages().Say(
866                     "sign(integer(kind=%d)) folding overflowed"_en_US, KIND);
867               }
868               return result.value;
869             }));
870   } else if (name == "size") {
871     if (auto shape{GetShape(context, args[0])}) {
872       if (auto &dimArg{args[1]}) { // DIM= is present, get one extent
873         if (auto dim{GetInt64Arg(args[1])}) {
874           int rank{GetRank(*shape)};
875           if (*dim >= 1 && *dim <= rank) {
876             const Symbol *symbol{UnwrapWholeSymbolDataRef(args[0])};
877             if (symbol && IsAssumedSizeArray(*symbol) && *dim == rank) {
878               context.messages().Say(
879                   "size(array,dim=%jd) of last dimension is not available for rank-%d assumed-size array dummy argument"_err_en_US,
880                   *dim, rank);
881               return MakeInvalidIntrinsic<T>(std::move(funcRef));
882             } else if (auto &extent{shape->at(*dim - 1)}) {
883               return Fold(context, ConvertToType<T>(std::move(*extent)));
884             }
885           } else {
886             context.messages().Say(
887                 "size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
888                 *dim, rank);
889           }
890         }
891       } else if (auto extents{common::AllElementsPresent(std::move(*shape))}) {
892         // DIM= is absent; compute PRODUCT(SHAPE())
893         ExtentExpr product{1};
894         for (auto &&extent : std::move(*extents)) {
895           product = std::move(product) * std::move(extent);
896         }
897         return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))};
898       }
899     }
900   } else if (name == "sizeof") { // in bytes; extension
901     if (auto info{
902             characteristics::TypeAndShape::Characterize(args[0], context)}) {
903       if (auto bytes{info->MeasureSizeInBytes(context)}) {
904         return Expr<T>{Fold(context, ConvertToType<T>(std::move(*bytes)))};
905       }
906     }
907   } else if (name == "storage_size") { // in bits
908     if (auto info{
909             characteristics::TypeAndShape::Characterize(args[0], context)}) {
910       if (auto bytes{info->MeasureElementSizeInBytes(context, true)}) {
911         return Expr<T>{
912             Fold(context, Expr<T>{8} * ConvertToType<T>(std::move(*bytes)))};
913       }
914     }
915   } else if (name == "sum") {
916     return FoldSum<T>(context, std::move(funcRef));
917   } else if (name == "ubound") {
918     return UBOUND(context, std::move(funcRef));
919   }
920   // TODO: dot_product, ibits, ishftc, matmul, sign, transfer
921   return Expr<T>{std::move(funcRef)};
922 }
923 
924 // Substitutes a bare type parameter reference with its value if it has one now
925 // in an instantiation.  Bare LEN type parameters are substituted only when
926 // the known value is constant.
FoldOperation(FoldingContext & context,TypeParamInquiry && inquiry)927 Expr<TypeParamInquiry::Result> FoldOperation(
928     FoldingContext &context, TypeParamInquiry &&inquiry) {
929   std::optional<NamedEntity> base{inquiry.base()};
930   parser::CharBlock parameterName{inquiry.parameter().name()};
931   if (base) {
932     // Handling "designator%typeParam".  Get the value of the type parameter
933     // from the instantiation of the base
934     if (const semantics::DeclTypeSpec *
935         declType{base->GetLastSymbol().GetType()}) {
936       if (const semantics::ParamValue *
937           paramValue{
938               declType->derivedTypeSpec().FindParameter(parameterName)}) {
939         const semantics::MaybeIntExpr &paramExpr{paramValue->GetExplicit()};
940         if (paramExpr && IsConstantExpr(*paramExpr)) {
941           Expr<SomeInteger> intExpr{*paramExpr};
942           return Fold(context,
943               ConvertToType<TypeParamInquiry::Result>(std::move(intExpr)));
944         }
945       }
946     }
947   } else {
948     // A "bare" type parameter: replace with its value, if that's now known
949     // in a current derived type instantiation, for KIND type parameters.
950     if (const auto *pdt{context.pdtInstance()}) {
951       bool isLen{false};
952       if (const semantics::Scope * scope{context.pdtInstance()->scope()}) {
953         auto iter{scope->find(parameterName)};
954         if (iter != scope->end()) {
955           const Symbol &symbol{*iter->second};
956           const auto *details{symbol.detailsIf<semantics::TypeParamDetails>()};
957           if (details) {
958             isLen = details->attr() == common::TypeParamAttr::Len;
959             const semantics::MaybeIntExpr &initExpr{details->init()};
960             if (initExpr && IsConstantExpr(*initExpr) &&
961                 (!isLen || ToInt64(*initExpr))) {
962               Expr<SomeInteger> expr{*initExpr};
963               return Fold(context,
964                   ConvertToType<TypeParamInquiry::Result>(std::move(expr)));
965             }
966           }
967         }
968       }
969       if (const auto *value{pdt->FindParameter(parameterName)}) {
970         if (value->isExplicit()) {
971           auto folded{Fold(context,
972               AsExpr(ConvertToType<TypeParamInquiry::Result>(
973                   Expr<SomeInteger>{value->GetExplicit().value()})))};
974           if (!isLen || ToInt64(folded)) {
975             return folded;
976           }
977         }
978       }
979     }
980   }
981   return AsExpr(std::move(inquiry));
982 }
983 
ToInt64(const Expr<SomeInteger> & expr)984 std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) {
985   return std::visit(
986       [](const auto &kindExpr) { return ToInt64(kindExpr); }, expr.u);
987 }
988 
ToInt64(const Expr<SomeType> & expr)989 std::optional<std::int64_t> ToInt64(const Expr<SomeType> &expr) {
990   if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(expr)}) {
991     return ToInt64(*intExpr);
992   } else {
993     return std::nullopt;
994   }
995 }
996 
997 FOR_EACH_INTEGER_KIND(template class ExpressionBase, )
998 template class ExpressionBase<SomeInteger>;
999 } // namespace Fortran::evaluate
1000