1 //
2 // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // SimplifyLoopConditions is an AST traverser that converts loop conditions and loop expressions
7 // to regular statements inside the loop. This way further transformations that generate statements
8 // from loop conditions and loop expressions work correctly.
9 //
10
11 #include "compiler/translator/SimplifyLoopConditions.h"
12
13 #include "compiler/translator/IntermNode.h"
14 #include "compiler/translator/IntermNodePatternMatcher.h"
15
16 namespace sh
17 {
18
19 namespace
20 {
21
CreateBoolConstantNode(bool value)22 TIntermConstantUnion *CreateBoolConstantNode(bool value)
23 {
24 TConstantUnion *u = new TConstantUnion;
25 u->setBConst(value);
26 TIntermConstantUnion *node =
27 new TIntermConstantUnion(u, TType(EbtBool, EbpUndefined, EvqConst, 1));
28 return node;
29 }
30
31 class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
32 {
33 public:
34 SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,
35 const TSymbolTable &symbolTable,
36 int shaderVersion);
37
38 void traverseLoop(TIntermLoop *node) override;
39
40 bool visitBinary(Visit visit, TIntermBinary *node) override;
41 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
42 bool visitTernary(Visit visit, TIntermTernary *node) override;
43
44 void nextIteration();
foundLoopToChange() const45 bool foundLoopToChange() const { return mFoundLoopToChange; }
46
47 protected:
48 // Marked to true once an operation that needs to be hoisted out of the expression has been
49 // found. After that, no more AST updates are performed on that traversal.
50 bool mFoundLoopToChange;
51 bool mInsideLoopConditionOrExpression;
52 IntermNodePatternMatcher mConditionsToSimplify;
53 };
54
SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,const TSymbolTable & symbolTable,int shaderVersion)55 SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser(
56 unsigned int conditionsToSimplifyMask,
57 const TSymbolTable &symbolTable,
58 int shaderVersion)
59 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
60 mFoundLoopToChange(false),
61 mInsideLoopConditionOrExpression(false),
62 mConditionsToSimplify(conditionsToSimplifyMask)
63 {
64 }
65
nextIteration()66 void SimplifyLoopConditionsTraverser::nextIteration()
67 {
68 mFoundLoopToChange = false;
69 mInsideLoopConditionOrExpression = false;
70 nextTemporaryIndex();
71 }
72
visitBinary(Visit visit,TIntermBinary * node)73 bool SimplifyLoopConditionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
74 {
75 // The visit functions operate in three modes:
76 // 1. If a matching expression has already been found, we return early since only one loop can
77 // be transformed on one traversal.
78 // 2. We try to find loops. In case a node is not inside a loop and can not contain loops, we
79 // stop traversing the subtree.
80 // 3. If we're inside a loop condition or expression, we check for expressions that should be
81 // moved out of the loop condition or expression. If one is found, the loop is processed.
82
83 if (mFoundLoopToChange)
84 return false;
85
86 if (!mInsideLoopConditionOrExpression)
87 return false;
88
89 mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode(), isLValueRequiredHere());
90 return !mFoundLoopToChange;
91 }
92
visitAggregate(Visit visit,TIntermAggregate * node)93 bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
94 {
95 if (mFoundLoopToChange)
96 return false;
97
98 // If we're outside a loop condition, we only need to traverse nodes that may contain loops.
99 if (!mInsideLoopConditionOrExpression)
100 return false;
101
102 mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode());
103 return !mFoundLoopToChange;
104 }
105
visitTernary(Visit visit,TIntermTernary * node)106 bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node)
107 {
108 if (mFoundLoopToChange)
109 return false;
110
111 // Don't traverse ternary operators outside loop conditions.
112 if (!mInsideLoopConditionOrExpression)
113 return false;
114
115 mFoundLoopToChange = mConditionsToSimplify.match(node);
116 return !mFoundLoopToChange;
117 }
118
traverseLoop(TIntermLoop * node)119 void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
120 {
121 if (mFoundLoopToChange)
122 return;
123
124 // Mark that we're inside a loop condition or expression, and transform the loop if needed.
125
126 incrementDepth(node);
127
128 // Note: No need to traverse the loop init node.
129
130 mInsideLoopConditionOrExpression = true;
131 TLoopType loopType = node->getType();
132
133 if (node->getCondition())
134 {
135 node->getCondition()->traverse(this);
136
137 if (mFoundLoopToChange)
138 {
139 // Replace the loop condition with a boolean variable that's updated on each iteration.
140 if (loopType == ELoopWhile)
141 {
142 // Transform:
143 // while (expr) { body; }
144 // into
145 // bool s0 = expr;
146 // while (s0) { { body; } s0 = expr; }
147 TIntermSequence tempInitSeq;
148 tempInitSeq.push_back(createTempInitDeclaration(node->getCondition()->deepCopy()));
149 insertStatementsInParentBlock(tempInitSeq);
150
151 TIntermBlock *newBody = new TIntermBlock();
152 if (node->getBody())
153 {
154 newBody->getSequence()->push_back(node->getBody());
155 }
156 newBody->getSequence()->push_back(
157 createTempAssignment(node->getCondition()->deepCopy()));
158
159 // Can't use queueReplacement to replace old body, since it may have been nullptr.
160 // It's safe to do the replacements in place here - this node won't be traversed
161 // further.
162 node->setBody(newBody);
163 node->setCondition(createTempSymbol(node->getCondition()->getType()));
164 }
165 else if (loopType == ELoopDoWhile)
166 {
167 // Transform:
168 // do {
169 // body;
170 // } while (expr);
171 // into
172 // bool s0 = true;
173 // do {
174 // { body; }
175 // s0 = expr;
176 // while (s0);
177 TIntermSequence tempInitSeq;
178 tempInitSeq.push_back(createTempInitDeclaration(CreateBoolConstantNode(true)));
179 insertStatementsInParentBlock(tempInitSeq);
180
181 TIntermBlock *newBody = new TIntermBlock();
182 if (node->getBody())
183 {
184 newBody->getSequence()->push_back(node->getBody());
185 }
186 newBody->getSequence()->push_back(
187 createTempAssignment(node->getCondition()->deepCopy()));
188
189 // Can't use queueReplacement to replace old body, since it may have been nullptr.
190 // It's safe to do the replacements in place here - this node won't be traversed
191 // further.
192 node->setBody(newBody);
193 node->setCondition(createTempSymbol(node->getCondition()->getType()));
194 }
195 else if (loopType == ELoopFor)
196 {
197 // Move the loop condition inside the loop.
198 // Transform:
199 // for (init; expr; exprB) { body; }
200 // into
201 // {
202 // init;
203 // bool s0 = expr;
204 // while (s0) { { body; } exprB; s0 = expr; }
205 // }
206 TIntermBlock *loopScope = new TIntermBlock();
207 if (node->getInit())
208 {
209 loopScope->getSequence()->push_back(node->getInit());
210 }
211 loopScope->getSequence()->push_back(
212 createTempInitDeclaration(node->getCondition()->deepCopy()));
213
214 TIntermBlock *whileLoopBody = new TIntermBlock();
215 if (node->getBody())
216 {
217 whileLoopBody->getSequence()->push_back(node->getBody());
218 }
219 if (node->getExpression())
220 {
221 whileLoopBody->getSequence()->push_back(node->getExpression());
222 }
223 whileLoopBody->getSequence()->push_back(
224 createTempAssignment(node->getCondition()->deepCopy()));
225 TIntermLoop *whileLoop = new TIntermLoop(
226 ELoopWhile, nullptr, createTempSymbol(node->getCondition()->getType()), nullptr,
227 whileLoopBody);
228 loopScope->getSequence()->push_back(whileLoop);
229 queueReplacementWithParent(getAncestorNode(1), node, loopScope,
230 OriginalNode::IS_DROPPED);
231 }
232 }
233 }
234
235 if (!mFoundLoopToChange && node->getExpression())
236 {
237 node->getExpression()->traverse(this);
238
239 if (mFoundLoopToChange)
240 {
241 ASSERT(loopType == ELoopFor);
242 // Move the loop expression to inside the loop.
243 // Transform:
244 // for (init; expr; exprB) { body; }
245 // into
246 // for (init; expr; ) { { body; } exprB; }
247 TIntermTyped *loopExpression = node->getExpression();
248 node->setExpression(nullptr);
249 TIntermBlock *oldBody = node->getBody();
250 node->setBody(new TIntermBlock());
251 if (oldBody != nullptr)
252 {
253 node->getBody()->getSequence()->push_back(oldBody);
254 }
255 node->getBody()->getSequence()->push_back(loopExpression);
256 }
257 }
258
259 mInsideLoopConditionOrExpression = false;
260
261 if (!mFoundLoopToChange && node->getBody())
262 node->getBody()->traverse(this);
263
264 decrementDepth();
265 }
266
267 } // namespace
268
SimplifyLoopConditions(TIntermNode * root,unsigned int conditionsToSimplifyMask,unsigned int * temporaryIndex,const TSymbolTable & symbolTable,int shaderVersion)269 void SimplifyLoopConditions(TIntermNode *root,
270 unsigned int conditionsToSimplifyMask,
271 unsigned int *temporaryIndex,
272 const TSymbolTable &symbolTable,
273 int shaderVersion)
274 {
275 SimplifyLoopConditionsTraverser traverser(conditionsToSimplifyMask, symbolTable, shaderVersion);
276 ASSERT(temporaryIndex != nullptr);
277 traverser.useTemporaryIndex(temporaryIndex);
278 // Process one loop at a time, and reset the traverser between iterations.
279 do
280 {
281 traverser.nextIteration();
282 root->traverse(&traverser);
283 if (traverser.foundLoopToChange())
284 traverser.updateTree();
285 } while (traverser.foundLoopToChange());
286 }
287
288 } // namespace sh
289