1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveSwitchFallThrough.cpp: Remove fall-through from switch statements.
7 // Note that it is unsafe to do further AST transformations on the AST generated
8 // by this function. It leaves duplicate nodes in the AST making replacements
9 // unreliable.
10 
11 #include "compiler/translator/RemoveSwitchFallThrough.h"
12 
13 #include "compiler/translator/Diagnostics.h"
14 #include "compiler/translator/IntermTraverse.h"
15 
16 namespace sh
17 {
18 
19 namespace
20 {
21 
22 class RemoveSwitchFallThroughTraverser : public TIntermTraverser
23 {
24   public:
25     static TIntermBlock *removeFallThrough(TIntermBlock *statementList,
26                                            PerformanceDiagnostics *perfDiagnostics);
27 
28   private:
29     RemoveSwitchFallThroughTraverser(TIntermBlock *statementList,
30                                      PerformanceDiagnostics *perfDiagnostics);
31 
32     void visitSymbol(TIntermSymbol *node) override;
33     void visitConstantUnion(TIntermConstantUnion *node) override;
34     bool visitDeclaration(Visit, TIntermDeclaration *node) override;
35     bool visitBinary(Visit, TIntermBinary *node) override;
36     bool visitUnary(Visit, TIntermUnary *node) override;
37     bool visitTernary(Visit visit, TIntermTernary *node) override;
38     bool visitSwizzle(Visit, TIntermSwizzle *node) override;
39     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
40     bool visitSwitch(Visit, TIntermSwitch *node) override;
41     bool visitCase(Visit, TIntermCase *node) override;
42     bool visitAggregate(Visit, TIntermAggregate *node) override;
43     bool visitBlock(Visit, TIntermBlock *node) override;
44     bool visitLoop(Visit, TIntermLoop *node) override;
45     bool visitBranch(Visit, TIntermBranch *node) override;
46 
47     void outputSequence(TIntermSequence *sequence, size_t startIndex);
48     void handlePreviousCase();
49 
50     TIntermBlock *mStatementList;
51     TIntermBlock *mStatementListOut;
52     bool mLastStatementWasBreak;
53     TIntermBlock *mPreviousCase;
54     std::vector<TIntermBlock *> mCasesSharingBreak;
55     PerformanceDiagnostics *mPerfDiagnostics;
56 };
57 
removeFallThrough(TIntermBlock * statementList,PerformanceDiagnostics * perfDiagnostics)58 TIntermBlock *RemoveSwitchFallThroughTraverser::removeFallThrough(
59     TIntermBlock *statementList,
60     PerformanceDiagnostics *perfDiagnostics)
61 {
62     RemoveSwitchFallThroughTraverser rm(statementList, perfDiagnostics);
63     ASSERT(statementList);
64     statementList->traverse(&rm);
65     ASSERT(rm.mPreviousCase || statementList->getSequence()->empty());
66     if (!rm.mLastStatementWasBreak && rm.mPreviousCase)
67     {
68         // Make sure that there's a branch at the end of the final case inside the switch statement.
69         // This also ensures that any cases that fall through to the final case will get the break.
70         TIntermBranch *finalBreak = new TIntermBranch(EOpBreak, nullptr);
71         rm.mPreviousCase->getSequence()->push_back(finalBreak);
72         rm.mLastStatementWasBreak = true;
73     }
74     rm.handlePreviousCase();
75     return rm.mStatementListOut;
76 }
77 
RemoveSwitchFallThroughTraverser(TIntermBlock * statementList,PerformanceDiagnostics * perfDiagnostics)78 RemoveSwitchFallThroughTraverser::RemoveSwitchFallThroughTraverser(
79     TIntermBlock *statementList,
80     PerformanceDiagnostics *perfDiagnostics)
81     : TIntermTraverser(true, false, false),
82       mStatementList(statementList),
83       mLastStatementWasBreak(false),
84       mPreviousCase(nullptr),
85       mPerfDiagnostics(perfDiagnostics)
86 {
87     mStatementListOut = new TIntermBlock();
88 }
89 
visitSymbol(TIntermSymbol * node)90 void RemoveSwitchFallThroughTraverser::visitSymbol(TIntermSymbol *node)
91 {
92     // Note that this assumes that switch statements which don't begin by a case statement
93     // have already been weeded out in validation.
94     mPreviousCase->getSequence()->push_back(node);
95     mLastStatementWasBreak = false;
96 }
97 
visitConstantUnion(TIntermConstantUnion * node)98 void RemoveSwitchFallThroughTraverser::visitConstantUnion(TIntermConstantUnion *node)
99 {
100     // Conditions of case labels are not traversed, so this is a constant statement like "0;".
101     // These are no-ops so there's no need to add them back to the statement list. Should have
102     // already been pruned out of the AST, in fact.
103     UNREACHABLE();
104 }
105 
visitDeclaration(Visit,TIntermDeclaration * node)106 bool RemoveSwitchFallThroughTraverser::visitDeclaration(Visit, TIntermDeclaration *node)
107 {
108     mPreviousCase->getSequence()->push_back(node);
109     mLastStatementWasBreak = false;
110     return false;
111 }
112 
visitBinary(Visit,TIntermBinary * node)113 bool RemoveSwitchFallThroughTraverser::visitBinary(Visit, TIntermBinary *node)
114 {
115     mPreviousCase->getSequence()->push_back(node);
116     mLastStatementWasBreak = false;
117     return false;
118 }
119 
visitUnary(Visit,TIntermUnary * node)120 bool RemoveSwitchFallThroughTraverser::visitUnary(Visit, TIntermUnary *node)
121 {
122     mPreviousCase->getSequence()->push_back(node);
123     mLastStatementWasBreak = false;
124     return false;
125 }
126 
visitTernary(Visit,TIntermTernary * node)127 bool RemoveSwitchFallThroughTraverser::visitTernary(Visit, TIntermTernary *node)
128 {
129     mPreviousCase->getSequence()->push_back(node);
130     mLastStatementWasBreak = false;
131     return false;
132 }
133 
visitSwizzle(Visit,TIntermSwizzle * node)134 bool RemoveSwitchFallThroughTraverser::visitSwizzle(Visit, TIntermSwizzle *node)
135 {
136     mPreviousCase->getSequence()->push_back(node);
137     mLastStatementWasBreak = false;
138     return false;
139 }
140 
visitIfElse(Visit,TIntermIfElse * node)141 bool RemoveSwitchFallThroughTraverser::visitIfElse(Visit, TIntermIfElse *node)
142 {
143     mPreviousCase->getSequence()->push_back(node);
144     mLastStatementWasBreak = false;
145     return false;
146 }
147 
visitSwitch(Visit,TIntermSwitch * node)148 bool RemoveSwitchFallThroughTraverser::visitSwitch(Visit, TIntermSwitch *node)
149 {
150     mPreviousCase->getSequence()->push_back(node);
151     mLastStatementWasBreak = false;
152     // Don't go into nested switch statements
153     return false;
154 }
155 
outputSequence(TIntermSequence * sequence,size_t startIndex)156 void RemoveSwitchFallThroughTraverser::outputSequence(TIntermSequence *sequence, size_t startIndex)
157 {
158     for (size_t i = startIndex; i < sequence->size(); ++i)
159     {
160         mStatementListOut->getSequence()->push_back(sequence->at(i));
161     }
162 }
163 
handlePreviousCase()164 void RemoveSwitchFallThroughTraverser::handlePreviousCase()
165 {
166     if (mPreviousCase)
167         mCasesSharingBreak.push_back(mPreviousCase);
168     if (mLastStatementWasBreak)
169     {
170         for (size_t i = 0; i < mCasesSharingBreak.size(); ++i)
171         {
172             ASSERT(!mCasesSharingBreak.at(i)->getSequence()->empty());
173             if (mCasesSharingBreak.at(i)->getSequence()->size() == 1)
174             {
175                 // Fall-through is allowed in case the label has no statements.
176                 outputSequence(mCasesSharingBreak.at(i)->getSequence(), 0);
177             }
178             else
179             {
180                 // Include all the statements that this case can fall through under the same label.
181                 if (mCasesSharingBreak.size() > i + 1u)
182                 {
183                     mPerfDiagnostics->warning(mCasesSharingBreak.at(i)->getLine(),
184                                               "Performance: non-empty fall-through cases in "
185                                               "switch statements generate extra code.",
186                                               "switch");
187                 }
188                 for (size_t j = i; j < mCasesSharingBreak.size(); ++j)
189                 {
190                     size_t startIndex =
191                         j > i ? 1 : 0;  // Add the label only from the first sequence.
192                     outputSequence(mCasesSharingBreak.at(j)->getSequence(), startIndex);
193                 }
194             }
195         }
196         mCasesSharingBreak.clear();
197     }
198     mLastStatementWasBreak = false;
199     mPreviousCase          = nullptr;
200 }
201 
visitCase(Visit,TIntermCase * node)202 bool RemoveSwitchFallThroughTraverser::visitCase(Visit, TIntermCase *node)
203 {
204     handlePreviousCase();
205     mPreviousCase = new TIntermBlock();
206     mPreviousCase->getSequence()->push_back(node);
207     mPreviousCase->setLine(node->getLine());
208     // Don't traverse the condition of the case statement
209     return false;
210 }
211 
visitAggregate(Visit,TIntermAggregate * node)212 bool RemoveSwitchFallThroughTraverser::visitAggregate(Visit, TIntermAggregate *node)
213 {
214     mPreviousCase->getSequence()->push_back(node);
215     mLastStatementWasBreak = false;
216     return false;
217 }
218 
DoesBlockAlwaysBreak(TIntermBlock * node)219 bool DoesBlockAlwaysBreak(TIntermBlock *node)
220 {
221     if (node->getSequence()->empty())
222     {
223         return false;
224     }
225 
226     TIntermBlock *lastStatementAsBlock = node->getSequence()->back()->getAsBlock();
227     if (lastStatementAsBlock)
228     {
229         return DoesBlockAlwaysBreak(lastStatementAsBlock);
230     }
231 
232     TIntermBranch *lastStatementAsBranch = node->getSequence()->back()->getAsBranchNode();
233     return lastStatementAsBranch != nullptr;
234 }
235 
visitBlock(Visit,TIntermBlock * node)236 bool RemoveSwitchFallThroughTraverser::visitBlock(Visit, TIntermBlock *node)
237 {
238     if (node != mStatementList)
239     {
240         mPreviousCase->getSequence()->push_back(node);
241         mLastStatementWasBreak = DoesBlockAlwaysBreak(node);
242         return false;
243     }
244     return true;
245 }
246 
visitLoop(Visit,TIntermLoop * node)247 bool RemoveSwitchFallThroughTraverser::visitLoop(Visit, TIntermLoop *node)
248 {
249     mPreviousCase->getSequence()->push_back(node);
250     mLastStatementWasBreak = false;
251     return false;
252 }
253 
visitBranch(Visit,TIntermBranch * node)254 bool RemoveSwitchFallThroughTraverser::visitBranch(Visit, TIntermBranch *node)
255 {
256     mPreviousCase->getSequence()->push_back(node);
257     // TODO: Verify that accepting return or continue statements here doesn't cause problems.
258     mLastStatementWasBreak = true;
259     return false;
260 }
261 
262 }  // anonymous namespace
263 
RemoveSwitchFallThrough(TIntermBlock * statementList,PerformanceDiagnostics * perfDiagnostics)264 TIntermBlock *RemoveSwitchFallThrough(TIntermBlock *statementList,
265                                       PerformanceDiagnostics *perfDiagnostics)
266 {
267     return RemoveSwitchFallThroughTraverser::removeFallThrough(statementList, perfDiagnostics);
268 }
269 
270 }  // namespace sh
271