1 //===-- lib/Semantics/check-case.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "check-case.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/reference.h"
12 #include "flang/Common/template.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/type.h"
15 #include "flang/Parser/parse-tree.h"
16 #include "flang/Semantics/semantics.h"
17 #include "flang/Semantics/tools.h"
18 #include <tuple>
19
20 namespace Fortran::semantics {
21
22 template <typename T> class CaseValues {
23 public:
CaseValues(SemanticsContext & c,const evaluate::DynamicType & t)24 CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
25 : context_{c}, caseExprType_{t} {}
26
Check(const std::list<parser::CaseConstruct::Case> & cases)27 void Check(const std::list<parser::CaseConstruct::Case> &cases) {
28 for (const parser::CaseConstruct::Case &c : cases) {
29 AddCase(c);
30 }
31 if (!hasErrors_) {
32 cases_.sort(Comparator{});
33 if (!AreCasesDisjoint()) { // C1149
34 ReportConflictingCases();
35 }
36 }
37 }
38
39 private:
40 using Value = evaluate::Scalar<T>;
41
AddCase(const parser::CaseConstruct::Case & c)42 void AddCase(const parser::CaseConstruct::Case &c) {
43 const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
44 const parser::CaseStmt &caseStmt{stmt.statement};
45 const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
46 std::visit(
47 common::visitors{
48 [&](const std::list<parser::CaseValueRange> &ranges) {
49 for (const auto &range : ranges) {
50 auto pair{ComputeBounds(range)};
51 if (pair.first && pair.second && *pair.first > *pair.second) {
52 context_.Say(stmt.source,
53 "CASE has lower bound greater than upper bound"_en_US);
54 } else {
55 if constexpr (T::category == TypeCategory::Logical) { // C1148
56 if ((pair.first || pair.second) &&
57 (!pair.first || !pair.second ||
58 *pair.first != *pair.second)) {
59 context_.Say(stmt.source,
60 "CASE range is not allowed for LOGICAL"_err_en_US);
61 }
62 }
63 cases_.emplace_back(stmt);
64 cases_.back().lower = std::move(pair.first);
65 cases_.back().upper = std::move(pair.second);
66 }
67 }
68 },
69 [&](const parser::Default &) { cases_.emplace_front(stmt); },
70 },
71 selector.u);
72 }
73
GetValue(const parser::CaseValue & caseValue)74 std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
75 const parser::Expr &expr{caseValue.thing.thing.value()};
76 auto *x{expr.typedExpr.get()};
77 if (x && x->v) { // C1147
78 auto type{x->v->GetType()};
79 if (type && type->category() == caseExprType_.category() &&
80 (type->category() != TypeCategory::Character ||
81 type->kind() == caseExprType_.kind())) {
82 x->v = evaluate::Fold(context_.foldingContext(),
83 evaluate::ConvertToType(T::GetType(), std::move(*x->v)));
84 if (x->v) {
85 if (auto value{evaluate::GetScalarConstantValue<T>(*x->v)}) {
86 return *value;
87 }
88 }
89 context_.Say(
90 expr.source, "CASE value must be a constant scalar"_err_en_US);
91 } else {
92 std::string typeStr{type ? type->AsFortran() : "typeless"s};
93 context_.Say(expr.source,
94 "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
95 typeStr, caseExprType_.AsFortran());
96 }
97 hasErrors_ = true;
98 }
99 return std::nullopt;
100 }
101
102 using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
ComputeBounds(const parser::CaseValueRange & range)103 PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
104 return std::visit(common::visitors{
105 [&](const parser::CaseValue &x) {
106 auto value{GetValue(x)};
107 return PairOfValues{value, value};
108 },
109 [&](const parser::CaseValueRange::Range &x) {
110 std::optional<Value> lo, hi;
111 if (x.lower) {
112 lo = GetValue(*x.lower);
113 }
114 if (x.upper) {
115 hi = GetValue(*x.upper);
116 }
117 if ((x.lower && !lo) || (x.upper && !hi)) {
118 return PairOfValues{}; // error case
119 }
120 return PairOfValues{std::move(lo), std::move(hi)};
121 },
122 },
123 range.u);
124 }
125
126 struct Case {
CaseFortran::semantics::CaseValues::Case127 explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
IsDefaultFortran::semantics::CaseValues::Case128 bool IsDefault() const { return !lower && !upper; }
AsFortranFortran::semantics::CaseValues::Case129 std::string AsFortran() const {
130 std::string result;
131 {
132 llvm::raw_string_ostream bs{result};
133 if (lower) {
134 evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
135 if (!upper) {
136 bs << ':';
137 } else if (*lower != *upper) {
138 evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
139 }
140 bs << ')';
141 } else if (upper) {
142 evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
143 } else {
144 bs << "DEFAULT";
145 }
146 }
147 return result;
148 }
149
150 const parser::Statement<parser::CaseStmt> &stmt;
151 std::optional<Value> lower, upper;
152 };
153
154 // Defines a comparator for use with std::list<>::sort().
155 // Returns true if and only if the highest value in range x is less
156 // than the least value in range y. The DEFAULT case is arbitrarily
157 // defined to be less than all others. When two ranges overlap,
158 // neither is less than the other.
159 struct Comparator {
operator ()Fortran::semantics::CaseValues::Comparator160 bool operator()(const Case &x, const Case &y) const {
161 if (x.IsDefault()) {
162 return !y.IsDefault();
163 } else {
164 return x.upper && y.lower && *x.upper < *y.lower;
165 }
166 }
167 };
168
AreCasesDisjoint() const169 bool AreCasesDisjoint() const {
170 auto endIter{cases_.end()};
171 for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
172 auto next{iter};
173 if (++next != endIter && !Comparator{}(*iter, *next)) {
174 return false;
175 }
176 }
177 return true;
178 }
179
180 // This has quadratic time, but only runs in error cases
ReportConflictingCases()181 void ReportConflictingCases() {
182 for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
183 parser::Message *msg{nullptr};
184 for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
185 if (p->stmt.source.begin() < iter->stmt.source.begin() &&
186 !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
187 if (!msg) {
188 msg = &context_.Say(iter->stmt.source,
189 "CASE %s conflicts with previous cases"_err_en_US,
190 iter->AsFortran());
191 }
192 msg->Attach(
193 p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
194 }
195 }
196 }
197 }
198
199 SemanticsContext &context_;
200 const evaluate::DynamicType &caseExprType_;
201 std::list<Case> cases_;
202 bool hasErrors_{false};
203 };
204
205 template <TypeCategory CAT> struct TypeVisitor {
206 using Result = bool;
207 using Types = evaluate::CategoryTypes<CAT>;
TestFortran::semantics::TypeVisitor208 template <typename T> Result Test() {
209 if (T::kind == exprType.kind()) {
210 CaseValues<T>(context, exprType).Check(caseList);
211 return true;
212 } else {
213 return false;
214 }
215 }
216 SemanticsContext &context;
217 const evaluate::DynamicType &exprType;
218 const std::list<parser::CaseConstruct::Case> &caseList;
219 };
220
Enter(const parser::CaseConstruct & construct)221 void CaseChecker::Enter(const parser::CaseConstruct &construct) {
222 const auto &selectCaseStmt{
223 std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
224 const auto &selectCase{selectCaseStmt.statement};
225 const auto &selectExpr{
226 std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
227 const auto *x{GetExpr(selectExpr)};
228 if (!x) {
229 return; // expression semantics failed
230 }
231 if (auto exprType{x->GetType()}) {
232 const auto &caseList{
233 std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
234 switch (exprType->category()) {
235 case TypeCategory::Integer:
236 common::SearchTypes(
237 TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
238 return;
239 case TypeCategory::Logical:
240 CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
241 .Check(caseList);
242 return;
243 case TypeCategory::Character:
244 common::SearchTypes(
245 TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
246 return;
247 default:
248 break;
249 }
250 }
251 context_.Say(selectExpr.source,
252 "SELECT CASE expression must be integer, logical, or character"_err_en_US);
253 }
254 } // namespace Fortran::semantics
255