1 // Copyright (c) 2018-2019, NVIDIA CORPORATION.  All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "check-do.h"
16 #include "attr.h"
17 #include "scope.h"
18 #include "semantics.h"
19 #include "symbol.h"
20 #include "tools.h"
21 #include "type.h"
22 #include "../common/template.h"
23 #include "../evaluate/expression.h"
24 #include "../evaluate/tools.h"
25 #include "../parser/message.h"
26 #include "../parser/parse-tree-visitor.h"
27 #include "../parser/tools.h"
28 
29 namespace Fortran::semantics {
30 
31 using namespace parser::literals;
32 
33 // Return the (possibly null)  name of the construct
34 template<typename A>
MaybeGetConstructName(const A & a)35 static const parser::Name *MaybeGetConstructName(const A &a) {
36   return common::GetPtrFromOptional(std::get<0>(std::get<0>(a.t).statement.t));
37 }
38 
MaybeGetConstructName(const parser::BlockConstruct & blockConstruct)39 static const parser::Name *MaybeGetConstructName(
40     const parser::BlockConstruct &blockConstruct) {
41   return common::GetPtrFromOptional(
42       std::get<parser::Statement<parser::BlockStmt>>(blockConstruct.t)
43           .statement.v);
44 }
45 
46 // Return the (possibly null) name of the statement
MaybeGetStmtName(const A & a)47 template<typename A> static const parser::Name *MaybeGetStmtName(const A &a) {
48   return common::GetPtrFromOptional(std::get<0>(a.t));
49 }
50 
51 // 11.1.7.5 - enforce semantics constraints on a DO CONCURRENT loop body
52 class DoConcurrentBodyEnforce {
53 public:
DoConcurrentBodyEnforce(SemanticsContext & context)54   DoConcurrentBodyEnforce(SemanticsContext &context) : context_{context} {}
labels()55   std::set<parser::Label> labels() { return labels_; }
names()56   std::set<SourceName> names() { return names_; }
Pre(const T &)57   template<typename T> bool Pre(const T &) { return true; }
Post(const T &)58   template<typename T> void Post(const T &) {}
Pre(const parser::Statement<T> & statement)59   template<typename T> bool Pre(const parser::Statement<T> &statement) {
60     currentStatementSourcePosition_ = statement.source;
61     if (statement.label.has_value()) {
62       labels_.insert(*statement.label);
63     }
64     return true;
65   }
66 
67   // C1167
Pre(const parser::WhereConstruct & s)68   bool Pre(const parser::WhereConstruct &s) {
69     AddName(MaybeGetConstructName(s));
70     return true;
71   }
72 
Pre(const parser::ForallConstruct & s)73   bool Pre(const parser::ForallConstruct &s) {
74     AddName(MaybeGetConstructName(s));
75     return true;
76   }
77 
Pre(const parser::ChangeTeamConstruct & s)78   bool Pre(const parser::ChangeTeamConstruct &s) {
79     AddName(MaybeGetConstructName(s));
80     return true;
81   }
82 
Pre(const parser::CriticalConstruct & s)83   bool Pre(const parser::CriticalConstruct &s) {
84     AddName(MaybeGetConstructName(s));
85     return true;
86   }
87 
Pre(const parser::LabelDoStmt & s)88   bool Pre(const parser::LabelDoStmt &s) {
89     AddName(MaybeGetStmtName(s));
90     return true;
91   }
92 
Pre(const parser::NonLabelDoStmt & s)93   bool Pre(const parser::NonLabelDoStmt &s) {
94     AddName(MaybeGetStmtName(s));
95     return true;
96   }
97 
Pre(const parser::IfThenStmt & s)98   bool Pre(const parser::IfThenStmt &s) {
99     AddName(MaybeGetStmtName(s));
100     return true;
101   }
102 
Pre(const parser::SelectCaseStmt & s)103   bool Pre(const parser::SelectCaseStmt &s) {
104     AddName(MaybeGetStmtName(s));
105     return true;
106   }
107 
Pre(const parser::SelectRankStmt & s)108   bool Pre(const parser::SelectRankStmt &s) {
109     AddName(MaybeGetStmtName(s));
110     return true;
111   }
112 
Pre(const parser::SelectTypeStmt & s)113   bool Pre(const parser::SelectTypeStmt &s) {
114     AddName(MaybeGetStmtName(s));
115     return true;
116   }
117 
118   // C1136
Post(const parser::ReturnStmt &)119   void Post(const parser::ReturnStmt &) {
120     context_.Say(currentStatementSourcePosition_,
121         "RETURN not allowed in DO CONCURRENT"_err_en_US);
122   }
123 
124   // C1137
NoImageControl()125   void NoImageControl() {
126     context_.Say(currentStatementSourcePosition_,
127         "image control statement not allowed in DO CONCURRENT"_err_en_US);
128   }
129 
130   // more C1137 checks
Post(const parser::SyncAllStmt &)131   void Post(const parser::SyncAllStmt &) { NoImageControl(); }
Post(const parser::SyncImagesStmt &)132   void Post(const parser::SyncImagesStmt &) { NoImageControl(); }
Post(const parser::SyncMemoryStmt &)133   void Post(const parser::SyncMemoryStmt &) { NoImageControl(); }
Post(const parser::SyncTeamStmt &)134   void Post(const parser::SyncTeamStmt &) { NoImageControl(); }
Post(const parser::ChangeTeamConstruct &)135   void Post(const parser::ChangeTeamConstruct &) { NoImageControl(); }
Post(const parser::CriticalConstruct &)136   void Post(const parser::CriticalConstruct &) { NoImageControl(); }
Post(const parser::EventPostStmt &)137   void Post(const parser::EventPostStmt &) { NoImageControl(); }
Post(const parser::EventWaitStmt &)138   void Post(const parser::EventWaitStmt &) { NoImageControl(); }
Post(const parser::FormTeamStmt &)139   void Post(const parser::FormTeamStmt &) { NoImageControl(); }
Post(const parser::LockStmt &)140   void Post(const parser::LockStmt &) { NoImageControl(); }
Post(const parser::UnlockStmt &)141   void Post(const parser::UnlockStmt &) { NoImageControl(); }
Post(const parser::StopStmt &)142   void Post(const parser::StopStmt &) { NoImageControl(); }
143 
144   // more C1137 checks
Post(const parser::AllocateStmt & allocateStmt)145   void Post(const parser::AllocateStmt &allocateStmt) {
146     CheckDoesntContainCoarray(allocateStmt);
147   }
148 
Post(const parser::DeallocateStmt & deallocateStmt)149   void Post(const parser::DeallocateStmt &deallocateStmt) {
150     CheckDoesntContainCoarray(deallocateStmt);  // C1137
151 
152     // C1140: deallocation of polymorphic objects
153     if (anyObjectIsPolymorphic()) {
154       context_.Say(currentStatementSourcePosition_,
155           "DEALLOCATE polymorphic object(s) not allowed"
156           " in DO CONCURRENT"_err_en_US);
157     }
158   }
159 
Post(const parser::Statement<T> &)160   template<typename T> void Post(const parser::Statement<T> &) {
161     if (EndTDeallocatesCoarray()) {
162       context_.Say(currentStatementSourcePosition_,
163           "implicit deallocation of coarray not allowed"
164           " in DO CONCURRENT"_err_en_US);
165     }
166   }
167 
168   // C1141: cannot call ieee_get_flag, ieee_[gs]et_halting_mode
Post(const parser::ProcedureDesignator & procedureDesignator)169   void Post(const parser::ProcedureDesignator &procedureDesignator) {
170     if (auto *name{std::get_if<parser::Name>(&procedureDesignator.u)}) {
171       // C1137: call move_alloc with coarray arguments
172       if (name->source == "move_alloc") {
173         if (anyObjectIsCoarray()) {
174           context_.Say(currentStatementSourcePosition_,
175               "call to MOVE_ALLOC intrinsic in DO CONCURRENT with coarray"
176               " argument(s) not allowed"_err_en_US);
177         }
178       }
179       // C1139: call to impure procedure
180       if (name->symbol && !IsPureProcedure(*name->symbol)) {
181         context_.Say(currentStatementSourcePosition_,
182             "call to impure procedure in DO CONCURRENT not allowed"_err_en_US);
183       }
184       if (name->symbol && fromScope(*name->symbol, "ieee_exceptions"s)) {
185         if (name->source == "ieee_get_flag") {
186           context_.Say(currentStatementSourcePosition_,
187               "IEEE_GET_FLAG not allowed in DO CONCURRENT"_err_en_US);
188         } else if (name->source == "ieee_set_halting_mode") {
189           context_.Say(currentStatementSourcePosition_,
190               "IEEE_SET_HALTING_MODE not allowed in DO CONCURRENT"_err_en_US);
191         } else if (name->source == "ieee_get_halting_mode") {
192           context_.Say(currentStatementSourcePosition_,
193               "IEEE_GET_HALTING_MODE not allowed in DO CONCURRENT"_err_en_US);
194         }
195       }
196     } else {
197       // C1139: this a procedure component
198       auto &component{std::get<parser::ProcComponentRef>(procedureDesignator.u)
199                           .v.thing.component};
200       if (component.symbol && !IsPureProcedure(*component.symbol)) {
201         context_.Say(currentStatementSourcePosition_,
202             "call to impure procedure in DO CONCURRENT not allowed"_err_en_US);
203       }
204     }
205   }
206 
207   // 11.1.7.5
Post(const parser::IoControlSpec & ioControlSpec)208   void Post(const parser::IoControlSpec &ioControlSpec) {
209     if (auto *charExpr{
210             std::get_if<parser::IoControlSpec::CharExpr>(&ioControlSpec.u)}) {
211       if (std::get<parser::IoControlSpec::CharExpr::Kind>(charExpr->t) ==
212           parser::IoControlSpec::CharExpr::Kind::Advance) {
213         context_.Say(currentStatementSourcePosition_,
214             "ADVANCE specifier not allowed in DO CONCURRENT"_err_en_US);
215       }
216     }
217   }
218 
219 private:
220   // C1137 helper functions
CheckAllocateObjectIsntCoarray(const parser::AllocateObject & allocateObject,StmtType stmtType)221   void CheckAllocateObjectIsntCoarray(
222       const parser::AllocateObject &allocateObject, StmtType stmtType) {
223     const parser::Name &name{GetLastName(allocateObject)};
224     if (name.symbol && IsCoarray(*name.symbol)) {
225       context_.Say(name.source,
226           "%s coarray not allowed in DO CONCURRENT"_err_en_US,
227           EnumToString(stmtType));
228     }
229   }
230 
CheckDoesntContainCoarray(const parser::AllocateStmt & allocateStmt)231   void CheckDoesntContainCoarray(const parser::AllocateStmt &allocateStmt) {
232     const auto &allocationList{
233         std::get<std::list<parser::Allocation>>(allocateStmt.t)};
234     for (const auto &allocation : allocationList) {
235       const auto &allocateObject{
236           std::get<parser::AllocateObject>(allocation.t)};
237       CheckAllocateObjectIsntCoarray(allocateObject, StmtType::ALLOCATE);
238     }
239   }
240 
CheckDoesntContainCoarray(const parser::DeallocateStmt & deallocateStmt)241   void CheckDoesntContainCoarray(const parser::DeallocateStmt &deallocateStmt) {
242     const auto &allocateObjectList{
243         std::get<std::list<parser::AllocateObject>>(deallocateStmt.t)};
244     for (const auto &allocateObject : allocateObjectList) {
245       CheckAllocateObjectIsntCoarray(allocateObject, StmtType::DEALLOCATE);
246     }
247   }
248 
anyObjectIsCoarray()249   bool anyObjectIsCoarray() { return false; }  // FIXME placeholder
anyObjectIsPolymorphic()250   bool anyObjectIsPolymorphic() { return false; }  // FIXME placeholder
EndTDeallocatesCoarray()251   bool EndTDeallocatesCoarray() { return false; }  // FIXME placeholder
fromScope(const Symbol & symbol,const std::string & moduleName)252   bool fromScope(const Symbol &symbol, const std::string &moduleName) {
253     if (symbol.GetUltimate().owner().IsModule() &&
254         symbol.GetUltimate().owner().GetName().value().ToString() ==
255             moduleName) {
256       return true;
257     }
258     return false;
259   }
260 
AddName(const parser::Name * nm)261   void AddName(const parser::Name *nm) {
262     if (nm) {
263       names_.insert(nm->source);
264     }
265   }
266 
267   std::set<parser::CharBlock> names_;
268   std::set<parser::Label> labels_;
269   parser::CharBlock currentStatementSourcePosition_;
270   SemanticsContext &context_;
271 };  // class DoConcurrentBodyEnforce
272 
273 class DoConcurrentLabelEnforce {
274 public:
DoConcurrentLabelEnforce(SemanticsContext & context,std::set<parser::Label> && labels,std::set<parser::CharBlock> && names,parser::CharBlock doConcurrentSourcePosition)275   DoConcurrentLabelEnforce(SemanticsContext &context,
276       std::set<parser::Label> &&labels, std::set<parser::CharBlock> &&names,
277       parser::CharBlock doConcurrentSourcePosition)
278     : context_{context}, labels_{labels}, names_{names},
279       doConcurrentSourcePosition_{doConcurrentSourcePosition} {}
Pre(const T &)280   template<typename T> bool Pre(const T &) { return true; }
Pre(const parser::Statement<T> & statement)281   template<typename T> bool Pre(const parser::Statement<T> &statement) {
282     currentStatementSourcePosition_ = statement.source;
283     return true;
284   }
285 
Post(const T &)286   template<typename T> void Post(const T &) {}
287 
Post(const parser::GotoStmt & gotoStmt)288   void Post(const parser::GotoStmt &gotoStmt) { checkLabelUse(gotoStmt.v); }
Post(const parser::ComputedGotoStmt & computedGotoStmt)289   void Post(const parser::ComputedGotoStmt &computedGotoStmt) {
290     for (auto &i : std::get<std::list<parser::Label>>(computedGotoStmt.t)) {
291       checkLabelUse(i);
292     }
293   }
294 
Post(const parser::ArithmeticIfStmt & arithmeticIfStmt)295   void Post(const parser::ArithmeticIfStmt &arithmeticIfStmt) {
296     checkLabelUse(std::get<1>(arithmeticIfStmt.t));
297     checkLabelUse(std::get<2>(arithmeticIfStmt.t));
298     checkLabelUse(std::get<3>(arithmeticIfStmt.t));
299   }
300 
Post(const parser::AssignStmt & assignStmt)301   void Post(const parser::AssignStmt &assignStmt) {
302     checkLabelUse(std::get<parser::Label>(assignStmt.t));
303   }
304 
Post(const parser::AssignedGotoStmt & assignedGotoStmt)305   void Post(const parser::AssignedGotoStmt &assignedGotoStmt) {
306     for (auto &i : std::get<std::list<parser::Label>>(assignedGotoStmt.t)) {
307       checkLabelUse(i);
308     }
309   }
310 
Post(const parser::AltReturnSpec & altReturnSpec)311   void Post(const parser::AltReturnSpec &altReturnSpec) {
312     checkLabelUse(altReturnSpec.v);
313   }
314 
Post(const parser::ErrLabel & errLabel)315   void Post(const parser::ErrLabel &errLabel) { checkLabelUse(errLabel.v); }
Post(const parser::EndLabel & endLabel)316   void Post(const parser::EndLabel &endLabel) { checkLabelUse(endLabel.v); }
Post(const parser::EorLabel & eorLabel)317   void Post(const parser::EorLabel &eorLabel) { checkLabelUse(eorLabel.v); }
318 
checkLabelUse(const parser::Label & labelUsed)319   void checkLabelUse(const parser::Label &labelUsed) {
320     if (labels_.find(labelUsed) == labels_.end()) {
321       context_.Say(currentStatementSourcePosition_,
322           "control flow escapes from DO CONCURRENT"_err_en_US);
323     }
324   }
325 
326 private:
327   SemanticsContext &context_;
328   std::set<parser::Label> labels_;
329   std::set<parser::CharBlock> names_;
330   parser::CharBlock currentStatementSourcePosition_{nullptr};
331   parser::CharBlock doConcurrentSourcePosition_{nullptr};
332 };  // class DoConcurrentLabelEnforce
333 
334 // Class for enforcing C1130
335 class DoConcurrentVariableEnforce {
336 public:
DoConcurrentVariableEnforce(SemanticsContext & context,parser::CharBlock doConcurrentSourcePosition)337   DoConcurrentVariableEnforce(
338       SemanticsContext &context, parser::CharBlock doConcurrentSourcePosition)
339     : context_{context},
340       doConcurrentSourcePosition_{doConcurrentSourcePosition},
341       blockScope_{context.FindScope(doConcurrentSourcePosition_)} {}
342 
Pre(const T &)343   template<typename T> bool Pre(const T &) { return true; }
Post(const T &)344   template<typename T> void Post(const T &) {}
345 
346   // Check to see if the name is a variable from an enclosing scope
Post(const parser::Name & name)347   void Post(const parser::Name &name) {
348     if (const Symbol * symbol{name.symbol}) {
349       if (IsVariableName(*symbol)) {
350         const Scope &variableScope{symbol->owner()};
351         if (DoesScopeContain(&variableScope, blockScope_)) {
352           context_.Say(name.source,
353               "Variable '%s' from an enclosing scope referenced in a DO "
354               "CONCURRENT with DEFAULT(NONE) must appear in a "
355               "locality-spec"_err_en_US,
356               name.source);
357         }
358       }
359     }
360   }
361 
362 private:
363   SemanticsContext &context_;
364   parser::CharBlock doConcurrentSourcePosition_;
365   const Scope &blockScope_;
366 };  // class DoConcurrentVariableEnforce
367 
368 // Find a DO statement and enforce semantics checks on its body
369 class DoContext {
370 public:
DoContext(SemanticsContext & context)371   DoContext(SemanticsContext &context) : context_{context} {}
372 
Check(const parser::DoConstruct & doConstruct)373   void Check(const parser::DoConstruct &doConstruct) {
374     if (doConstruct.IsDoConcurrent()) {
375       CheckDoConcurrent(doConstruct);
376       return;
377     }
378     if (doConstruct.IsDoNormal()) {
379       CheckDoNormal(doConstruct);
380       return;
381     }
382     // TODO: handle the other cases
383   }
384 
385 private:
386   using Bounds = parser::LoopControl::Bounds;
387 
GetBounds(const parser::DoConstruct & doConstruct)388   const Bounds &GetBounds(const parser::DoConstruct &doConstruct) {
389     auto &loopControl{doConstruct.GetLoopControl().value()};
390     return std::get<Bounds>(loopControl.u);
391   }
392 
SayBadDoControl(parser::CharBlock sourceLocation)393   void SayBadDoControl(parser::CharBlock sourceLocation) {
394     context_.Say(sourceLocation, "DO controls should be INTEGER"_err_en_US);
395   }
396 
CheckDoControl(const parser::CharBlock & sourceLocation,bool isReal)397   void CheckDoControl(const parser::CharBlock &sourceLocation, bool isReal) {
398     const bool warn{context_.warnOnNonstandardUsage() ||
399         context_.ShouldWarn(parser::LanguageFeature::RealDoControls)};
400     if (isReal && !warn) {
401       // No messages for the default case
402     } else if (isReal && warn) {
403       // TODO: Mark the following message as a warning when we have warnings
404       context_.Say(sourceLocation, "DO controls should be INTEGER"_en_US);
405     } else {
406       SayBadDoControl(sourceLocation);
407     }
408   }
409 
CheckDoVariable(const parser::ScalarName & scalarName)410   void CheckDoVariable(const parser::ScalarName &scalarName) {
411     const parser::CharBlock &sourceLocation{scalarName.thing.source};
412     if (const Symbol * symbol{scalarName.thing.symbol}) {
413       if (!IsVariableName(*symbol)) {
414         context_.Say(
415             sourceLocation, "DO control must be an INTEGER variable"_err_en_US);
416       } else {
417         const DeclTypeSpec *symType{symbol->GetType()};
418         if (!symType) {
419           SayBadDoControl(sourceLocation);
420         } else {
421           if (!symType->IsNumeric(TypeCategory::Integer)) {
422             CheckDoControl(
423                 sourceLocation, symType->IsNumeric(TypeCategory::Real));
424           }
425         }
426       }  // No messages for INTEGER
427     }
428   }
429 
430   // Semantic checks for the limit and step expressions
CheckDoExpression(const parser::ScalarExpr & scalarExpression)431   void CheckDoExpression(const parser::ScalarExpr &scalarExpression) {
432     if (const SomeExpr * expr{GetExpr(scalarExpression)}) {
433       if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) {
434         // No warnings or errors for type INTEGER
435         const parser::CharBlock &loc{scalarExpression.thing.value().source};
436         CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real));
437       }
438     }
439   }
440 
CheckDoNormal(const parser::DoConstruct & doConstruct)441   void CheckDoNormal(const parser::DoConstruct &doConstruct) {
442     // C1120 extended by allowing REAL and DOUBLE PRECISION
443     // Get the bounds, then check the variable, init, final, and step
444     const Bounds &bounds{GetBounds(doConstruct)};
445     CheckDoVariable(bounds.name);
446     CheckDoExpression(bounds.lower);
447     CheckDoExpression(bounds.upper);
448     if (bounds.step.has_value()) {
449       CheckDoExpression(bounds.step.value());
450     }
451   }
452 
CheckDoConcurrent(const parser::DoConstruct & doConstruct)453   void CheckDoConcurrent(const parser::DoConstruct &doConstruct) {
454     auto &doStmt{
455         std::get<parser::Statement<parser::NonLabelDoStmt>>(doConstruct.t)};
456     currentStatementSourcePosition_ = doStmt.source;
457 
458     const parser::Block &block{std::get<parser::Block>(doConstruct.t)};
459     DoConcurrentBodyEnforce doConcurrentBodyEnforce{context_};
460     parser::Walk(block, doConcurrentBodyEnforce);
461 
462     DoConcurrentLabelEnforce doConcurrentLabelEnforce{context_,
463         doConcurrentBodyEnforce.labels(), doConcurrentBodyEnforce.names(),
464         currentStatementSourcePosition_};
465     parser::Walk(block, doConcurrentLabelEnforce);
466 
467     const auto &loopControl{
468         std::get<std::optional<parser::LoopControl>>(doStmt.statement.t)};
469     const auto &concurrent{
470         std::get<parser::LoopControl::Concurrent>(loopControl->u)};
471     CheckConcurrentLoopControl(concurrent, block);
472   }
473 
474   using SymbolSet = std::set<const Symbol *>;
475 
476   // Return a set of symbols whose names are in a Local locality-spec.  Look
477   // the names up in the scope that encloses the DO construct to avoid getting
478   // the local versions of them.  Then follow the host-, use-, and
479   // construct-associations to get the root symbols
GatherLocals(const std::list<parser::LocalitySpec> & localitySpecs) const480   SymbolSet GatherLocals(
481       const std::list<parser::LocalitySpec> &localitySpecs) const {
482     SymbolSet symbols;
483     const Scope &parentScope{
484         context_.FindScope(currentStatementSourcePosition_).parent()};
485     // Loop through the LocalitySpec::Local locality-specs
486     for (const auto &ls : localitySpecs) {
487       if (const auto *names{std::get_if<parser::LocalitySpec::Local>(&ls.u)}) {
488         // Loop through the names in the Local locality-spec getting their
489         // symbols
490         for (const parser::Name &name : names->v) {
491           if (const Symbol * symbol{parentScope.FindSymbol(name.source)}) {
492             if (const Symbol * root{GetAssociationRoot(*symbol)}) {
493               symbols.insert(root);
494             }
495           }
496         }
497       }
498     }
499     return symbols;
500   }
501 
GatherSymbolsFromExpression(const parser::Expr & expression)502   static SymbolSet GatherSymbolsFromExpression(const parser::Expr &expression) {
503     SymbolSet result;
504     if (const auto *expr{GetExpr(expression)}) {
505       for (const Symbol *symbol : evaluate::CollectSymbols(*expr)) {
506         if (const Symbol * root{GetAssociationRoot(DEREF(symbol))}) {
507           result.insert(root);
508         }
509       }
510     }
511     return result;
512   }
513 
514   // C1121 - procedures in mask must be pure
CheckMaskIsPure(const parser::ScalarLogicalExpr & mask) const515   void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const {
516     SymbolSet references{GatherSymbolsFromExpression(mask.thing.thing.value())};
517     for (const Symbol *ref : references) {
518       if (IsProcedure(*ref) && !IsPureProcedure(*ref)) {
519         const parser::CharBlock &name{ref->name()};
520         context_
521             .Say(currentStatementSourcePosition_,
522                 "concurrent-header mask expression cannot reference an impure"
523                 " procedure"_err_en_US)
524             .Attach(name, "Declaration of impure procedure '%s'"_en_US, name);
525         return;
526       }
527     }
528   }
529 
CheckNoCollisions(const SymbolSet & refs,const SymbolSet & uses,const parser::MessageFixedText & errorMessage,const parser::CharBlock & refPosition) const530   void CheckNoCollisions(const SymbolSet &refs, const SymbolSet &uses,
531       const parser::MessageFixedText &errorMessage,
532       const parser::CharBlock &refPosition) const {
533     for (const Symbol *ref : refs) {
534       if (uses.find(ref) != uses.end()) {
535         const parser::CharBlock &name{ref->name()};
536         context_.Say(refPosition, errorMessage, name)
537             .Attach(name, "Declaration of '%s'"_en_US, name);
538         return;
539       }
540     }
541   }
542 
HasNoReferences(const SymbolSet & indexNames,const parser::ScalarIntExpr & expr) const543   void HasNoReferences(
544       const SymbolSet &indexNames, const parser::ScalarIntExpr &expr) const {
545     CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
546         indexNames,
547         "concurrent-control expression references index-name '%s'"_err_en_US,
548         expr.thing.thing.value().source);
549   }
550 
551   // C1129, names in local locality-specs can't be in mask expressions
CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr & mask,const SymbolSet & localVars) const552   void CheckMaskDoesNotReferenceLocal(
553       const parser::ScalarLogicalExpr &mask, const SymbolSet &localVars) const {
554     CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()),
555         localVars,
556         "concurrent-header mask-expr references variable '%s'"
557         " in LOCAL locality-spec"_err_en_US,
558         mask.thing.thing.value().source);
559   }
560 
561   // C1129, names in local locality-specs can't be in limit or step expressions
CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr & expr,const SymbolSet & localVars) const562   void CheckExprDoesNotReferenceLocal(
563       const parser::ScalarIntExpr &expr, const SymbolSet &localVars) const {
564     CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
565         localVars,
566         "concurrent-header expression references variable '%s'"
567         " in LOCAL locality-spec"_err_en_US,
568         expr.thing.thing.value().source);
569   }
570 
571   // C1130, default(none) locality requires names to be in locality-specs to be
572   // used in the body of the DO loop
CheckDefaultNoneImpliesExplicitLocality(const std::list<parser::LocalitySpec> & localitySpecs,const parser::Block & block) const573   void CheckDefaultNoneImpliesExplicitLocality(
574       const std::list<parser::LocalitySpec> &localitySpecs,
575       const parser::Block &block) const {
576     bool hasDefaultNone{false};
577     for (auto &ls : localitySpecs) {
578       if (std::holds_alternative<parser::LocalitySpec::DefaultNone>(ls.u)) {
579         if (hasDefaultNone) {
580           // C1127, you can only have one DEFAULT(NONE)
581           context_.Say(currentStatementSourcePosition_,
582               "only one DEFAULT(NONE) may appear"_en_US);
583           break;
584         }
585         hasDefaultNone = true;
586       }
587     }
588     if (hasDefaultNone) {
589       DoConcurrentVariableEnforce doConcurrentVariableEnforce{
590           context_, currentStatementSourcePosition_};
591       parser::Walk(block, doConcurrentVariableEnforce);
592     }
593   }
594 
595   // C1123, concurrent limit or step expressions can't reference index-names
CheckConcurrentHeader(const parser::ConcurrentHeader & header) const596   void CheckConcurrentHeader(const parser::ConcurrentHeader &header) const {
597     auto &controls{std::get<std::list<parser::ConcurrentControl>>(header.t)};
598     SymbolSet indexNames;
599     for (const auto &c : controls) {
600       const auto &indexName{std::get<parser::Name>(c.t)};
601       if (indexName.symbol) {
602         indexNames.insert(indexName.symbol);
603       }
604     }
605     if (!indexNames.empty()) {
606       for (const auto &c : controls) {
607         HasNoReferences(indexNames, std::get<1>(c.t));
608         HasNoReferences(indexNames, std::get<2>(c.t));
609         if (const auto &expr{
610                 std::get<std::optional<parser::ScalarIntExpr>>(c.t)}) {
611           HasNoReferences(indexNames, *expr);
612         }
613       }
614     }
615   }
616 
CheckLocalitySpecs(const parser::LoopControl::Concurrent & concurrent,const parser::Block & block) const617   void CheckLocalitySpecs(const parser::LoopControl::Concurrent &concurrent,
618       const parser::Block &block) const {
619     const auto &header{std::get<parser::ConcurrentHeader>(concurrent.t)};
620     const auto &controls{
621         std::get<std::list<parser::ConcurrentControl>>(header.t)};
622     const auto &localitySpecs{
623         std::get<std::list<parser::LocalitySpec>>(concurrent.t)};
624     if (!localitySpecs.empty()) {
625       const SymbolSet &localVars{GatherLocals(localitySpecs)};
626       for (const auto &c : controls) {
627         CheckExprDoesNotReferenceLocal(std::get<1>(c.t), localVars);
628         CheckExprDoesNotReferenceLocal(std::get<2>(c.t), localVars);
629         if (const auto &expr{
630                 std::get<std::optional<parser::ScalarIntExpr>>(c.t)}) {
631           CheckExprDoesNotReferenceLocal(*expr, localVars);
632         }
633       }
634       if (const auto &mask{
635               std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
636         CheckMaskDoesNotReferenceLocal(*mask, localVars);
637       }
638       CheckDefaultNoneImpliesExplicitLocality(localitySpecs, block);
639     }
640   }
641 
642   // check constraints [C1121 .. C1130]
CheckConcurrentLoopControl(const parser::LoopControl::Concurrent & concurrent,const parser::Block & block) const643   void CheckConcurrentLoopControl(
644       const parser::LoopControl::Concurrent &concurrent,
645       const parser::Block &block) const {
646 
647     const auto &header{std::get<parser::ConcurrentHeader>(concurrent.t)};
648     const auto &mask{
649         std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)};
650     if (mask.has_value()) {
651       CheckMaskIsPure(*mask);
652     }
653     CheckConcurrentHeader(header);
654     CheckLocalitySpecs(concurrent, block);
655   }
656 
657   SemanticsContext &context_;
658   parser::CharBlock currentStatementSourcePosition_;
659 };  // class DoContext
660 
661 // DO loops must be canonicalized prior to calling
Leave(const parser::DoConstruct & x)662 void DoChecker::Leave(const parser::DoConstruct &x) {
663   DoContext doContext{context_};
664   doContext.Check(x);
665 }
666 
667 // Return the (possibly null) name of the ConstructNode
MaybeGetNodeName(const ConstructNode & construct)668 static const parser::Name *MaybeGetNodeName(const ConstructNode &construct) {
669   return std::visit(
670       [&](const auto &x) { return MaybeGetConstructName(*x); }, construct);
671 }
672 
GetConstructPosition(const A & a)673 template<typename A> static parser::CharBlock GetConstructPosition(const A &a) {
674   return std::get<0>(a.t).source;
675 }
676 
GetNodePosition(const ConstructNode & construct)677 static parser::CharBlock GetNodePosition(const ConstructNode &construct) {
678   return std::visit(
679       [&](const auto &x) { return GetConstructPosition(*x); }, construct);
680 }
681 
SayBadLeave(StmtType stmtType,const char * enclosingStmtName,const ConstructNode & construct) const682 void DoChecker::SayBadLeave(StmtType stmtType, const char *enclosingStmtName,
683     const ConstructNode &construct) const {
684   context_
685       .Say("%s must not leave a %s statement"_err_en_US, EnumToString(stmtType),
686           enclosingStmtName)
687       .Attach(GetNodePosition(construct), "The construct that was left"_en_US);
688 }
689 
MaybeGetDoConstruct(const ConstructNode & construct)690 static const parser::DoConstruct *MaybeGetDoConstruct(
691     const ConstructNode &construct) {
692   if (const auto *doNode{
693           std::get_if<const parser::DoConstruct *>(&construct)}) {
694     return *doNode;
695   } else {
696     return nullptr;
697   }
698 }
699 
ConstructIsDoConcurrent(const ConstructNode & construct)700 static bool ConstructIsDoConcurrent(const ConstructNode &construct) {
701   const parser::DoConstruct *doConstruct{MaybeGetDoConstruct(construct)};
702   return doConstruct && doConstruct->IsDoConcurrent();
703 }
704 
705 // Check that CYCLE and EXIT statements do not cause flow of control to
706 // leave DO CONCURRENT, CRITICAL, or CHANGE TEAM constructs.
CheckForBadLeave(StmtType stmtType,const ConstructNode & construct) const707 void DoChecker::CheckForBadLeave(
708     StmtType stmtType, const ConstructNode &construct) const {
709   std::visit(
710       common::visitors{
711           [&](const parser::DoConstruct *doConstructPtr) {
712             if (doConstructPtr->IsDoConcurrent()) {
713               // C1135 and C1167
714               SayBadLeave(stmtType, "DO CONCURRENT", construct);
715             }
716           },
717           [&](const parser::CriticalConstruct *) {
718             // C1135 and C1168
719             SayBadLeave(stmtType, "CRITICAL", construct);
720           },
721           [&](const parser::ChangeTeamConstruct *) {
722             // C1135 and C1168
723             SayBadLeave(stmtType, "CHANGE TEAM", construct);
724           },
725           [](const auto *) {},
726       },
727       construct);
728 }
729 
StmtMatchesConstruct(const parser::Name * stmtName,StmtType stmtType,const parser::Name * constructName,const ConstructNode & construct)730 static bool StmtMatchesConstruct(const parser::Name *stmtName,
731     StmtType stmtType, const parser::Name *constructName,
732     const ConstructNode &construct) {
733   bool inDoConstruct{MaybeGetDoConstruct(construct) != nullptr};
734   if (stmtName == nullptr) {
735     return inDoConstruct;  // Unlabeled statements match all DO constructs
736   } else if (constructName && constructName->source == stmtName->source) {
737     return stmtType == StmtType::EXIT || inDoConstruct;
738   } else {
739     return false;
740   }
741 }
742 
743 // C1167 Can't EXIT from a DO CONCURRENT
CheckDoConcurrentExit(StmtType stmtType,const ConstructNode & construct) const744 void DoChecker::CheckDoConcurrentExit(
745     StmtType stmtType, const ConstructNode &construct) const {
746   if (stmtType == StmtType::EXIT && ConstructIsDoConcurrent(construct)) {
747     SayBadLeave(StmtType::EXIT, "DO CONCURRENT", construct);
748   }
749 }
750 
751 // Check nesting violations for a CYCLE or EXIT statement.  Loop up the nesting
752 // levels looking for a construct that matches the CYCLE or EXIT statment.  At
753 // every construct, check for a violation.  If we find a match without finding
754 // a violation, the check is complete.
CheckNesting(StmtType stmtType,const parser::Name * stmtName) const755 void DoChecker::CheckNesting(
756     StmtType stmtType, const parser::Name *stmtName) const {
757   const ConstructStack &stack{context_.constructStack()};
758   for (auto iter{stack.cend()}; iter-- != stack.cbegin();) {
759     const ConstructNode &construct{*iter};
760     const parser::Name *constructName{MaybeGetNodeName(construct)};
761     if (StmtMatchesConstruct(stmtName, stmtType, constructName, construct)) {
762       CheckDoConcurrentExit(stmtType, construct);
763       return;  // We got a match, so we're finished checking
764     }
765     CheckForBadLeave(stmtType, construct);
766   }
767 
768   // We haven't found a match in the enclosing constructs
769   if (stmtType == StmtType::EXIT) {
770     context_.Say("No matching construct for EXIT statement"_err_en_US);
771   } else {
772     context_.Say("No matching DO construct for CYCLE statement"_err_en_US);
773   }
774 }
775 
776 // C1135
Enter(const parser::CycleStmt & cycleStmt)777 void DoChecker::Enter(const parser::CycleStmt &cycleStmt) {
778   CheckNesting(StmtType::CYCLE, common::GetPtrFromOptional(cycleStmt.v));
779 }
780 
781 // C1167 and C1168
Enter(const parser::ExitStmt & exitStmt)782 void DoChecker::Enter(const parser::ExitStmt &exitStmt) {
783   CheckNesting(StmtType::EXIT, common::GetPtrFromOptional(exitStmt.v));
784 }
785 
786 }  // namespace Fortran::semantics
787