1 //===-- lib/Evaluate/fold-integer.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "fold-implementation.h"
10 #include "fold-reduction.h"
11 #include "flang/Evaluate/check-expression.h"
12
13 namespace Fortran::evaluate {
14
15 // Class to retrieve the constant lower bound of an expression which is an
16 // array that devolves to a type of Constant<T>
17 class GetConstantArrayLboundHelper {
18 public:
GetConstantArrayLboundHelper(ConstantSubscript dim)19 GetConstantArrayLboundHelper(ConstantSubscript dim) : dim_{dim} {}
20
GetLbound(const T &)21 template <typename T> ConstantSubscript GetLbound(const T &) {
22 // The method is needed for template expansion, but we should never get
23 // here in practice.
24 CHECK(false);
25 return 0;
26 }
27
GetLbound(const Constant<T> & x)28 template <typename T> ConstantSubscript GetLbound(const Constant<T> &x) {
29 // Return the lower bound
30 return x.lbounds()[dim_];
31 }
32
GetLbound(const Parentheses<T> & x)33 template <typename T> ConstantSubscript GetLbound(const Parentheses<T> &x) {
34 // Strip off the parentheses
35 return GetLbound(x.left());
36 }
37
GetLbound(const Expr<T> & x)38 template <typename T> ConstantSubscript GetLbound(const Expr<T> &x) {
39 // recurse through Expr<T>'a until we hit a constant
40 return std::visit([&](const auto &inner) { return GetLbound(inner); },
41 // [&](const auto &) { return 0; },
42 x.u);
43 }
44
45 private:
46 ConstantSubscript dim_;
47 };
48
49 template <int KIND>
LBOUND(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)50 Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context,
51 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
52 using T = Type<TypeCategory::Integer, KIND>;
53 ActualArguments &args{funcRef.arguments()};
54 if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
55 if (int rank{array->Rank()}; rank > 0) {
56 std::optional<int> dim;
57 if (funcRef.Rank() == 0) {
58 // Optional DIM= argument is present: result is scalar.
59 if (auto dim64{GetInt64Arg(args[1])}) {
60 if (*dim64 < 1 || *dim64 > rank) {
61 context.messages().Say("DIM=%jd dimension is out of range for "
62 "rank-%d array"_err_en_US,
63 *dim64, rank);
64 return MakeInvalidIntrinsic<T>(std::move(funcRef));
65 } else {
66 dim = *dim64 - 1; // 1-based to 0-based
67 }
68 } else {
69 // DIM= is present but not constant
70 return Expr<T>{std::move(funcRef)};
71 }
72 }
73 bool lowerBoundsAreOne{true};
74 if (auto named{ExtractNamedEntity(*array)}) {
75 const Symbol &symbol{named->GetLastSymbol()};
76 if (symbol.Rank() == rank) {
77 lowerBoundsAreOne = false;
78 if (dim) {
79 return Fold(context,
80 ConvertToType<T>(GetLowerBound(context, *named, *dim)));
81 } else if (auto extents{
82 AsExtentArrayExpr(GetLowerBounds(context, *named))}) {
83 return Fold(context,
84 ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
85 }
86 } else {
87 lowerBoundsAreOne = symbol.Rank() == 0; // LBOUND(array%component)
88 }
89 }
90 if (IsActuallyConstant(*array)) {
91 return Expr<T>{GetConstantArrayLboundHelper{*dim}.GetLbound(*array)};
92 }
93 if (lowerBoundsAreOne) {
94 if (dim) {
95 return Expr<T>{1};
96 } else {
97 std::vector<Scalar<T>> ones(rank, Scalar<T>{1});
98 return Expr<T>{
99 Constant<T>{std::move(ones), ConstantSubscripts{rank}}};
100 }
101 }
102 }
103 }
104 return Expr<T>{std::move(funcRef)};
105 }
106
107 template <int KIND>
UBOUND(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)108 Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
109 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
110 using T = Type<TypeCategory::Integer, KIND>;
111 ActualArguments &args{funcRef.arguments()};
112 if (auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
113 if (int rank{array->Rank()}; rank > 0) {
114 std::optional<int> dim;
115 if (funcRef.Rank() == 0) {
116 // Optional DIM= argument is present: result is scalar.
117 if (auto dim64{GetInt64Arg(args[1])}) {
118 if (*dim64 < 1 || *dim64 > rank) {
119 context.messages().Say("DIM=%jd dimension is out of range for "
120 "rank-%d array"_err_en_US,
121 *dim64, rank);
122 return MakeInvalidIntrinsic<T>(std::move(funcRef));
123 } else {
124 dim = *dim64 - 1; // 1-based to 0-based
125 }
126 } else {
127 // DIM= is present but not constant
128 return Expr<T>{std::move(funcRef)};
129 }
130 }
131 bool takeBoundsFromShape{true};
132 if (auto named{ExtractNamedEntity(*array)}) {
133 const Symbol &symbol{named->GetLastSymbol()};
134 if (symbol.Rank() == rank) {
135 takeBoundsFromShape = false;
136 if (dim) {
137 if (semantics::IsAssumedSizeArray(symbol) && *dim == rank - 1) {
138 context.messages().Say("DIM=%jd dimension is out of range for "
139 "rank-%d assumed-size array"_err_en_US,
140 rank, rank);
141 return MakeInvalidIntrinsic<T>(std::move(funcRef));
142 } else if (auto ub{GetUpperBound(context, *named, *dim)}) {
143 return Fold(context, ConvertToType<T>(std::move(*ub)));
144 }
145 } else {
146 Shape ubounds{GetUpperBounds(context, *named)};
147 if (semantics::IsAssumedSizeArray(symbol)) {
148 CHECK(!ubounds.back());
149 ubounds.back() = ExtentExpr{-1};
150 }
151 if (auto extents{AsExtentArrayExpr(ubounds)}) {
152 return Fold(context,
153 ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
154 }
155 }
156 } else {
157 takeBoundsFromShape = symbol.Rank() == 0; // UBOUND(array%component)
158 }
159 }
160 if (takeBoundsFromShape) {
161 if (auto shape{GetShape(context, *array)}) {
162 if (dim) {
163 if (auto &dimSize{shape->at(*dim)}) {
164 return Fold(context,
165 ConvertToType<T>(Expr<ExtentType>{std::move(*dimSize)}));
166 }
167 } else if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
168 return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
169 }
170 }
171 }
172 }
173 }
174 return Expr<T>{std::move(funcRef)};
175 }
176
177 // COUNT()
178 template <typename T>
FoldCount(FoldingContext & context,FunctionRef<T> && ref)179 static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
180 static_assert(T::category == TypeCategory::Integer);
181 ActualArguments &arg{ref.arguments()};
182 if (const Constant<LogicalResult> *mask{arg.empty()
183 ? nullptr
184 : Folder<LogicalResult>{context}.Folding(arg[0])}) {
185 std::optional<int> dim;
186 if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
187 auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
188 if (mask->At(at).IsTrue()) {
189 element = element.AddSigned(Scalar<T>{1}).value;
190 }
191 }};
192 return Expr<T>{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
193 }
194 }
195 return Expr<T>{std::move(ref)};
196 }
197
198 // FINDLOC(), MAXLOC(), & MINLOC()
199 enum class WhichLocation { Findloc, Maxloc, Minloc };
200 template <WhichLocation WHICH> class LocationHelper {
201 public:
LocationHelper(DynamicType && type,ActualArguments & arg,FoldingContext & context)202 LocationHelper(
203 DynamicType &&type, ActualArguments &arg, FoldingContext &context)
204 : type_{type}, arg_{arg}, context_{context} {}
205 using Result = std::optional<Constant<SubscriptInteger>>;
206 using Types = std::conditional_t<WHICH == WhichLocation::Findloc,
207 AllIntrinsicTypes, RelationalTypes>;
208
Test() const209 template <typename T> Result Test() const {
210 if (T::category != type_.category() || T::kind != type_.kind()) {
211 return std::nullopt;
212 }
213 CHECK(arg_.size() == (WHICH == WhichLocation::Findloc ? 6 : 5));
214 Folder<T> folder{context_};
215 Constant<T> *array{folder.Folding(arg_[0])};
216 if (!array) {
217 return std::nullopt;
218 }
219 std::optional<Constant<T>> value;
220 if constexpr (WHICH == WhichLocation::Findloc) {
221 if (const Constant<T> *p{folder.Folding(arg_[1])}) {
222 value.emplace(*p);
223 } else {
224 return std::nullopt;
225 }
226 }
227 std::optional<int> dim;
228 Constant<LogicalResult> *mask{
229 GetReductionMASK(arg_[maskArg], array->shape(), context_)};
230 if ((!mask && arg_[maskArg]) ||
231 !CheckReductionDIM(dim, context_, arg_, dimArg, array->Rank())) {
232 return std::nullopt;
233 }
234 bool back{false};
235 if (arg_[backArg]) {
236 const auto *backConst{
237 Folder<LogicalResult>{context_}.Folding(arg_[backArg])};
238 if (backConst) {
239 back = backConst->GetScalarValue().value().IsTrue();
240 } else {
241 return std::nullopt;
242 }
243 }
244 const RelationalOperator relation{WHICH == WhichLocation::Findloc
245 ? RelationalOperator::EQ
246 : WHICH == WhichLocation::Maxloc
247 ? (back ? RelationalOperator::GE : RelationalOperator::GT)
248 : back ? RelationalOperator::LE
249 : RelationalOperator::LT};
250 // Use lower bounds of 1 exclusively.
251 array->SetLowerBoundsToOne();
252 ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
253 if (mask) {
254 mask->SetLowerBoundsToOne();
255 maskAt = mask->lbounds();
256 }
257 if (dim) { // DIM=
258 if (*dim < 1 || *dim > array->Rank()) {
259 context_.messages().Say(
260 "FINDLOC(DIM=%d) is out of range"_err_en_US, *dim);
261 return std::nullopt;
262 }
263 int zbDim{*dim - 1};
264 resultShape = array->shape();
265 resultShape.erase(
266 resultShape.begin() + zbDim); // scalar if array is vector
267 ConstantSubscript dimLength{array->shape()[zbDim]};
268 ConstantSubscript n{GetSize(resultShape)};
269 for (ConstantSubscript j{0}; j < n; ++j) {
270 ConstantSubscript hit{array->lbounds()[zbDim] - 1};
271 value.reset();
272 for (ConstantSubscript k{0}; k < dimLength;
273 ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
274 if ((!mask || mask->At(maskAt).IsTrue()) &&
275 IsHit(array->At(at), value, relation)) {
276 hit = at[zbDim];
277 if (!back) {
278 break;
279 }
280 }
281 }
282 resultIndices.emplace_back(hit);
283 at[zbDim] = array->lbounds()[zbDim] + dimLength - 1;
284 array->IncrementSubscripts(at);
285 at[zbDim] = array->lbounds()[zbDim];
286 if (mask) {
287 maskAt[zbDim] = mask->lbounds()[zbDim] + dimLength - 1;
288 mask->IncrementSubscripts(maskAt);
289 maskAt[zbDim] = mask->lbounds()[zbDim];
290 }
291 }
292 } else { // no DIM=
293 resultShape = ConstantSubscripts{array->Rank()}; // always a vector
294 ConstantSubscript n{GetSize(array->shape())};
295 resultIndices = ConstantSubscripts(array->Rank(), 0);
296 for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
297 mask && mask->IncrementSubscripts(maskAt)) {
298 if ((!mask || mask->At(maskAt).IsTrue()) &&
299 IsHit(array->At(at), value, relation)) {
300 resultIndices = at;
301 if (!back) {
302 break;
303 }
304 }
305 }
306 }
307 std::vector<Scalar<SubscriptInteger>> resultElements;
308 for (ConstantSubscript j : resultIndices) {
309 resultElements.emplace_back(j);
310 }
311 return Constant<SubscriptInteger>{
312 std::move(resultElements), std::move(resultShape)};
313 }
314
315 private:
316 template <typename T>
IsHit(typename Constant<T>::Element element,std::optional<Constant<T>> & value,RelationalOperator relation) const317 bool IsHit(typename Constant<T>::Element element,
318 std::optional<Constant<T>> &value,
319 [[maybe_unused]] RelationalOperator relation) const {
320 std::optional<Expr<LogicalResult>> cmp;
321 if (value) {
322 if constexpr (T::category == TypeCategory::Logical) {
323 // array(at) .EQV. value?
324 static_assert(WHICH == WhichLocation::Findloc);
325 cmp.emplace(
326 ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
327 LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
328 Expr<T>{Constant<T>{*value}}}}));
329 } else { // compare array(at) to value
330 cmp.emplace(
331 PackageRelation(relation, Expr<T>{Constant<T>{std::move(element)}},
332 Expr<T>{Constant<T>{*value}}));
333 }
334 Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
335 return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
336 } else { // first unmasked element seen for MAXLOC/MINLOC
337 value.emplace(std::move(element));
338 return true;
339 }
340 }
341
342 static constexpr int dimArg{WHICH == WhichLocation::Findloc ? 2 : 1};
343 static constexpr int maskArg{dimArg + 1};
344 static constexpr int backArg{maskArg + 2};
345
346 DynamicType type_;
347 ActualArguments &arg_;
348 FoldingContext &context_;
349 };
350
351 template <WhichLocation which>
FoldLocationCall(ActualArguments & arg,FoldingContext & context)352 static std::optional<Constant<SubscriptInteger>> FoldLocationCall(
353 ActualArguments &arg, FoldingContext &context) {
354 if (arg[0]) {
355 if (auto type{arg[0]->GetType()}) {
356 return common::SearchTypes(
357 LocationHelper<which>{std::move(*type), arg, context});
358 }
359 }
360 return std::nullopt;
361 }
362
363 template <WhichLocation which, typename T>
FoldLocation(FoldingContext & context,FunctionRef<T> && ref)364 static Expr<T> FoldLocation(FoldingContext &context, FunctionRef<T> &&ref) {
365 static_assert(T::category == TypeCategory::Integer);
366 if (std::optional<Constant<SubscriptInteger>> found{
367 FoldLocationCall<which>(ref.arguments(), context)}) {
368 return Expr<T>{Fold(
369 context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
370 } else {
371 return Expr<T>{std::move(ref)};
372 }
373 }
374
375 // for IALL, IANY, & IPARITY
376 template <typename T>
FoldBitReduction(FoldingContext & context,FunctionRef<T> && ref,Scalar<T> (Scalar<T>::* operation)(const Scalar<T> &)const,Scalar<T> identity)377 static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
378 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
379 Scalar<T> identity) {
380 static_assert(T::category == TypeCategory::Integer);
381 std::optional<int> dim;
382 if (std::optional<Constant<T>> array{
383 ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
384 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
385 auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
386 element = (element.*operation)(array->At(at));
387 }};
388 return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
389 }
390 return Expr<T>{std::move(ref)};
391 }
392
393 template <int KIND>
FoldIntrinsicFunction(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)394 Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
395 FoldingContext &context,
396 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
397 using T = Type<TypeCategory::Integer, KIND>;
398 using Int4 = Type<TypeCategory::Integer, 4>;
399 ActualArguments &args{funcRef.arguments()};
400 auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
401 CHECK(intrinsic);
402 std::string name{intrinsic->name};
403 if (name == "abs") {
404 return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
405 ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> {
406 typename Scalar<T>::ValueWithOverflow j{i.ABS()};
407 if (j.overflow) {
408 context.messages().Say(
409 "abs(integer(kind=%d)) folding overflowed"_en_US, KIND);
410 }
411 return j.value;
412 }));
413 } else if (name == "bit_size") {
414 return Expr<T>{Scalar<T>::bits};
415 } else if (name == "ceiling" || name == "floor" || name == "nint") {
416 if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
417 // NINT rounds ties away from zero, not to even
418 common::RoundingMode mode{name == "ceiling" ? common::RoundingMode::Up
419 : name == "floor" ? common::RoundingMode::Down
420 : common::RoundingMode::TiesAwayFromZero};
421 return std::visit(
422 [&](const auto &kx) {
423 using TR = ResultType<decltype(kx)>;
424 return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
425 ScalarFunc<T, TR>([&](const Scalar<TR> &x) {
426 auto y{x.template ToInteger<Scalar<T>>(mode)};
427 if (y.flags.test(RealFlag::Overflow)) {
428 context.messages().Say(
429 "%s intrinsic folding overflow"_en_US, name);
430 }
431 return y.value;
432 }));
433 },
434 cx->u);
435 }
436 } else if (name == "count") {
437 return FoldCount<T>(context, std::move(funcRef));
438 } else if (name == "digits") {
439 if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
440 return Expr<T>{std::visit(
441 [](const auto &kx) {
442 return Scalar<ResultType<decltype(kx)>>::DIGITS;
443 },
444 cx->u)};
445 } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
446 return Expr<T>{std::visit(
447 [](const auto &kx) {
448 return Scalar<ResultType<decltype(kx)>>::DIGITS;
449 },
450 cx->u)};
451 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
452 return Expr<T>{std::visit(
453 [](const auto &kx) {
454 return Scalar<typename ResultType<decltype(kx)>::Part>::DIGITS;
455 },
456 cx->u)};
457 }
458 } else if (name == "dim") {
459 return FoldElementalIntrinsic<T, T, T>(
460 context, std::move(funcRef), &Scalar<T>::DIM);
461 } else if (name == "dshiftl" || name == "dshiftr") {
462 const auto fptr{
463 name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
464 // Third argument can be of any kind. However, it must be smaller or equal
465 // than BIT_SIZE. It can be converted to Int4 to simplify.
466 return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef),
467 ScalarFunc<T, T, T, Int4>(
468 [&fptr](const Scalar<T> &i, const Scalar<T> &j,
469 const Scalar<Int4> &shift) -> Scalar<T> {
470 return std::invoke(fptr, i, j, static_cast<int>(shift.ToInt64()));
471 }));
472 } else if (name == "exponent") {
473 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
474 return std::visit(
475 [&funcRef, &context](const auto &x) -> Expr<T> {
476 using TR = typename std::decay_t<decltype(x)>::Result;
477 return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
478 &Scalar<TR>::template EXPONENT<Scalar<T>>);
479 },
480 sx->u);
481 } else {
482 DIE("exponent argument must be real");
483 }
484 } else if (name == "findloc") {
485 return FoldLocation<WhichLocation::Findloc, T>(context, std::move(funcRef));
486 } else if (name == "huge") {
487 return Expr<T>{Scalar<T>::HUGE()};
488 } else if (name == "iachar" || name == "ichar") {
489 auto *someChar{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
490 CHECK(someChar);
491 if (auto len{ToInt64(someChar->LEN())}) {
492 if (len.value() != 1) {
493 // Do not die, this was not checked before
494 context.messages().Say(
495 "Character in intrinsic function %s must have length one"_en_US,
496 name);
497 } else {
498 return std::visit(
499 [&funcRef, &context](const auto &str) -> Expr<T> {
500 using Char = typename std::decay_t<decltype(str)>::Result;
501 return FoldElementalIntrinsic<T, Char>(context,
502 std::move(funcRef),
503 ScalarFunc<T, Char>([](const Scalar<Char> &c) {
504 return Scalar<T>{CharacterUtils<Char::kind>::ICHAR(c)};
505 }));
506 },
507 someChar->u);
508 }
509 }
510 } else if (name == "iand" || name == "ior" || name == "ieor") {
511 auto fptr{&Scalar<T>::IAND};
512 if (name == "iand") { // done in fptr declaration
513 } else if (name == "ior") {
514 fptr = &Scalar<T>::IOR;
515 } else if (name == "ieor") {
516 fptr = &Scalar<T>::IEOR;
517 } else {
518 common::die("missing case to fold intrinsic function %s", name.c_str());
519 }
520 return FoldElementalIntrinsic<T, T, T>(
521 context, std::move(funcRef), ScalarFunc<T, T, T>(fptr));
522 } else if (name == "iall") {
523 return FoldBitReduction(
524 context, std::move(funcRef), &Scalar<T>::IAND, Scalar<T>{}.NOT());
525 } else if (name == "iany") {
526 return FoldBitReduction(
527 context, std::move(funcRef), &Scalar<T>::IOR, Scalar<T>{});
528 } else if (name == "ibclr" || name == "ibset") {
529 // Second argument can be of any kind. However, it must be smaller
530 // than BIT_SIZE. It can be converted to Int4 to simplify.
531 auto fptr{&Scalar<T>::IBCLR};
532 if (name == "ibclr") { // done in fptr definition
533 } else if (name == "ibset") {
534 fptr = &Scalar<T>::IBSET;
535 } else {
536 common::die("missing case to fold intrinsic function %s", name.c_str());
537 }
538 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
539 ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
540 const Scalar<Int4> &pos) -> Scalar<T> {
541 auto posVal{static_cast<int>(pos.ToInt64())};
542 if (posVal < 0) {
543 context.messages().Say(
544 "bit position for %s (%d) is negative"_err_en_US, name, posVal);
545 } else if (posVal >= i.bits) {
546 context.messages().Say(
547 "bit position for %s (%d) is not less than %d"_err_en_US, name,
548 posVal, i.bits);
549 }
550 return std::invoke(fptr, i, posVal);
551 }));
552 } else if (name == "index" || name == "scan" || name == "verify") {
553 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
554 return std::visit(
555 [&](const auto &kch) -> Expr<T> {
556 using TC = typename std::decay_t<decltype(kch)>::Result;
557 if (UnwrapExpr<Expr<SomeLogical>>(args[2])) { // BACK=
558 return FoldElementalIntrinsic<T, TC, TC, LogicalResult>(context,
559 std::move(funcRef),
560 ScalarFunc<T, TC, TC, LogicalResult>{
561 [&name](const Scalar<TC> &str, const Scalar<TC> &other,
562 const Scalar<LogicalResult> &back) -> Scalar<T> {
563 return name == "index"
564 ? CharacterUtils<TC::kind>::INDEX(
565 str, other, back.IsTrue())
566 : name == "scan" ? CharacterUtils<TC::kind>::SCAN(
567 str, other, back.IsTrue())
568 : CharacterUtils<TC::kind>::VERIFY(
569 str, other, back.IsTrue());
570 }});
571 } else {
572 return FoldElementalIntrinsic<T, TC, TC>(context,
573 std::move(funcRef),
574 ScalarFunc<T, TC, TC>{
575 [&name](const Scalar<TC> &str,
576 const Scalar<TC> &other) -> Scalar<T> {
577 return name == "index"
578 ? CharacterUtils<TC::kind>::INDEX(str, other)
579 : name == "scan"
580 ? CharacterUtils<TC::kind>::SCAN(str, other)
581 : CharacterUtils<TC::kind>::VERIFY(str, other);
582 }});
583 }
584 },
585 charExpr->u);
586 } else {
587 DIE("first argument must be CHARACTER");
588 }
589 } else if (name == "int") {
590 if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
591 return std::visit(
592 [&](auto &&x) -> Expr<T> {
593 using From = std::decay_t<decltype(x)>;
594 if constexpr (std::is_same_v<From, BOZLiteralConstant> ||
595 IsNumericCategoryExpr<From>()) {
596 return Fold(context, ConvertToType<T>(std::move(x)));
597 }
598 DIE("int() argument type not valid");
599 },
600 std::move(expr->u));
601 }
602 } else if (name == "int_ptr_kind") {
603 return Expr<T>{8};
604 } else if (name == "kind") {
605 if constexpr (common::HasMember<T, IntegerTypes>) {
606 return Expr<T>{args[0].value().GetType()->kind()};
607 } else {
608 DIE("kind() result not integral");
609 }
610 } else if (name == "iparity") {
611 return FoldBitReduction(
612 context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
613 } else if (name == "ishft" || name == "shifta" || name == "shiftr" ||
614 name == "shiftl") {
615 // Second argument can be of any kind. However, it must be smaller or
616 // equal than BIT_SIZE. It can be converted to Int4 to simplify.
617 auto fptr{&Scalar<T>::ISHFT};
618 if (name == "ishft") { // done in fptr definition
619 } else if (name == "shifta") {
620 fptr = &Scalar<T>::SHIFTA;
621 } else if (name == "shiftr") {
622 fptr = &Scalar<T>::SHIFTR;
623 } else if (name == "shiftl") {
624 fptr = &Scalar<T>::SHIFTL;
625 } else {
626 common::die("missing case to fold intrinsic function %s", name.c_str());
627 }
628 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
629 ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
630 const Scalar<Int4> &pos) -> Scalar<T> {
631 auto posVal{static_cast<int>(pos.ToInt64())};
632 if (posVal < 0) {
633 context.messages().Say(
634 "shift count for %s (%d) is negative"_err_en_US, name, posVal);
635 } else if (posVal > i.bits) {
636 context.messages().Say(
637 "shift count for %s (%d) is greater than %d"_err_en_US, name,
638 posVal, i.bits);
639 }
640 return std::invoke(fptr, i, posVal);
641 }));
642 } else if (name == "lbound") {
643 return LBOUND(context, std::move(funcRef));
644 } else if (name == "leadz" || name == "trailz" || name == "poppar" ||
645 name == "popcnt") {
646 if (auto *sn{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
647 return std::visit(
648 [&funcRef, &context, &name](const auto &n) -> Expr<T> {
649 using TI = typename std::decay_t<decltype(n)>::Result;
650 if (name == "poppar") {
651 return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
652 ScalarFunc<T, TI>([](const Scalar<TI> &i) -> Scalar<T> {
653 return Scalar<T>{i.POPPAR() ? 1 : 0};
654 }));
655 }
656 auto fptr{&Scalar<TI>::LEADZ};
657 if (name == "leadz") { // done in fptr definition
658 } else if (name == "trailz") {
659 fptr = &Scalar<TI>::TRAILZ;
660 } else if (name == "popcnt") {
661 fptr = &Scalar<TI>::POPCNT;
662 } else {
663 common::die(
664 "missing case to fold intrinsic function %s", name.c_str());
665 }
666 return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
667 ScalarFunc<T, TI>([&fptr](const Scalar<TI> &i) -> Scalar<T> {
668 return Scalar<T>{std::invoke(fptr, i)};
669 }));
670 },
671 sn->u);
672 } else {
673 DIE("leadz argument must be integer");
674 }
675 } else if (name == "len") {
676 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
677 return std::visit(
678 [&](auto &kx) {
679 if (auto len{kx.LEN()}) {
680 return Fold(context, ConvertToType<T>(*std::move(len)));
681 } else {
682 return Expr<T>{std::move(funcRef)};
683 }
684 },
685 charExpr->u);
686 } else {
687 DIE("len() argument must be of character type");
688 }
689 } else if (name == "len_trim") {
690 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
691 return std::visit(
692 [&](const auto &kch) -> Expr<T> {
693 using TC = typename std::decay_t<decltype(kch)>::Result;
694 return FoldElementalIntrinsic<T, TC>(context, std::move(funcRef),
695 ScalarFunc<T, TC>{[](const Scalar<TC> &str) -> Scalar<T> {
696 return CharacterUtils<TC::kind>::LEN_TRIM(str);
697 }});
698 },
699 charExpr->u);
700 } else {
701 DIE("len_trim() argument must be of character type");
702 }
703 } else if (name == "maskl" || name == "maskr") {
704 // Argument can be of any kind but value has to be smaller than BIT_SIZE.
705 // It can be safely converted to Int4 to simplify.
706 const auto fptr{name == "maskl" ? &Scalar<T>::MASKL : &Scalar<T>::MASKR};
707 return FoldElementalIntrinsic<T, Int4>(context, std::move(funcRef),
708 ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
709 return fptr(static_cast<int>(places.ToInt64()));
710 }));
711 } else if (name == "max") {
712 return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
713 } else if (name == "max0" || name == "max1") {
714 return RewriteSpecificMINorMAX(context, std::move(funcRef));
715 } else if (name == "maxexponent") {
716 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
717 return std::visit(
718 [](const auto &x) {
719 using TR = typename std::decay_t<decltype(x)>::Result;
720 return Expr<T>{Scalar<TR>::MAXEXPONENT};
721 },
722 sx->u);
723 }
724 } else if (name == "maxloc") {
725 return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef));
726 } else if (name == "maxval") {
727 return FoldMaxvalMinval<T>(context, std::move(funcRef),
728 RelationalOperator::GT, T::Scalar::Least());
729 } else if (name == "merge") {
730 return FoldMerge<T>(context, std::move(funcRef));
731 } else if (name == "merge_bits") {
732 return FoldElementalIntrinsic<T, T, T, T>(
733 context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
734 } else if (name == "min") {
735 return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
736 } else if (name == "min0" || name == "min1") {
737 return RewriteSpecificMINorMAX(context, std::move(funcRef));
738 } else if (name == "minexponent") {
739 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
740 return std::visit(
741 [](const auto &x) {
742 using TR = typename std::decay_t<decltype(x)>::Result;
743 return Expr<T>{Scalar<TR>::MINEXPONENT};
744 },
745 sx->u);
746 }
747 } else if (name == "minloc") {
748 return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef));
749 } else if (name == "minval") {
750 return FoldMaxvalMinval<T>(
751 context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
752 } else if (name == "mod") {
753 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
754 ScalarFuncWithContext<T, T, T>(
755 [](FoldingContext &context, const Scalar<T> &x,
756 const Scalar<T> &y) -> Scalar<T> {
757 auto quotRem{x.DivideSigned(y)};
758 if (quotRem.divisionByZero) {
759 context.messages().Say("mod() by zero"_en_US);
760 } else if (quotRem.overflow) {
761 context.messages().Say("mod() folding overflowed"_en_US);
762 }
763 return quotRem.remainder;
764 }));
765 } else if (name == "modulo") {
766 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
767 ScalarFuncWithContext<T, T, T>(
768 [](FoldingContext &context, const Scalar<T> &x,
769 const Scalar<T> &y) -> Scalar<T> {
770 auto result{x.MODULO(y)};
771 if (result.overflow) {
772 context.messages().Say("modulo() folding overflowed"_en_US);
773 }
774 return result.value;
775 }));
776 } else if (name == "not") {
777 return FoldElementalIntrinsic<T, T>(
778 context, std::move(funcRef), &Scalar<T>::NOT);
779 } else if (name == "precision") {
780 if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
781 return Expr<T>{std::visit(
782 [](const auto &kx) {
783 return Scalar<ResultType<decltype(kx)>>::PRECISION;
784 },
785 cx->u)};
786 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
787 return Expr<T>{std::visit(
788 [](const auto &kx) {
789 return Scalar<typename ResultType<decltype(kx)>::Part>::PRECISION;
790 },
791 cx->u)};
792 }
793 } else if (name == "product") {
794 return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{1});
795 } else if (name == "radix") {
796 return Expr<T>{2};
797 } else if (name == "range") {
798 if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
799 return Expr<T>{std::visit(
800 [](const auto &kx) {
801 return Scalar<ResultType<decltype(kx)>>::RANGE;
802 },
803 cx->u)};
804 } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
805 return Expr<T>{std::visit(
806 [](const auto &kx) {
807 return Scalar<ResultType<decltype(kx)>>::RANGE;
808 },
809 cx->u)};
810 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
811 return Expr<T>{std::visit(
812 [](const auto &kx) {
813 return Scalar<typename ResultType<decltype(kx)>::Part>::RANGE;
814 },
815 cx->u)};
816 }
817 } else if (name == "rank") {
818 if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
819 if (auto named{ExtractNamedEntity(*array)}) {
820 const Symbol &symbol{named->GetLastSymbol()};
821 if (IsAssumedRank(symbol)) {
822 // DescriptorInquiry can only be placed in expression of kind
823 // DescriptorInquiry::Result::kind.
824 return ConvertToType<T>(Expr<
825 Type<TypeCategory::Integer, DescriptorInquiry::Result::kind>>{
826 DescriptorInquiry{*named, DescriptorInquiry::Field::Rank}});
827 }
828 }
829 return Expr<T>{args[0].value().Rank()};
830 }
831 return Expr<T>{args[0].value().Rank()};
832 } else if (name == "selected_char_kind") {
833 if (const auto *chCon{UnwrapExpr<Constant<TypeOf<std::string>>>(args[0])}) {
834 if (std::optional<std::string> value{chCon->GetScalarValue()}) {
835 int defaultKind{
836 context.defaults().GetDefaultKind(TypeCategory::Character)};
837 return Expr<T>{SelectedCharKind(*value, defaultKind)};
838 }
839 }
840 } else if (name == "selected_int_kind") {
841 if (auto p{GetInt64Arg(args[0])}) {
842 return Expr<T>{SelectedIntKind(*p)};
843 }
844 } else if (name == "selected_real_kind" ||
845 name == "__builtin_ieee_selected_real_kind") {
846 if (auto p{GetInt64ArgOr(args[0], 0)}) {
847 if (auto r{GetInt64ArgOr(args[1], 0)}) {
848 if (auto radix{GetInt64ArgOr(args[2], 2)}) {
849 return Expr<T>{SelectedRealKind(*p, *r, *radix)};
850 }
851 }
852 }
853 } else if (name == "shape") {
854 if (auto shape{GetShape(context, args[0])}) {
855 if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
856 return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
857 }
858 }
859 } else if (name == "sign") {
860 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
861 ScalarFunc<T, T, T>(
862 [&context](const Scalar<T> &j, const Scalar<T> &k) -> Scalar<T> {
863 typename Scalar<T>::ValueWithOverflow result{j.SIGN(k)};
864 if (result.overflow) {
865 context.messages().Say(
866 "sign(integer(kind=%d)) folding overflowed"_en_US, KIND);
867 }
868 return result.value;
869 }));
870 } else if (name == "size") {
871 if (auto shape{GetShape(context, args[0])}) {
872 if (auto &dimArg{args[1]}) { // DIM= is present, get one extent
873 if (auto dim{GetInt64Arg(args[1])}) {
874 int rank{GetRank(*shape)};
875 if (*dim >= 1 && *dim <= rank) {
876 const Symbol *symbol{UnwrapWholeSymbolDataRef(args[0])};
877 if (symbol && IsAssumedSizeArray(*symbol) && *dim == rank) {
878 context.messages().Say(
879 "size(array,dim=%jd) of last dimension is not available for rank-%d assumed-size array dummy argument"_err_en_US,
880 *dim, rank);
881 return MakeInvalidIntrinsic<T>(std::move(funcRef));
882 } else if (auto &extent{shape->at(*dim - 1)}) {
883 return Fold(context, ConvertToType<T>(std::move(*extent)));
884 }
885 } else {
886 context.messages().Say(
887 "size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US,
888 *dim, rank);
889 }
890 }
891 } else if (auto extents{common::AllElementsPresent(std::move(*shape))}) {
892 // DIM= is absent; compute PRODUCT(SHAPE())
893 ExtentExpr product{1};
894 for (auto &&extent : std::move(*extents)) {
895 product = std::move(product) * std::move(extent);
896 }
897 return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))};
898 }
899 }
900 } else if (name == "sizeof") { // in bytes; extension
901 if (auto info{
902 characteristics::TypeAndShape::Characterize(args[0], context)}) {
903 if (auto bytes{info->MeasureSizeInBytes(context)}) {
904 return Expr<T>{Fold(context, ConvertToType<T>(std::move(*bytes)))};
905 }
906 }
907 } else if (name == "storage_size") { // in bits
908 if (auto info{
909 characteristics::TypeAndShape::Characterize(args[0], context)}) {
910 if (auto bytes{info->MeasureElementSizeInBytes(context, true)}) {
911 return Expr<T>{
912 Fold(context, Expr<T>{8} * ConvertToType<T>(std::move(*bytes)))};
913 }
914 }
915 } else if (name == "sum") {
916 return FoldSum<T>(context, std::move(funcRef));
917 } else if (name == "ubound") {
918 return UBOUND(context, std::move(funcRef));
919 }
920 // TODO: dot_product, ibits, ishftc, matmul, sign, transfer
921 return Expr<T>{std::move(funcRef)};
922 }
923
924 // Substitutes a bare type parameter reference with its value if it has one now
925 // in an instantiation. Bare LEN type parameters are substituted only when
926 // the known value is constant.
FoldOperation(FoldingContext & context,TypeParamInquiry && inquiry)927 Expr<TypeParamInquiry::Result> FoldOperation(
928 FoldingContext &context, TypeParamInquiry &&inquiry) {
929 std::optional<NamedEntity> base{inquiry.base()};
930 parser::CharBlock parameterName{inquiry.parameter().name()};
931 if (base) {
932 // Handling "designator%typeParam". Get the value of the type parameter
933 // from the instantiation of the base
934 if (const semantics::DeclTypeSpec *
935 declType{base->GetLastSymbol().GetType()}) {
936 if (const semantics::ParamValue *
937 paramValue{
938 declType->derivedTypeSpec().FindParameter(parameterName)}) {
939 const semantics::MaybeIntExpr ¶mExpr{paramValue->GetExplicit()};
940 if (paramExpr && IsConstantExpr(*paramExpr)) {
941 Expr<SomeInteger> intExpr{*paramExpr};
942 return Fold(context,
943 ConvertToType<TypeParamInquiry::Result>(std::move(intExpr)));
944 }
945 }
946 }
947 } else {
948 // A "bare" type parameter: replace with its value, if that's now known
949 // in a current derived type instantiation, for KIND type parameters.
950 if (const auto *pdt{context.pdtInstance()}) {
951 bool isLen{false};
952 if (const semantics::Scope * scope{context.pdtInstance()->scope()}) {
953 auto iter{scope->find(parameterName)};
954 if (iter != scope->end()) {
955 const Symbol &symbol{*iter->second};
956 const auto *details{symbol.detailsIf<semantics::TypeParamDetails>()};
957 if (details) {
958 isLen = details->attr() == common::TypeParamAttr::Len;
959 const semantics::MaybeIntExpr &initExpr{details->init()};
960 if (initExpr && IsConstantExpr(*initExpr) &&
961 (!isLen || ToInt64(*initExpr))) {
962 Expr<SomeInteger> expr{*initExpr};
963 return Fold(context,
964 ConvertToType<TypeParamInquiry::Result>(std::move(expr)));
965 }
966 }
967 }
968 }
969 if (const auto *value{pdt->FindParameter(parameterName)}) {
970 if (value->isExplicit()) {
971 auto folded{Fold(context,
972 AsExpr(ConvertToType<TypeParamInquiry::Result>(
973 Expr<SomeInteger>{value->GetExplicit().value()})))};
974 if (!isLen || ToInt64(folded)) {
975 return folded;
976 }
977 }
978 }
979 }
980 }
981 return AsExpr(std::move(inquiry));
982 }
983
ToInt64(const Expr<SomeInteger> & expr)984 std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) {
985 return std::visit(
986 [](const auto &kindExpr) { return ToInt64(kindExpr); }, expr.u);
987 }
988
ToInt64(const Expr<SomeType> & expr)989 std::optional<std::int64_t> ToInt64(const Expr<SomeType> &expr) {
990 if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(expr)}) {
991 return ToInt64(*intExpr);
992 } else {
993 return std::nullopt;
994 }
995 }
996
997 FOR_EACH_INTEGER_KIND(template class ExpressionBase, )
998 template class ExpressionBase<SomeInteger>;
999 } // namespace Fortran::evaluate
1000