1 //
2 // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // SimplifyLoopConditions is an AST traverser that converts loop conditions and loop expressions
7 // to regular statements inside the loop. This way further transformations that generate statements
8 // from loop conditions and loop expressions work correctly.
9 //
10 
11 #include "compiler/translator/SimplifyLoopConditions.h"
12 
13 #include "compiler/translator/IntermNodePatternMatcher.h"
14 #include "compiler/translator/IntermNode_util.h"
15 #include "compiler/translator/IntermTraverse.h"
16 #include "compiler/translator/StaticType.h"
17 
18 namespace sh
19 {
20 
21 namespace
22 {
23 
24 class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
25 {
26   public:
27     SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,
28                                     TSymbolTable *symbolTable);
29 
30     void traverseLoop(TIntermLoop *node) override;
31 
32     bool visitUnary(Visit visit, TIntermUnary *node) override;
33     bool visitBinary(Visit visit, TIntermBinary *node) override;
34     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
35     bool visitTernary(Visit visit, TIntermTernary *node) override;
36     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
37 
foundLoopToChange() const38     bool foundLoopToChange() const { return mFoundLoopToChange; }
39 
40   protected:
41     // Marked to true once an operation that needs to be hoisted out of a loop expression has been
42     // found.
43     bool mFoundLoopToChange;
44     bool mInsideLoopInitConditionOrExpression;
45     IntermNodePatternMatcher mConditionsToSimplify;
46 };
47 
SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,TSymbolTable * symbolTable)48 SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser(
49     unsigned int conditionsToSimplifyMask,
50     TSymbolTable *symbolTable)
51     : TLValueTrackingTraverser(true, false, false, symbolTable),
52       mFoundLoopToChange(false),
53       mInsideLoopInitConditionOrExpression(false),
54       mConditionsToSimplify(conditionsToSimplifyMask)
55 {
56 }
57 
58 // If we're inside a loop initialization, condition, or expression, we check for expressions that
59 // should be moved out of the loop condition or expression. If one is found, the loop is
60 // transformed.
61 // If we're not inside loop initialization, condition, or expression, we only need to traverse nodes
62 // that may contain loops.
63 
visitUnary(Visit visit,TIntermUnary * node)64 bool SimplifyLoopConditionsTraverser::visitUnary(Visit visit, TIntermUnary *node)
65 {
66     if (!mInsideLoopInitConditionOrExpression)
67         return false;
68 
69     if (mFoundLoopToChange)
70         return false;  // Already decided to change this loop.
71 
72     mFoundLoopToChange = mConditionsToSimplify.match(node);
73     return !mFoundLoopToChange;
74 }
75 
visitBinary(Visit visit,TIntermBinary * node)76 bool SimplifyLoopConditionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
77 {
78     if (!mInsideLoopInitConditionOrExpression)
79         return false;
80 
81     if (mFoundLoopToChange)
82         return false;  // Already decided to change this loop.
83 
84     mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode(), isLValueRequiredHere());
85     return !mFoundLoopToChange;
86 }
87 
visitAggregate(Visit visit,TIntermAggregate * node)88 bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
89 {
90     if (!mInsideLoopInitConditionOrExpression)
91         return false;
92 
93     if (mFoundLoopToChange)
94         return false;  // Already decided to change this loop.
95 
96     mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode());
97     return !mFoundLoopToChange;
98 }
99 
visitTernary(Visit visit,TIntermTernary * node)100 bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node)
101 {
102     if (!mInsideLoopInitConditionOrExpression)
103         return false;
104 
105     if (mFoundLoopToChange)
106         return false;  // Already decided to change this loop.
107 
108     mFoundLoopToChange = mConditionsToSimplify.match(node);
109     return !mFoundLoopToChange;
110 }
111 
visitDeclaration(Visit visit,TIntermDeclaration * node)112 bool SimplifyLoopConditionsTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
113 {
114     if (!mInsideLoopInitConditionOrExpression)
115         return false;
116 
117     if (mFoundLoopToChange)
118         return false;  // Already decided to change this loop.
119 
120     mFoundLoopToChange = mConditionsToSimplify.match(node);
121     return !mFoundLoopToChange;
122 }
123 
traverseLoop(TIntermLoop * node)124 void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
125 {
126     // Mark that we're inside a loop condition or expression, and determine if the loop needs to be
127     // transformed.
128 
129     ScopedNodeInTraversalPath addToPath(this, node);
130 
131     mInsideLoopInitConditionOrExpression = true;
132     mFoundLoopToChange                   = false;
133 
134     if (!mFoundLoopToChange && node->getInit())
135     {
136         node->getInit()->traverse(this);
137     }
138 
139     if (!mFoundLoopToChange && node->getCondition())
140     {
141         node->getCondition()->traverse(this);
142     }
143 
144     if (!mFoundLoopToChange && node->getExpression())
145     {
146         node->getExpression()->traverse(this);
147     }
148 
149     mInsideLoopInitConditionOrExpression = false;
150 
151     if (mFoundLoopToChange)
152     {
153         const TType *boolType        = StaticType::Get<EbtBool, EbpUndefined, EvqTemporary, 1, 1>();
154         TVariable *conditionVariable = CreateTempVariable(mSymbolTable, boolType);
155 
156         // Replace the loop condition with a boolean variable that's updated on each iteration.
157         TLoopType loopType = node->getType();
158         if (loopType == ELoopWhile)
159         {
160             // Transform:
161             //   while (expr) { body; }
162             // into
163             //   bool s0 = expr;
164             //   while (s0) { { body; } s0 = expr; }
165             TIntermDeclaration *tempInitDeclaration =
166                 CreateTempInitDeclarationNode(conditionVariable, node->getCondition()->deepCopy());
167             insertStatementInParentBlock(tempInitDeclaration);
168 
169             TIntermBlock *newBody = new TIntermBlock();
170             if (node->getBody())
171             {
172                 newBody->getSequence()->push_back(node->getBody());
173             }
174             newBody->getSequence()->push_back(
175                 CreateTempAssignmentNode(conditionVariable, node->getCondition()->deepCopy()));
176 
177             // Can't use queueReplacement to replace old body, since it may have been nullptr.
178             // It's safe to do the replacements in place here - the new body will still be
179             // traversed, but that won't create any problems.
180             node->setBody(newBody);
181             node->setCondition(CreateTempSymbolNode(conditionVariable));
182         }
183         else if (loopType == ELoopDoWhile)
184         {
185             // Transform:
186             //   do {
187             //     body;
188             //   } while (expr);
189             // into
190             //   bool s0 = true;
191             //   do {
192             //     { body; }
193             //     s0 = expr;
194             //   } while (s0);
195             TIntermDeclaration *tempInitDeclaration =
196                 CreateTempInitDeclarationNode(conditionVariable, CreateBoolNode(true));
197             insertStatementInParentBlock(tempInitDeclaration);
198 
199             TIntermBlock *newBody = new TIntermBlock();
200             if (node->getBody())
201             {
202                 newBody->getSequence()->push_back(node->getBody());
203             }
204             newBody->getSequence()->push_back(
205                 CreateTempAssignmentNode(conditionVariable, node->getCondition()->deepCopy()));
206 
207             // Can't use queueReplacement to replace old body, since it may have been nullptr.
208             // It's safe to do the replacements in place here - the new body will still be
209             // traversed, but that won't create any problems.
210             node->setBody(newBody);
211             node->setCondition(CreateTempSymbolNode(conditionVariable));
212         }
213         else if (loopType == ELoopFor)
214         {
215             // Move the loop condition inside the loop.
216             // Transform:
217             //   for (init; expr; exprB) { body; }
218             // into
219             //   {
220             //     init;
221             //     bool s0 = expr;
222             //     while (s0) {
223             //       { body; }
224             //       exprB;
225             //       s0 = expr;
226             //     }
227             //   }
228             TIntermBlock *loopScope            = new TIntermBlock();
229             TIntermSequence *loopScopeSequence = loopScope->getSequence();
230 
231             // Insert "init;"
232             if (node->getInit())
233             {
234                 loopScopeSequence->push_back(node->getInit());
235             }
236 
237             // Insert "bool s0 = expr;" if applicable, "bool s0 = true;" otherwise
238             TIntermTyped *conditionInitializer = nullptr;
239             if (node->getCondition())
240             {
241                 conditionInitializer = node->getCondition()->deepCopy();
242             }
243             else
244             {
245                 conditionInitializer = CreateBoolNode(true);
246             }
247             loopScopeSequence->push_back(
248                 CreateTempInitDeclarationNode(conditionVariable, conditionInitializer));
249 
250             // Insert "{ body; }" in the while loop
251             TIntermBlock *whileLoopBody = new TIntermBlock();
252             if (node->getBody())
253             {
254                 whileLoopBody->getSequence()->push_back(node->getBody());
255             }
256             // Insert "exprB;" in the while loop
257             if (node->getExpression())
258             {
259                 whileLoopBody->getSequence()->push_back(node->getExpression());
260             }
261             // Insert "s0 = expr;" in the while loop
262             if (node->getCondition())
263             {
264                 whileLoopBody->getSequence()->push_back(
265                     CreateTempAssignmentNode(conditionVariable, node->getCondition()->deepCopy()));
266             }
267 
268             // Create "while(s0) { whileLoopBody }"
269             TIntermLoop *whileLoop =
270                 new TIntermLoop(ELoopWhile, nullptr, CreateTempSymbolNode(conditionVariable),
271                                 nullptr, whileLoopBody);
272             loopScope->getSequence()->push_back(whileLoop);
273             queueReplacement(loopScope, OriginalNode::IS_DROPPED);
274 
275             // After this the old body node will be traversed and loops inside it may be
276             // transformed. This is fine, since the old body node will still be in the AST after the
277             // transformation that's queued here, and transforming loops inside it doesn't need to
278             // know the exact post-transform path to it.
279         }
280     }
281 
282     mFoundLoopToChange = false;
283 
284     // We traverse the body of the loop even if the loop is transformed.
285     if (node->getBody())
286         node->getBody()->traverse(this);
287 }
288 
289 }  // namespace
290 
SimplifyLoopConditions(TIntermNode * root,unsigned int conditionsToSimplifyMask,TSymbolTable * symbolTable)291 void SimplifyLoopConditions(TIntermNode *root,
292                             unsigned int conditionsToSimplifyMask,
293                             TSymbolTable *symbolTable)
294 {
295     SimplifyLoopConditionsTraverser traverser(conditionsToSimplifyMask, symbolTable);
296     root->traverse(&traverser);
297     traverser.updateTree();
298 }
299 
300 }  // namespace sh
301