1 //
2 // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // SimplifyLoopConditions is an AST traverser that converts loop conditions and loop expressions
7 // to regular statements inside the loop. This way further transformations that generate statements
8 // from loop conditions and loop expressions work correctly.
9 //
10 
11 #include "compiler/translator/SimplifyLoopConditions.h"
12 
13 #include "compiler/translator/IntermNode.h"
14 #include "compiler/translator/IntermNodePatternMatcher.h"
15 
16 namespace sh
17 {
18 
19 namespace
20 {
21 
CreateBoolConstantNode(bool value)22 TIntermConstantUnion *CreateBoolConstantNode(bool value)
23 {
24     TConstantUnion *u = new TConstantUnion;
25     u->setBConst(value);
26     TIntermConstantUnion *node =
27         new TIntermConstantUnion(u, TType(EbtBool, EbpUndefined, EvqConst, 1));
28     return node;
29 }
30 
31 class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
32 {
33   public:
34     SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,
35                                     const TSymbolTable &symbolTable,
36                                     int shaderVersion);
37 
38     void traverseLoop(TIntermLoop *node) override;
39 
40     bool visitBinary(Visit visit, TIntermBinary *node) override;
41     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
42     bool visitTernary(Visit visit, TIntermTernary *node) override;
43 
44     void nextIteration();
foundLoopToChange() const45     bool foundLoopToChange() const { return mFoundLoopToChange; }
46 
47   protected:
48     // Marked to true once an operation that needs to be hoisted out of the expression has been
49     // found. After that, no more AST updates are performed on that traversal.
50     bool mFoundLoopToChange;
51     bool mInsideLoopConditionOrExpression;
52     IntermNodePatternMatcher mConditionsToSimplify;
53 };
54 
SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,const TSymbolTable & symbolTable,int shaderVersion)55 SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser(
56     unsigned int conditionsToSimplifyMask,
57     const TSymbolTable &symbolTable,
58     int shaderVersion)
59     : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
60       mFoundLoopToChange(false),
61       mInsideLoopConditionOrExpression(false),
62       mConditionsToSimplify(conditionsToSimplifyMask)
63 {
64 }
65 
nextIteration()66 void SimplifyLoopConditionsTraverser::nextIteration()
67 {
68     mFoundLoopToChange               = false;
69     mInsideLoopConditionOrExpression = false;
70     nextTemporaryIndex();
71 }
72 
visitBinary(Visit visit,TIntermBinary * node)73 bool SimplifyLoopConditionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
74 {
75     // The visit functions operate in three modes:
76     // 1. If a matching expression has already been found, we return early since only one loop can
77     //    be transformed on one traversal.
78     // 2. We try to find loops. In case a node is not inside a loop and can not contain loops, we
79     //    stop traversing the subtree.
80     // 3. If we're inside a loop condition or expression, we check for expressions that should be
81     //    moved out of the loop condition or expression. If one is found, the loop is processed.
82 
83     if (mFoundLoopToChange)
84         return false;
85 
86     if (!mInsideLoopConditionOrExpression)
87         return false;
88 
89     mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode(), isLValueRequiredHere());
90     return !mFoundLoopToChange;
91 }
92 
visitAggregate(Visit visit,TIntermAggregate * node)93 bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
94 {
95     if (mFoundLoopToChange)
96         return false;
97 
98     // If we're outside a loop condition, we only need to traverse nodes that may contain loops.
99     if (!mInsideLoopConditionOrExpression)
100         return false;
101 
102     mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode());
103     return !mFoundLoopToChange;
104 }
105 
visitTernary(Visit visit,TIntermTernary * node)106 bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node)
107 {
108     if (mFoundLoopToChange)
109         return false;
110 
111     // Don't traverse ternary operators outside loop conditions.
112     if (!mInsideLoopConditionOrExpression)
113         return false;
114 
115     mFoundLoopToChange = mConditionsToSimplify.match(node);
116     return !mFoundLoopToChange;
117 }
118 
traverseLoop(TIntermLoop * node)119 void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
120 {
121     if (mFoundLoopToChange)
122         return;
123 
124     // Mark that we're inside a loop condition or expression, and transform the loop if needed.
125 
126     incrementDepth(node);
127 
128     // Note: No need to traverse the loop init node.
129 
130     mInsideLoopConditionOrExpression = true;
131     TLoopType loopType               = node->getType();
132 
133     if (node->getCondition())
134     {
135         node->getCondition()->traverse(this);
136 
137         if (mFoundLoopToChange)
138         {
139             // Replace the loop condition with a boolean variable that's updated on each iteration.
140             if (loopType == ELoopWhile)
141             {
142                 // Transform:
143                 //   while (expr) { body; }
144                 // into
145                 //   bool s0 = expr;
146                 //   while (s0) { { body; } s0 = expr; }
147                 TIntermSequence tempInitSeq;
148                 tempInitSeq.push_back(createTempInitDeclaration(node->getCondition()->deepCopy()));
149                 insertStatementsInParentBlock(tempInitSeq);
150 
151                 TIntermBlock *newBody = new TIntermBlock();
152                 if (node->getBody())
153                 {
154                     newBody->getSequence()->push_back(node->getBody());
155                 }
156                 newBody->getSequence()->push_back(
157                     createTempAssignment(node->getCondition()->deepCopy()));
158 
159                 // Can't use queueReplacement to replace old body, since it may have been nullptr.
160                 // It's safe to do the replacements in place here - this node won't be traversed
161                 // further.
162                 node->setBody(newBody);
163                 node->setCondition(createTempSymbol(node->getCondition()->getType()));
164             }
165             else if (loopType == ELoopDoWhile)
166             {
167                 // Transform:
168                 //   do {
169                 //     body;
170                 //   } while (expr);
171                 // into
172                 //   bool s0 = true;
173                 //   do {
174                 //     { body; }
175                 //     s0 = expr;
176                 //   while (s0);
177                 TIntermSequence tempInitSeq;
178                 tempInitSeq.push_back(createTempInitDeclaration(CreateBoolConstantNode(true)));
179                 insertStatementsInParentBlock(tempInitSeq);
180 
181                 TIntermBlock *newBody = new TIntermBlock();
182                 if (node->getBody())
183                 {
184                     newBody->getSequence()->push_back(node->getBody());
185                 }
186                 newBody->getSequence()->push_back(
187                     createTempAssignment(node->getCondition()->deepCopy()));
188 
189                 // Can't use queueReplacement to replace old body, since it may have been nullptr.
190                 // It's safe to do the replacements in place here - this node won't be traversed
191                 // further.
192                 node->setBody(newBody);
193                 node->setCondition(createTempSymbol(node->getCondition()->getType()));
194             }
195             else if (loopType == ELoopFor)
196             {
197                 // Move the loop condition inside the loop.
198                 // Transform:
199                 //   for (init; expr; exprB) { body; }
200                 // into
201                 //   {
202                 //     init;
203                 //     bool s0 = expr;
204                 //     while (s0) { { body; } exprB; s0 = expr; }
205                 //   }
206                 TIntermBlock *loopScope = new TIntermBlock();
207                 if (node->getInit())
208                 {
209                     loopScope->getSequence()->push_back(node->getInit());
210                 }
211                 loopScope->getSequence()->push_back(
212                     createTempInitDeclaration(node->getCondition()->deepCopy()));
213 
214                 TIntermBlock *whileLoopBody = new TIntermBlock();
215                 if (node->getBody())
216                 {
217                     whileLoopBody->getSequence()->push_back(node->getBody());
218                 }
219                 if (node->getExpression())
220                 {
221                     whileLoopBody->getSequence()->push_back(node->getExpression());
222                 }
223                 whileLoopBody->getSequence()->push_back(
224                     createTempAssignment(node->getCondition()->deepCopy()));
225                 TIntermLoop *whileLoop = new TIntermLoop(
226                     ELoopWhile, nullptr, createTempSymbol(node->getCondition()->getType()), nullptr,
227                     whileLoopBody);
228                 loopScope->getSequence()->push_back(whileLoop);
229                 queueReplacementWithParent(getAncestorNode(1), node, loopScope,
230                                            OriginalNode::IS_DROPPED);
231             }
232         }
233     }
234 
235     if (!mFoundLoopToChange && node->getExpression())
236     {
237         node->getExpression()->traverse(this);
238 
239         if (mFoundLoopToChange)
240         {
241             ASSERT(loopType == ELoopFor);
242             // Move the loop expression to inside the loop.
243             // Transform:
244             //   for (init; expr; exprB) { body; }
245             // into
246             //   for (init; expr; ) { { body; } exprB; }
247             TIntermTyped *loopExpression = node->getExpression();
248             node->setExpression(nullptr);
249             TIntermBlock *oldBody = node->getBody();
250             node->setBody(new TIntermBlock());
251             if (oldBody != nullptr)
252             {
253                 node->getBody()->getSequence()->push_back(oldBody);
254             }
255             node->getBody()->getSequence()->push_back(loopExpression);
256         }
257     }
258 
259     mInsideLoopConditionOrExpression = false;
260 
261     if (!mFoundLoopToChange && node->getBody())
262         node->getBody()->traverse(this);
263 
264     decrementDepth();
265 }
266 
267 }  // namespace
268 
SimplifyLoopConditions(TIntermNode * root,unsigned int conditionsToSimplifyMask,unsigned int * temporaryIndex,const TSymbolTable & symbolTable,int shaderVersion)269 void SimplifyLoopConditions(TIntermNode *root,
270                             unsigned int conditionsToSimplifyMask,
271                             unsigned int *temporaryIndex,
272                             const TSymbolTable &symbolTable,
273                             int shaderVersion)
274 {
275     SimplifyLoopConditionsTraverser traverser(conditionsToSimplifyMask, symbolTable, shaderVersion);
276     ASSERT(temporaryIndex != nullptr);
277     traverser.useTemporaryIndex(temporaryIndex);
278     // Process one loop at a time, and reset the traverser between iterations.
279     do
280     {
281         traverser.nextIteration();
282         root->traverse(&traverser);
283         if (traverser.foundLoopToChange())
284             traverser.updateTree();
285     } while (traverser.foundLoopToChange());
286 }
287 
288 }  // namespace sh
289