1 //===-- lib/Semantics/check-case.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "check-case.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/reference.h"
12 #include "flang/Common/template.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/type.h"
15 #include "flang/Parser/parse-tree.h"
16 #include "flang/Semantics/semantics.h"
17 #include "flang/Semantics/tools.h"
18 #include <tuple>
19 
20 namespace Fortran::semantics {
21 
22 template <typename T> class CaseValues {
23 public:
CaseValues(SemanticsContext & c,const evaluate::DynamicType & t)24   CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
25       : context_{c}, caseExprType_{t} {}
26 
Check(const std::list<parser::CaseConstruct::Case> & cases)27   void Check(const std::list<parser::CaseConstruct::Case> &cases) {
28     for (const parser::CaseConstruct::Case &c : cases) {
29       AddCase(c);
30     }
31     if (!hasErrors_) {
32       cases_.sort(Comparator{});
33       if (!AreCasesDisjoint()) { // C1149
34         ReportConflictingCases();
35       }
36     }
37   }
38 
39 private:
40   using Value = evaluate::Scalar<T>;
41 
AddCase(const parser::CaseConstruct::Case & c)42   void AddCase(const parser::CaseConstruct::Case &c) {
43     const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
44     const parser::CaseStmt &caseStmt{stmt.statement};
45     const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
46     std::visit(
47         common::visitors{
48             [&](const std::list<parser::CaseValueRange> &ranges) {
49               for (const auto &range : ranges) {
50                 auto pair{ComputeBounds(range)};
51                 if (pair.first && pair.second && *pair.first > *pair.second) {
52                   context_.Say(stmt.source,
53                       "CASE has lower bound greater than upper bound"_en_US);
54                 } else {
55                   if constexpr (T::category == TypeCategory::Logical) { // C1148
56                     if ((pair.first || pair.second) &&
57                         (!pair.first || !pair.second ||
58                             *pair.first != *pair.second)) {
59                       context_.Say(stmt.source,
60                           "CASE range is not allowed for LOGICAL"_err_en_US);
61                     }
62                   }
63                   cases_.emplace_back(stmt);
64                   cases_.back().lower = std::move(pair.first);
65                   cases_.back().upper = std::move(pair.second);
66                 }
67               }
68             },
69             [&](const parser::Default &) { cases_.emplace_front(stmt); },
70         },
71         selector.u);
72   }
73 
GetValue(const parser::CaseValue & caseValue)74   std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
75     const parser::Expr &expr{caseValue.thing.thing.value()};
76     auto *x{expr.typedExpr.get()};
77     if (x && x->v) { // C1147
78       auto type{x->v->GetType()};
79       if (type && type->category() == caseExprType_.category() &&
80           (type->category() != TypeCategory::Character ||
81               type->kind() == caseExprType_.kind())) {
82         x->v = evaluate::Fold(context_.foldingContext(),
83             evaluate::ConvertToType(T::GetType(), std::move(*x->v)));
84         if (x->v) {
85           if (auto value{evaluate::GetScalarConstantValue<T>(*x->v)}) {
86             return *value;
87           }
88         }
89         context_.Say(
90             expr.source, "CASE value must be a constant scalar"_err_en_US);
91       } else {
92         std::string typeStr{type ? type->AsFortran() : "typeless"s};
93         context_.Say(expr.source,
94             "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
95             typeStr, caseExprType_.AsFortran());
96       }
97       hasErrors_ = true;
98     }
99     return std::nullopt;
100   }
101 
102   using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
ComputeBounds(const parser::CaseValueRange & range)103   PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
104     return std::visit(common::visitors{
105                           [&](const parser::CaseValue &x) {
106                             auto value{GetValue(x)};
107                             return PairOfValues{value, value};
108                           },
109                           [&](const parser::CaseValueRange::Range &x) {
110                             std::optional<Value> lo, hi;
111                             if (x.lower) {
112                               lo = GetValue(*x.lower);
113                             }
114                             if (x.upper) {
115                               hi = GetValue(*x.upper);
116                             }
117                             if ((x.lower && !lo) || (x.upper && !hi)) {
118                               return PairOfValues{}; // error case
119                             }
120                             return PairOfValues{std::move(lo), std::move(hi)};
121                           },
122                       },
123         range.u);
124   }
125 
126   struct Case {
CaseFortran::semantics::CaseValues::Case127     explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
IsDefaultFortran::semantics::CaseValues::Case128     bool IsDefault() const { return !lower && !upper; }
AsFortranFortran::semantics::CaseValues::Case129     std::string AsFortran() const {
130       std::string result;
131       {
132         llvm::raw_string_ostream bs{result};
133         if (lower) {
134           evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
135           if (!upper) {
136             bs << ':';
137           } else if (*lower != *upper) {
138             evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
139           }
140           bs << ')';
141         } else if (upper) {
142           evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
143         } else {
144           bs << "DEFAULT";
145         }
146       }
147       return result;
148     }
149 
150     const parser::Statement<parser::CaseStmt> &stmt;
151     std::optional<Value> lower, upper;
152   };
153 
154   // Defines a comparator for use with std::list<>::sort().
155   // Returns true if and only if the highest value in range x is less
156   // than the least value in range y.  The DEFAULT case is arbitrarily
157   // defined to be less than all others.  When two ranges overlap,
158   // neither is less than the other.
159   struct Comparator {
operator ()Fortran::semantics::CaseValues::Comparator160     bool operator()(const Case &x, const Case &y) const {
161       if (x.IsDefault()) {
162         return !y.IsDefault();
163       } else {
164         return x.upper && y.lower && *x.upper < *y.lower;
165       }
166     }
167   };
168 
AreCasesDisjoint() const169   bool AreCasesDisjoint() const {
170     auto endIter{cases_.end()};
171     for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
172       auto next{iter};
173       if (++next != endIter && !Comparator{}(*iter, *next)) {
174         return false;
175       }
176     }
177     return true;
178   }
179 
180   // This has quadratic time, but only runs in error cases
ReportConflictingCases()181   void ReportConflictingCases() {
182     for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
183       parser::Message *msg{nullptr};
184       for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
185         if (p->stmt.source.begin() < iter->stmt.source.begin() &&
186             !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
187           if (!msg) {
188             msg = &context_.Say(iter->stmt.source,
189                 "CASE %s conflicts with previous cases"_err_en_US,
190                 iter->AsFortran());
191           }
192           msg->Attach(
193               p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
194         }
195       }
196     }
197   }
198 
199   SemanticsContext &context_;
200   const evaluate::DynamicType &caseExprType_;
201   std::list<Case> cases_;
202   bool hasErrors_{false};
203 };
204 
205 template <TypeCategory CAT> struct TypeVisitor {
206   using Result = bool;
207   using Types = evaluate::CategoryTypes<CAT>;
TestFortran::semantics::TypeVisitor208   template <typename T> Result Test() {
209     if (T::kind == exprType.kind()) {
210       CaseValues<T>(context, exprType).Check(caseList);
211       return true;
212     } else {
213       return false;
214     }
215   }
216   SemanticsContext &context;
217   const evaluate::DynamicType &exprType;
218   const std::list<parser::CaseConstruct::Case> &caseList;
219 };
220 
Enter(const parser::CaseConstruct & construct)221 void CaseChecker::Enter(const parser::CaseConstruct &construct) {
222   const auto &selectCaseStmt{
223       std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
224   const auto &selectCase{selectCaseStmt.statement};
225   const auto &selectExpr{
226       std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
227   const auto *x{GetExpr(selectExpr)};
228   if (!x) {
229     return; // expression semantics failed
230   }
231   if (auto exprType{x->GetType()}) {
232     const auto &caseList{
233         std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
234     switch (exprType->category()) {
235     case TypeCategory::Integer:
236       common::SearchTypes(
237           TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
238       return;
239     case TypeCategory::Logical:
240       CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
241           .Check(caseList);
242       return;
243     case TypeCategory::Character:
244       common::SearchTypes(
245           TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
246       return;
247     default:
248       break;
249     }
250   }
251   context_.Say(selectExpr.source,
252       "SELECT CASE expression must be integer, logical, or character"_err_en_US);
253 }
254 } // namespace Fortran::semantics
255