1 // Copyright (c) 2018-2019, NVIDIA CORPORATION.  All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "canonicalize-do.h"
16 #include "../parser/parse-tree-visitor.h"
17 
18 namespace Fortran::parser {
19 
20 class CanonicalizationOfDoLoops {
21   struct LabelInfo {
22     Block::iterator iter;
23     Label label;
24   };
25 
26 public:
Pre(T &)27   template<typename T> bool Pre(T &) { return true; }
Post(T &)28   template<typename T> void Post(T &) {}
Post(Block & block)29   void Post(Block &block) {
30     std::vector<LabelInfo> stack;
31     for (auto i{block.begin()}, end{block.end()}; i != end; ++i) {
32       if (auto *executableConstruct{std::get_if<ExecutableConstruct>(&i->u)}) {
33         std::visit(
34             common::visitors{
35                 [](auto &) {},
36                 // Labels on end-stmt of constructs are accepted by f18 as an
37                 // extension.
38                 [&](common::Indirection<AssociateConstruct> &associate) {
39                   CanonicalizeIfMatch(block, stack, i,
40                       std::get<Statement<EndAssociateStmt>>(
41                           associate.value().t));
42                 },
43                 [&](common::Indirection<BlockConstruct> &blockConstruct) {
44                   CanonicalizeIfMatch(block, stack, i,
45                       std::get<Statement<EndBlockStmt>>(
46                           blockConstruct.value().t));
47                 },
48                 [&](common::Indirection<ChangeTeamConstruct> &changeTeam) {
49                   CanonicalizeIfMatch(block, stack, i,
50                       std::get<Statement<EndChangeTeamStmt>>(
51                           changeTeam.value().t));
52                 },
53                 [&](common::Indirection<CriticalConstruct> &critical) {
54                   CanonicalizeIfMatch(block, stack, i,
55                       std::get<Statement<EndCriticalStmt>>(critical.value().t));
56                 },
57                 [&](common::Indirection<DoConstruct> &doConstruct) {
58                   CanonicalizeIfMatch(block, stack, i,
59                       std::get<Statement<EndDoStmt>>(doConstruct.value().t));
60                 },
61                 [&](common::Indirection<IfConstruct> &ifConstruct) {
62                   CanonicalizeIfMatch(block, stack, i,
63                       std::get<Statement<EndIfStmt>>(ifConstruct.value().t));
64                 },
65                 [&](common::Indirection<CaseConstruct> &caseConstruct) {
66                   CanonicalizeIfMatch(block, stack, i,
67                       std::get<Statement<EndSelectStmt>>(
68                           caseConstruct.value().t));
69                 },
70                 [&](common::Indirection<SelectRankConstruct> &selectRank) {
71                   CanonicalizeIfMatch(block, stack, i,
72                       std::get<Statement<EndSelectStmt>>(selectRank.value().t));
73                 },
74                 [&](common::Indirection<SelectTypeConstruct> &selectType) {
75                   CanonicalizeIfMatch(block, stack, i,
76                       std::get<Statement<EndSelectStmt>>(selectType.value().t));
77                 },
78                 [&](common::Indirection<ForallConstruct> &forall) {
79                   CanonicalizeIfMatch(block, stack, i,
80                       std::get<Statement<EndForallStmt>>(forall.value().t));
81                 },
82                 [&](common::Indirection<WhereConstruct> &where) {
83                   CanonicalizeIfMatch(block, stack, i,
84                       std::get<Statement<EndWhereStmt>>(where.value().t));
85                 },
86                 [&](Statement<common::Indirection<LabelDoStmt>> &labelDoStmt) {
87                   auto &label{std::get<Label>(labelDoStmt.statement.value().t)};
88                   stack.push_back(LabelInfo{i, label});
89                 },
90                 [&](Statement<common::Indirection<EndDoStmt>> &endDoStmt) {
91                   CanonicalizeIfMatch(block, stack, i, endDoStmt);
92                 },
93                 [&](Statement<ActionStmt> &actionStmt) {
94                   CanonicalizeIfMatch(block, stack, i, actionStmt);
95                 },
96             },
97             executableConstruct->u);
98       }
99     }
100   }
101 
102 private:
103   template<typename T>
CanonicalizeIfMatch(Block & originalBlock,std::vector<LabelInfo> & stack,Block::iterator & i,Statement<T> & statement)104   void CanonicalizeIfMatch(Block &originalBlock, std::vector<LabelInfo> &stack,
105       Block::iterator &i, Statement<T> &statement) {
106     if (!stack.empty() && statement.label.has_value() &&
107         stack.back().label == *statement.label) {
108       auto currentLabel{stack.back().label};
109       if constexpr (std::is_same_v<T, common::Indirection<EndDoStmt>>) {
110         std::get<ExecutableConstruct>(i->u).u = Statement<ActionStmt>{
111             std::optional<Label>{currentLabel}, ContinueStmt{}};
112       }
113       auto next{++i};
114       do {
115         Block block;
116         auto doLoop{stack.back().iter};
117         auto originalSource{
118             std::get<Statement<common::Indirection<LabelDoStmt>>>(
119                 std::get<ExecutableConstruct>(doLoop->u).u)
120                 .source};
121         block.splice(block.begin(), originalBlock, ++stack.back().iter, next);
122         auto &labelDo{std::get<Statement<common::Indirection<LabelDoStmt>>>(
123             std::get<ExecutableConstruct>(doLoop->u).u)};
124         auto &loopControl{
125             std::get<std::optional<LoopControl>>(labelDo.statement.value().t)};
126         auto &name{std::get<std::optional<Name>>(labelDo.statement.value().t)};
127         Statement<NonLabelDoStmt> nonLabelDoStmt{std::move(labelDo.label),
128             NonLabelDoStmt{
129                 std::make_tuple(common::Clone(name), std::move(loopControl))}};
130         nonLabelDoStmt.source = originalSource;
131         std::get<ExecutableConstruct>(doLoop->u).u =
132             common::Indirection<DoConstruct>{
133                 std::make_tuple(std::move(nonLabelDoStmt), std::move(block),
134                     Statement<EndDoStmt>{
135                         std::optional<Label>{}, EndDoStmt{std::move(name)}})};
136         stack.pop_back();
137       } while (!stack.empty() && stack.back().label == currentLabel);
138       i = --next;
139     }
140   }
141 };
142 
CanonicalizeDo(Program & program)143 bool CanonicalizeDo(Program &program) {
144   CanonicalizationOfDoLoops canonicalizationOfDoLoops;
145   Walk(program, canonicalizationOfDoLoops);
146   return true;
147 }
148 
149 }
150