1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_util/IntermTraverse.h"
8 
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/InfoSink.h"
11 #include "compiler/translator/SymbolTable.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 
14 namespace sh
15 {
16 
17 // Traverse the intermediate representation tree, and call a node type specific visit function for
18 // each node. Traversal is done recursively through the node member function traverse(). Nodes with
19 // children can have their whole subtree skipped if preVisit is turned on and the type specific
20 // function returns false.
21 template <typename T>
traverse(T * node)22 void TIntermTraverser::traverse(T *node)
23 {
24     ScopedNodeInTraversalPath addToPath(this, node);
25     if (!addToPath.isWithinDepthLimit())
26         return;
27 
28     bool visit = true;
29 
30     // Visit the node before children if pre-visiting.
31     if (preVisit)
32         visit = node->visit(PreVisit, this);
33 
34     if (visit)
35     {
36         size_t childIndex = 0;
37         size_t childCount = node->getChildCount();
38 
39         while (childIndex < childCount && visit)
40         {
41             node->getChildNode(childIndex)->traverse(this);
42             if (inVisit && childIndex != childCount - 1)
43             {
44                 visit = node->visit(InVisit, this);
45             }
46             ++childIndex;
47         }
48 
49         if (visit && postVisit)
50             node->visit(PostVisit, this);
51     }
52 }
53 
54 // Instantiate template for RewriteAtomicFunctionExpressions, in case this gets inlined thus not
55 // exported from the TU.
56 template void TIntermTraverser::traverse(TIntermNode *);
57 
traverse(TIntermTraverser * it)58 void TIntermNode::traverse(TIntermTraverser *it)
59 {
60     it->traverse(this);
61 }
62 
traverse(TIntermTraverser * it)63 void TIntermSymbol::traverse(TIntermTraverser *it)
64 {
65     TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
66     it->visitSymbol(this);
67 }
68 
traverse(TIntermTraverser * it)69 void TIntermConstantUnion::traverse(TIntermTraverser *it)
70 {
71     TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
72     it->visitConstantUnion(this);
73 }
74 
traverse(TIntermTraverser * it)75 void TIntermFunctionPrototype::traverse(TIntermTraverser *it)
76 {
77     TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
78     it->visitFunctionPrototype(this);
79 }
80 
traverse(TIntermTraverser * it)81 void TIntermBinary::traverse(TIntermTraverser *it)
82 {
83     it->traverseBinary(this);
84 }
85 
traverse(TIntermTraverser * it)86 void TIntermUnary::traverse(TIntermTraverser *it)
87 {
88     it->traverseUnary(this);
89 }
90 
traverse(TIntermTraverser * it)91 void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
92 {
93     it->traverseFunctionDefinition(this);
94 }
95 
traverse(TIntermTraverser * it)96 void TIntermBlock::traverse(TIntermTraverser *it)
97 {
98     it->traverseBlock(this);
99 }
100 
traverse(TIntermTraverser * it)101 void TIntermAggregate::traverse(TIntermTraverser *it)
102 {
103     it->traverseAggregate(this);
104 }
105 
traverse(TIntermTraverser * it)106 void TIntermLoop::traverse(TIntermTraverser *it)
107 {
108     it->traverseLoop(this);
109 }
110 
traverse(TIntermTraverser * it)111 void TIntermPreprocessorDirective::traverse(TIntermTraverser *it)
112 {
113     it->visitPreprocessorDirective(this);
114 }
115 
visit(Visit visit,TIntermTraverser * it)116 bool TIntermSymbol::visit(Visit visit, TIntermTraverser *it)
117 {
118     it->visitSymbol(this);
119     return false;
120 }
121 
visit(Visit visit,TIntermTraverser * it)122 bool TIntermConstantUnion::visit(Visit visit, TIntermTraverser *it)
123 {
124     it->visitConstantUnion(this);
125     return false;
126 }
127 
visit(Visit visit,TIntermTraverser * it)128 bool TIntermFunctionPrototype::visit(Visit visit, TIntermTraverser *it)
129 {
130     it->visitFunctionPrototype(this);
131     return false;
132 }
133 
visit(Visit visit,TIntermTraverser * it)134 bool TIntermFunctionDefinition::visit(Visit visit, TIntermTraverser *it)
135 {
136     return it->visitFunctionDefinition(visit, this);
137 }
138 
visit(Visit visit,TIntermTraverser * it)139 bool TIntermUnary::visit(Visit visit, TIntermTraverser *it)
140 {
141     return it->visitUnary(visit, this);
142 }
143 
visit(Visit visit,TIntermTraverser * it)144 bool TIntermSwizzle::visit(Visit visit, TIntermTraverser *it)
145 {
146     return it->visitSwizzle(visit, this);
147 }
148 
visit(Visit visit,TIntermTraverser * it)149 bool TIntermBinary::visit(Visit visit, TIntermTraverser *it)
150 {
151     return it->visitBinary(visit, this);
152 }
153 
visit(Visit visit,TIntermTraverser * it)154 bool TIntermTernary::visit(Visit visit, TIntermTraverser *it)
155 {
156     return it->visitTernary(visit, this);
157 }
158 
visit(Visit visit,TIntermTraverser * it)159 bool TIntermAggregate::visit(Visit visit, TIntermTraverser *it)
160 {
161     return it->visitAggregate(visit, this);
162 }
163 
visit(Visit visit,TIntermTraverser * it)164 bool TIntermDeclaration::visit(Visit visit, TIntermTraverser *it)
165 {
166     return it->visitDeclaration(visit, this);
167 }
168 
visit(Visit visit,TIntermTraverser * it)169 bool TIntermGlobalQualifierDeclaration::visit(Visit visit, TIntermTraverser *it)
170 {
171     return it->visitGlobalQualifierDeclaration(visit, this);
172 }
173 
visit(Visit visit,TIntermTraverser * it)174 bool TIntermBlock::visit(Visit visit, TIntermTraverser *it)
175 {
176     return it->visitBlock(visit, this);
177 }
178 
visit(Visit visit,TIntermTraverser * it)179 bool TIntermIfElse::visit(Visit visit, TIntermTraverser *it)
180 {
181     return it->visitIfElse(visit, this);
182 }
183 
visit(Visit visit,TIntermTraverser * it)184 bool TIntermLoop::visit(Visit visit, TIntermTraverser *it)
185 {
186     return it->visitLoop(visit, this);
187 }
188 
visit(Visit visit,TIntermTraverser * it)189 bool TIntermBranch::visit(Visit visit, TIntermTraverser *it)
190 {
191     return it->visitBranch(visit, this);
192 }
193 
visit(Visit visit,TIntermTraverser * it)194 bool TIntermSwitch::visit(Visit visit, TIntermTraverser *it)
195 {
196     return it->visitSwitch(visit, this);
197 }
198 
visit(Visit visit,TIntermTraverser * it)199 bool TIntermCase::visit(Visit visit, TIntermTraverser *it)
200 {
201     return it->visitCase(visit, this);
202 }
203 
visit(Visit visit,TIntermTraverser * it)204 bool TIntermPreprocessorDirective::visit(Visit visit, TIntermTraverser *it)
205 {
206     it->visitPreprocessorDirective(this);
207     return false;
208 }
209 
TIntermTraverser(bool preVisit,bool inVisit,bool postVisit,TSymbolTable * symbolTable)210 TIntermTraverser::TIntermTraverser(bool preVisit,
211                                    bool inVisit,
212                                    bool postVisit,
213                                    TSymbolTable *symbolTable)
214     : preVisit(preVisit),
215       inVisit(inVisit),
216       postVisit(postVisit),
217       mMaxDepth(0),
218       mMaxAllowedDepth(std::numeric_limits<int>::max()),
219       mInGlobalScope(true),
220       mSymbolTable(symbolTable)
221 {
222     // Only enabling inVisit is not supported.
223     ASSERT(!(inVisit && !preVisit && !postVisit));
224 }
225 
~TIntermTraverser()226 TIntermTraverser::~TIntermTraverser() {}
227 
setMaxAllowedDepth(int depth)228 void TIntermTraverser::setMaxAllowedDepth(int depth)
229 {
230     mMaxAllowedDepth = depth;
231 }
232 
getParentBlock() const233 const TIntermBlock *TIntermTraverser::getParentBlock() const
234 {
235     if (!mParentBlockStack.empty())
236     {
237         return mParentBlockStack.back().node;
238     }
239     return nullptr;
240 }
241 
pushParentBlock(TIntermBlock * node)242 void TIntermTraverser::pushParentBlock(TIntermBlock *node)
243 {
244     mParentBlockStack.push_back(ParentBlock(node, 0));
245 }
246 
incrementParentBlockPos()247 void TIntermTraverser::incrementParentBlockPos()
248 {
249     ++mParentBlockStack.back().pos;
250 }
251 
popParentBlock()252 void TIntermTraverser::popParentBlock()
253 {
254     ASSERT(!mParentBlockStack.empty());
255     mParentBlockStack.pop_back();
256 }
257 
insertStatementsInParentBlock(const TIntermSequence & insertions)258 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertions)
259 {
260     TIntermSequence emptyInsertionsAfter;
261     insertStatementsInParentBlock(insertions, emptyInsertionsAfter);
262 }
263 
insertStatementsInParentBlock(const TIntermSequence & insertionsBefore,const TIntermSequence & insertionsAfter)264 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
265                                                      const TIntermSequence &insertionsAfter)
266 {
267     ASSERT(!mParentBlockStack.empty());
268     ParentBlock &parentBlock = mParentBlockStack.back();
269     if (mPath.back() == parentBlock.node)
270     {
271         ASSERT(mParentBlockStack.size() >= 2u);
272         // The current node is a block node, so the parent block is not the topmost one in the block
273         // stack, but the one below that.
274         parentBlock = mParentBlockStack.at(mParentBlockStack.size() - 2u);
275     }
276     NodeInsertMultipleEntry insert(parentBlock.node, parentBlock.pos, insertionsBefore,
277                                    insertionsAfter);
278     mInsertions.push_back(insert);
279 }
280 
insertStatementInParentBlock(TIntermNode * statement)281 void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement)
282 {
283     TIntermSequence insertions;
284     insertions.push_back(statement);
285     insertStatementsInParentBlock(insertions);
286 }
287 
insertStatementsInBlockAtPosition(TIntermBlock * parent,size_t position,const TIntermSequence & insertionsBefore,const TIntermSequence & insertionsAfter)288 void TIntermTraverser::insertStatementsInBlockAtPosition(TIntermBlock *parent,
289                                                          size_t position,
290                                                          const TIntermSequence &insertionsBefore,
291                                                          const TIntermSequence &insertionsAfter)
292 {
293     ASSERT(parent);
294     ASSERT(position >= 0);
295     ASSERT(position < parent->getChildCount());
296 
297     mInsertions.emplace_back(parent, position, insertionsBefore, insertionsAfter);
298 }
299 
setInFunctionCallOutParameter(bool inOutParameter)300 void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter)
301 {
302     mInFunctionCallOutParameter = inOutParameter;
303 }
304 
isInFunctionCallOutParameter() const305 bool TLValueTrackingTraverser::isInFunctionCallOutParameter() const
306 {
307     return mInFunctionCallOutParameter;
308 }
309 
traverseBinary(TIntermBinary * node)310 void TIntermTraverser::traverseBinary(TIntermBinary *node)
311 {
312     traverse(node);
313 }
314 
traverseBinary(TIntermBinary * node)315 void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
316 {
317     ScopedNodeInTraversalPath addToPath(this, node);
318     if (!addToPath.isWithinDepthLimit())
319         return;
320 
321     bool visit = true;
322 
323     // visit the node before children if pre-visiting.
324     if (preVisit)
325         visit = node->visit(PreVisit, this);
326 
327     // Visit the children, in the right order.
328     if (visit)
329     {
330         if (node->isAssignment())
331         {
332             ASSERT(!isLValueRequiredHere());
333             setOperatorRequiresLValue(true);
334         }
335 
336         node->getLeft()->traverse(this);
337 
338         if (node->isAssignment())
339             setOperatorRequiresLValue(false);
340 
341         if (inVisit)
342             visit = node->visit(InVisit, this);
343 
344         if (visit)
345         {
346             // Some binary operations like indexing can be inside an expression which must be an
347             // l-value.
348             bool parentOperatorRequiresLValue     = operatorRequiresLValue();
349             bool parentInFunctionCallOutParameter = isInFunctionCallOutParameter();
350 
351             // Index is not required to be an l-value even when the surrounding expression is
352             // required to be an l-value.
353             TOperator op = node->getOp();
354             if (op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
355                 op == EOpIndexDirectStruct || op == EOpIndexIndirect)
356             {
357                 setOperatorRequiresLValue(false);
358                 setInFunctionCallOutParameter(false);
359             }
360 
361             node->getRight()->traverse(this);
362 
363             setOperatorRequiresLValue(parentOperatorRequiresLValue);
364             setInFunctionCallOutParameter(parentInFunctionCallOutParameter);
365 
366             // Visit the node after the children, if requested and the traversal
367             // hasn't been cancelled yet.
368             if (postVisit)
369                 visit = node->visit(PostVisit, this);
370         }
371     }
372 }
373 
traverseUnary(TIntermUnary * node)374 void TIntermTraverser::traverseUnary(TIntermUnary *node)
375 {
376     traverse(node);
377 }
378 
traverseUnary(TIntermUnary * node)379 void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
380 {
381     ScopedNodeInTraversalPath addToPath(this, node);
382     if (!addToPath.isWithinDepthLimit())
383         return;
384 
385     bool visit = true;
386 
387     if (preVisit)
388         visit = node->visit(PreVisit, this);
389 
390     if (visit)
391     {
392         ASSERT(!operatorRequiresLValue());
393         switch (node->getOp())
394         {
395             case EOpPostIncrement:
396             case EOpPostDecrement:
397             case EOpPreIncrement:
398             case EOpPreDecrement:
399                 setOperatorRequiresLValue(true);
400                 break;
401             default:
402                 break;
403         }
404 
405         node->getOperand()->traverse(this);
406 
407         setOperatorRequiresLValue(false);
408 
409         if (postVisit)
410             visit = node->visit(PostVisit, this);
411     }
412 }
413 
414 // Traverse a function definition node. This keeps track of global scope.
traverseFunctionDefinition(TIntermFunctionDefinition * node)415 void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
416 {
417     ScopedNodeInTraversalPath addToPath(this, node);
418     if (!addToPath.isWithinDepthLimit())
419         return;
420 
421     bool visit = true;
422 
423     if (preVisit)
424         visit = node->visit(PreVisit, this);
425 
426     if (visit)
427     {
428         node->getFunctionPrototype()->traverse(this);
429         if (inVisit)
430             visit = node->visit(InVisit, this);
431         if (visit)
432         {
433             mInGlobalScope = false;
434             node->getBody()->traverse(this);
435             mInGlobalScope = true;
436             if (postVisit)
437                 visit = node->visit(PostVisit, this);
438         }
439     }
440 }
441 
442 // Traverse a block node. This keeps track of the position of traversed child nodes within the block
443 // so that nodes may be inserted before or after them.
traverseBlock(TIntermBlock * node)444 void TIntermTraverser::traverseBlock(TIntermBlock *node)
445 {
446     ScopedNodeInTraversalPath addToPath(this, node);
447     if (!addToPath.isWithinDepthLimit())
448         return;
449 
450     pushParentBlock(node);
451 
452     bool visit = true;
453 
454     TIntermSequence *sequence = node->getSequence();
455 
456     if (preVisit)
457         visit = node->visit(PreVisit, this);
458 
459     if (visit)
460     {
461         for (auto *child : *sequence)
462         {
463             if (visit)
464             {
465                 child->traverse(this);
466                 if (inVisit)
467                 {
468                     if (child != sequence->back())
469                         visit = node->visit(InVisit, this);
470                 }
471 
472                 incrementParentBlockPos();
473             }
474         }
475 
476         if (visit && postVisit)
477             visit = node->visit(PostVisit, this);
478     }
479 
480     popParentBlock();
481 }
482 
traverseAggregate(TIntermAggregate * node)483 void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
484 {
485     traverse(node);
486 }
487 
CompareInsertion(const NodeInsertMultipleEntry & a,const NodeInsertMultipleEntry & b)488 bool TIntermTraverser::CompareInsertion(const NodeInsertMultipleEntry &a,
489                                         const NodeInsertMultipleEntry &b)
490 {
491     if (a.parent != b.parent)
492     {
493         return a.parent < b.parent;
494     }
495     return a.position < b.position;
496 }
497 
updateTree(TCompiler * compiler,TIntermNode * node)498 bool TIntermTraverser::updateTree(TCompiler *compiler, TIntermNode *node)
499 {
500     // Sort the insertions so that insertion position is increasing and same position insertions are
501     // not reordered. The insertions are processed in reverse order so that multiple insertions to
502     // the same parent node are handled correctly.
503     std::stable_sort(mInsertions.begin(), mInsertions.end(), CompareInsertion);
504     for (size_t ii = 0; ii < mInsertions.size(); ++ii)
505     {
506         // If two insertions are to the same position, insert them in the order they were specified.
507         // The std::stable_sort call above will automatically guarantee this.
508         const NodeInsertMultipleEntry &insertion = mInsertions[mInsertions.size() - ii - 1];
509         ASSERT(insertion.parent);
510         if (!insertion.insertionsAfter.empty())
511         {
512             bool inserted = insertion.parent->insertChildNodes(insertion.position + 1,
513                                                                insertion.insertionsAfter);
514             ASSERT(inserted);
515         }
516         if (!insertion.insertionsBefore.empty())
517         {
518             bool inserted =
519                 insertion.parent->insertChildNodes(insertion.position, insertion.insertionsBefore);
520             ASSERT(inserted);
521         }
522     }
523     for (size_t ii = 0; ii < mReplacements.size(); ++ii)
524     {
525         const NodeUpdateEntry &replacement = mReplacements[ii];
526         ASSERT(replacement.parent);
527         bool replaced =
528             replacement.parent->replaceChildNode(replacement.original, replacement.replacement);
529         ASSERT(replaced);
530 
531         if (!replacement.originalBecomesChildOfReplacement)
532         {
533             // In AST traversing, a parent is visited before its children.
534             // After we replace a node, if its immediate child is to
535             // be replaced, we need to make sure we don't update the replaced
536             // node; instead, we update the replacement node.
537             for (size_t jj = ii + 1; jj < mReplacements.size(); ++jj)
538             {
539                 NodeUpdateEntry &replacement2 = mReplacements[jj];
540                 if (replacement2.parent == replacement.original)
541                     replacement2.parent = replacement.replacement;
542             }
543         }
544     }
545     for (size_t ii = 0; ii < mMultiReplacements.size(); ++ii)
546     {
547         const NodeReplaceWithMultipleEntry &replacement = mMultiReplacements[ii];
548         ASSERT(replacement.parent);
549         bool replaced = replacement.parent->replaceChildNodeWithMultiple(replacement.original,
550                                                                          replacement.replacements);
551         ASSERT(replaced);
552     }
553 
554     clearReplacementQueue();
555 
556     return compiler->validateAST(node);
557 }
558 
clearReplacementQueue()559 void TIntermTraverser::clearReplacementQueue()
560 {
561     mReplacements.clear();
562     mMultiReplacements.clear();
563     mInsertions.clear();
564 }
565 
queueReplacement(TIntermNode * replacement,OriginalNode originalStatus)566 void TIntermTraverser::queueReplacement(TIntermNode *replacement, OriginalNode originalStatus)
567 {
568     queueReplacementWithParent(getParentNode(), mPath.back(), replacement, originalStatus);
569 }
570 
queueReplacementWithParent(TIntermNode * parent,TIntermNode * original,TIntermNode * replacement,OriginalNode originalStatus)571 void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
572                                                   TIntermNode *original,
573                                                   TIntermNode *replacement,
574                                                   OriginalNode originalStatus)
575 {
576     bool originalBecomesChild = (originalStatus == OriginalNode::BECOMES_CHILD);
577     mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild));
578 }
579 
TLValueTrackingTraverser(bool preVisitIn,bool inVisitIn,bool postVisitIn,TSymbolTable * symbolTable)580 TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisitIn,
581                                                    bool inVisitIn,
582                                                    bool postVisitIn,
583                                                    TSymbolTable *symbolTable)
584     : TIntermTraverser(preVisitIn, inVisitIn, postVisitIn, symbolTable),
585       mOperatorRequiresLValue(false),
586       mInFunctionCallOutParameter(false)
587 {
588     ASSERT(symbolTable);
589 }
590 
traverseAggregate(TIntermAggregate * node)591 void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
592 {
593     ScopedNodeInTraversalPath addToPath(this, node);
594     if (!addToPath.isWithinDepthLimit())
595         return;
596 
597     bool visit = true;
598 
599     TIntermSequence *sequence = node->getSequence();
600 
601     if (preVisit)
602         visit = node->visit(PreVisit, this);
603 
604     if (visit)
605     {
606         size_t paramIndex = 0u;
607         for (auto *child : *sequence)
608         {
609             if (visit)
610             {
611                 if (node->getFunction())
612                 {
613                     // Both built-ins and user defined functions should have the function symbol
614                     // set.
615                     ASSERT(paramIndex < node->getFunction()->getParamCount());
616                     TQualifier qualifier =
617                         node->getFunction()->getParam(paramIndex)->getType().getQualifier();
618                     setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
619                     ++paramIndex;
620                 }
621                 else
622                 {
623                     ASSERT(node->isConstructor());
624                 }
625                 child->traverse(this);
626                 if (inVisit)
627                 {
628                     if (child != sequence->back())
629                         visit = node->visit(InVisit, this);
630                 }
631             }
632         }
633         setInFunctionCallOutParameter(false);
634 
635         if (visit && postVisit)
636             visit = node->visit(PostVisit, this);
637     }
638 }
639 
traverseLoop(TIntermLoop * node)640 void TIntermTraverser::traverseLoop(TIntermLoop *node)
641 {
642     traverse(node);
643 }
644 }  // namespace sh
645