1 //
2 // Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/IntermTraverse.h"
8 
9 #include "compiler/translator/InfoSink.h"
10 #include "compiler/translator/IntermNode_util.h"
11 #include "compiler/translator/SymbolTable.h"
12 
13 namespace sh
14 {
15 
traverse(TIntermTraverser * it)16 void TIntermSymbol::traverse(TIntermTraverser *it)
17 {
18     it->traverseSymbol(this);
19 }
20 
traverse(TIntermTraverser * it)21 void TIntermRaw::traverse(TIntermTraverser *it)
22 {
23     it->traverseRaw(this);
24 }
25 
traverse(TIntermTraverser * it)26 void TIntermConstantUnion::traverse(TIntermTraverser *it)
27 {
28     it->traverseConstantUnion(this);
29 }
30 
traverse(TIntermTraverser * it)31 void TIntermSwizzle::traverse(TIntermTraverser *it)
32 {
33     it->traverseSwizzle(this);
34 }
35 
traverse(TIntermTraverser * it)36 void TIntermBinary::traverse(TIntermTraverser *it)
37 {
38     it->traverseBinary(this);
39 }
40 
traverse(TIntermTraverser * it)41 void TIntermUnary::traverse(TIntermTraverser *it)
42 {
43     it->traverseUnary(this);
44 }
45 
traverse(TIntermTraverser * it)46 void TIntermTernary::traverse(TIntermTraverser *it)
47 {
48     it->traverseTernary(this);
49 }
50 
traverse(TIntermTraverser * it)51 void TIntermIfElse::traverse(TIntermTraverser *it)
52 {
53     it->traverseIfElse(this);
54 }
55 
traverse(TIntermTraverser * it)56 void TIntermSwitch::traverse(TIntermTraverser *it)
57 {
58     it->traverseSwitch(this);
59 }
60 
traverse(TIntermTraverser * it)61 void TIntermCase::traverse(TIntermTraverser *it)
62 {
63     it->traverseCase(this);
64 }
65 
traverse(TIntermTraverser * it)66 void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
67 {
68     it->traverseFunctionDefinition(this);
69 }
70 
traverse(TIntermTraverser * it)71 void TIntermBlock::traverse(TIntermTraverser *it)
72 {
73     it->traverseBlock(this);
74 }
75 
traverse(TIntermTraverser * it)76 void TIntermInvariantDeclaration::traverse(TIntermTraverser *it)
77 {
78     it->traverseInvariantDeclaration(this);
79 }
80 
traverse(TIntermTraverser * it)81 void TIntermDeclaration::traverse(TIntermTraverser *it)
82 {
83     it->traverseDeclaration(this);
84 }
85 
traverse(TIntermTraverser * it)86 void TIntermFunctionPrototype::traverse(TIntermTraverser *it)
87 {
88     it->traverseFunctionPrototype(this);
89 }
90 
traverse(TIntermTraverser * it)91 void TIntermAggregate::traverse(TIntermTraverser *it)
92 {
93     it->traverseAggregate(this);
94 }
95 
traverse(TIntermTraverser * it)96 void TIntermLoop::traverse(TIntermTraverser *it)
97 {
98     it->traverseLoop(this);
99 }
100 
traverse(TIntermTraverser * it)101 void TIntermBranch::traverse(TIntermTraverser *it)
102 {
103     it->traverseBranch(this);
104 }
105 
TIntermTraverser(bool preVisit,bool inVisit,bool postVisit,TSymbolTable * symbolTable)106 TIntermTraverser::TIntermTraverser(bool preVisit,
107                                    bool inVisit,
108                                    bool postVisit,
109                                    TSymbolTable *symbolTable)
110     : preVisit(preVisit),
111       inVisit(inVisit),
112       postVisit(postVisit),
113       mDepth(-1),
114       mMaxDepth(0),
115       mInGlobalScope(true),
116       mSymbolTable(symbolTable),
117       mTemporaryId(nullptr)
118 {
119 }
120 
~TIntermTraverser()121 TIntermTraverser::~TIntermTraverser()
122 {
123 }
124 
getParentBlock() const125 const TIntermBlock *TIntermTraverser::getParentBlock() const
126 {
127     if (!mParentBlockStack.empty())
128     {
129         return mParentBlockStack.back().node;
130     }
131     return nullptr;
132 }
133 
pushParentBlock(TIntermBlock * node)134 void TIntermTraverser::pushParentBlock(TIntermBlock *node)
135 {
136     mParentBlockStack.push_back(ParentBlock(node, 0));
137 }
138 
incrementParentBlockPos()139 void TIntermTraverser::incrementParentBlockPos()
140 {
141     ++mParentBlockStack.back().pos;
142 }
143 
popParentBlock()144 void TIntermTraverser::popParentBlock()
145 {
146     ASSERT(!mParentBlockStack.empty());
147     mParentBlockStack.pop_back();
148 }
149 
insertStatementsInParentBlock(const TIntermSequence & insertions)150 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertions)
151 {
152     TIntermSequence emptyInsertionsAfter;
153     insertStatementsInParentBlock(insertions, emptyInsertionsAfter);
154 }
155 
insertStatementsInParentBlock(const TIntermSequence & insertionsBefore,const TIntermSequence & insertionsAfter)156 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
157                                                      const TIntermSequence &insertionsAfter)
158 {
159     ASSERT(!mParentBlockStack.empty());
160     ParentBlock &parentBlock = mParentBlockStack.back();
161     if (mPath.back() == parentBlock.node)
162     {
163         ASSERT(mParentBlockStack.size() >= 2u);
164         // The current node is a block node, so the parent block is not the topmost one in the block
165         // stack, but the one below that.
166         parentBlock = mParentBlockStack.at(mParentBlockStack.size() - 2u);
167     }
168     NodeInsertMultipleEntry insert(parentBlock.node, parentBlock.pos, insertionsBefore,
169                                    insertionsAfter);
170     mInsertions.push_back(insert);
171 }
172 
insertStatementInParentBlock(TIntermNode * statement)173 void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement)
174 {
175     TIntermSequence insertions;
176     insertions.push_back(statement);
177     insertStatementsInParentBlock(insertions);
178 }
179 
createTempSymbol(const TType & type,TQualifier qualifier)180 TIntermSymbol *TIntermTraverser::createTempSymbol(const TType &type, TQualifier qualifier)
181 {
182     ASSERT(mTemporaryId != nullptr);
183     // nextTemporaryId() needs to be called when the code wants to start using another temporary
184     // symbol.
185     return CreateTempSymbolNode(*mTemporaryId, type, qualifier);
186 }
187 
createTempSymbol(const TType & type)188 TIntermSymbol *TIntermTraverser::createTempSymbol(const TType &type)
189 {
190     return createTempSymbol(type, EvqTemporary);
191 }
192 
createTempDeclaration(const TType & type)193 TIntermDeclaration *TIntermTraverser::createTempDeclaration(const TType &type)
194 {
195     ASSERT(mTemporaryId != nullptr);
196     TIntermDeclaration *tempDeclaration = new TIntermDeclaration();
197     tempDeclaration->appendDeclarator(CreateTempSymbolNode(*mTemporaryId, type, EvqTemporary));
198     return tempDeclaration;
199 }
200 
createTempInitDeclaration(TIntermTyped * initializer,TQualifier qualifier)201 TIntermDeclaration *TIntermTraverser::createTempInitDeclaration(TIntermTyped *initializer,
202                                                                 TQualifier qualifier)
203 {
204     ASSERT(mTemporaryId != nullptr);
205     return CreateTempInitDeclarationNode(*mTemporaryId, initializer, qualifier);
206 }
207 
createTempInitDeclaration(TIntermTyped * initializer)208 TIntermDeclaration *TIntermTraverser::createTempInitDeclaration(TIntermTyped *initializer)
209 {
210     return createTempInitDeclaration(initializer, EvqTemporary);
211 }
212 
createTempAssignment(TIntermTyped * rightNode)213 TIntermBinary *TIntermTraverser::createTempAssignment(TIntermTyped *rightNode)
214 {
215     ASSERT(rightNode != nullptr);
216     TIntermSymbol *tempSymbol = createTempSymbol(rightNode->getType());
217     TIntermBinary *assignment = new TIntermBinary(EOpAssign, tempSymbol, rightNode);
218     return assignment;
219 }
220 
nextTemporaryId()221 void TIntermTraverser::nextTemporaryId()
222 {
223     ASSERT(mSymbolTable);
224     if (!mTemporaryId)
225     {
226         mTemporaryId = new TSymbolUniqueId(mSymbolTable);
227         return;
228     }
229     *mTemporaryId = TSymbolUniqueId(mSymbolTable);
230 }
231 
addToFunctionMap(const TSymbolUniqueId & id,TIntermSequence * paramSequence)232 void TLValueTrackingTraverser::addToFunctionMap(const TSymbolUniqueId &id,
233                                                 TIntermSequence *paramSequence)
234 {
235     mFunctionMap[id.get()] = paramSequence;
236 }
237 
isInFunctionMap(const TIntermAggregate * callNode) const238 bool TLValueTrackingTraverser::isInFunctionMap(const TIntermAggregate *callNode) const
239 {
240     ASSERT(callNode->getOp() == EOpCallFunctionInAST);
241     return (mFunctionMap.find(callNode->getFunctionSymbolInfo()->getId().get()) !=
242             mFunctionMap.end());
243 }
244 
getFunctionParameters(const TIntermAggregate * callNode)245 TIntermSequence *TLValueTrackingTraverser::getFunctionParameters(const TIntermAggregate *callNode)
246 {
247     ASSERT(isInFunctionMap(callNode));
248     return mFunctionMap[callNode->getFunctionSymbolInfo()->getId().get()];
249 }
250 
setInFunctionCallOutParameter(bool inOutParameter)251 void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter)
252 {
253     mInFunctionCallOutParameter = inOutParameter;
254 }
255 
isInFunctionCallOutParameter() const256 bool TLValueTrackingTraverser::isInFunctionCallOutParameter() const
257 {
258     return mInFunctionCallOutParameter;
259 }
260 
261 //
262 // Traverse the intermediate representation tree, and
263 // call a node type specific function for each node.
264 // Done recursively through the member function Traverse().
265 // Node types can be skipped if their function to call is 0,
266 // but their subtree will still be traversed.
267 // Nodes with children can have their whole subtree skipped
268 // if preVisit is turned on and the type specific function
269 // returns false.
270 //
271 
272 //
273 // Traversal functions for terminals are straighforward....
274 //
traverseSymbol(TIntermSymbol * node)275 void TIntermTraverser::traverseSymbol(TIntermSymbol *node)
276 {
277     ScopedNodeInTraversalPath addToPath(this, node);
278     visitSymbol(node);
279 }
280 
traverseConstantUnion(TIntermConstantUnion * node)281 void TIntermTraverser::traverseConstantUnion(TIntermConstantUnion *node)
282 {
283     ScopedNodeInTraversalPath addToPath(this, node);
284     visitConstantUnion(node);
285 }
286 
traverseSwizzle(TIntermSwizzle * node)287 void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node)
288 {
289     ScopedNodeInTraversalPath addToPath(this, node);
290 
291     bool visit = true;
292 
293     if (preVisit)
294         visit = visitSwizzle(PreVisit, node);
295 
296     if (visit)
297     {
298         node->getOperand()->traverse(this);
299     }
300 
301     if (visit && postVisit)
302         visitSwizzle(PostVisit, node);
303 }
304 
305 //
306 // Traverse a binary node.
307 //
traverseBinary(TIntermBinary * node)308 void TIntermTraverser::traverseBinary(TIntermBinary *node)
309 {
310     ScopedNodeInTraversalPath addToPath(this, node);
311 
312     bool visit = true;
313 
314     //
315     // visit the node before children if pre-visiting.
316     //
317     if (preVisit)
318         visit = visitBinary(PreVisit, node);
319 
320     //
321     // Visit the children, in the right order.
322     //
323     if (visit)
324     {
325         if (node->getLeft())
326             node->getLeft()->traverse(this);
327 
328         if (inVisit)
329             visit = visitBinary(InVisit, node);
330 
331         if (visit && node->getRight())
332             node->getRight()->traverse(this);
333     }
334 
335     //
336     // Visit the node after the children, if requested and the traversal
337     // hasn't been cancelled yet.
338     //
339     if (visit && postVisit)
340         visitBinary(PostVisit, node);
341 }
342 
traverseBinary(TIntermBinary * node)343 void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
344 {
345     ScopedNodeInTraversalPath addToPath(this, node);
346 
347     bool visit = true;
348 
349     //
350     // visit the node before children if pre-visiting.
351     //
352     if (preVisit)
353         visit = visitBinary(PreVisit, node);
354 
355     //
356     // Visit the children, in the right order.
357     //
358     if (visit)
359     {
360         // Some binary operations like indexing can be inside an expression which must be an
361         // l-value.
362         bool parentOperatorRequiresLValue     = operatorRequiresLValue();
363         bool parentInFunctionCallOutParameter = isInFunctionCallOutParameter();
364         if (node->isAssignment())
365         {
366             ASSERT(!isLValueRequiredHere());
367             setOperatorRequiresLValue(true);
368         }
369 
370         if (node->getLeft())
371             node->getLeft()->traverse(this);
372 
373         if (inVisit)
374             visit = visitBinary(InVisit, node);
375 
376         if (node->isAssignment())
377             setOperatorRequiresLValue(false);
378 
379         // Index is not required to be an l-value even when the surrounding expression is required
380         // to be an l-value.
381         TOperator op = node->getOp();
382         if (op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
383             op == EOpIndexDirectStruct || op == EOpIndexIndirect)
384         {
385             setOperatorRequiresLValue(false);
386             setInFunctionCallOutParameter(false);
387         }
388 
389         if (visit && node->getRight())
390             node->getRight()->traverse(this);
391 
392         setOperatorRequiresLValue(parentOperatorRequiresLValue);
393         setInFunctionCallOutParameter(parentInFunctionCallOutParameter);
394     }
395 
396     //
397     // Visit the node after the children, if requested and the traversal
398     // hasn't been cancelled yet.
399     //
400     if (visit && postVisit)
401         visitBinary(PostVisit, node);
402 }
403 
404 //
405 // Traverse a unary node.  Same comments in binary node apply here.
406 //
traverseUnary(TIntermUnary * node)407 void TIntermTraverser::traverseUnary(TIntermUnary *node)
408 {
409     ScopedNodeInTraversalPath addToPath(this, node);
410 
411     bool visit = true;
412 
413     if (preVisit)
414         visit = visitUnary(PreVisit, node);
415 
416     if (visit)
417     {
418         node->getOperand()->traverse(this);
419     }
420 
421     if (visit && postVisit)
422         visitUnary(PostVisit, node);
423 }
424 
traverseUnary(TIntermUnary * node)425 void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
426 {
427     ScopedNodeInTraversalPath addToPath(this, node);
428 
429     bool visit = true;
430 
431     if (preVisit)
432         visit = visitUnary(PreVisit, node);
433 
434     if (visit)
435     {
436         ASSERT(!operatorRequiresLValue());
437         switch (node->getOp())
438         {
439             case EOpPostIncrement:
440             case EOpPostDecrement:
441             case EOpPreIncrement:
442             case EOpPreDecrement:
443                 setOperatorRequiresLValue(true);
444                 break;
445             default:
446                 break;
447         }
448 
449         node->getOperand()->traverse(this);
450 
451         setOperatorRequiresLValue(false);
452     }
453 
454     if (visit && postVisit)
455         visitUnary(PostVisit, node);
456 }
457 
458 // Traverse a function definition node.
traverseFunctionDefinition(TIntermFunctionDefinition * node)459 void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
460 {
461     ScopedNodeInTraversalPath addToPath(this, node);
462 
463     bool visit = true;
464 
465     if (preVisit)
466         visit = visitFunctionDefinition(PreVisit, node);
467 
468     if (visit)
469     {
470         mInGlobalScope = false;
471 
472         node->getFunctionPrototype()->traverse(this);
473         if (inVisit)
474             visit = visitFunctionDefinition(InVisit, node);
475         node->getBody()->traverse(this);
476 
477         mInGlobalScope = true;
478     }
479 
480     if (visit && postVisit)
481         visitFunctionDefinition(PostVisit, node);
482 }
483 
484 // Traverse a block node.
traverseBlock(TIntermBlock * node)485 void TIntermTraverser::traverseBlock(TIntermBlock *node)
486 {
487     ScopedNodeInTraversalPath addToPath(this, node);
488     pushParentBlock(node);
489 
490     bool visit = true;
491 
492     TIntermSequence *sequence = node->getSequence();
493 
494     if (preVisit)
495         visit = visitBlock(PreVisit, node);
496 
497     if (visit)
498     {
499         for (auto *child : *sequence)
500         {
501             child->traverse(this);
502             if (visit && inVisit)
503             {
504                 if (child != sequence->back())
505                     visit = visitBlock(InVisit, node);
506             }
507 
508             incrementParentBlockPos();
509         }
510     }
511 
512     if (visit && postVisit)
513         visitBlock(PostVisit, node);
514 
515     popParentBlock();
516 }
517 
traverseInvariantDeclaration(TIntermInvariantDeclaration * node)518 void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration *node)
519 {
520     ScopedNodeInTraversalPath addToPath(this, node);
521 
522     bool visit = true;
523 
524     if (preVisit)
525     {
526         visit = visitInvariantDeclaration(PreVisit, node);
527     }
528 
529     if (visit)
530     {
531         node->getSymbol()->traverse(this);
532         if (postVisit)
533         {
534             visitInvariantDeclaration(PostVisit, node);
535         }
536     }
537 }
538 
539 // Traverse a declaration node.
traverseDeclaration(TIntermDeclaration * node)540 void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
541 {
542     ScopedNodeInTraversalPath addToPath(this, node);
543 
544     bool visit = true;
545 
546     TIntermSequence *sequence = node->getSequence();
547 
548     if (preVisit)
549         visit = visitDeclaration(PreVisit, node);
550 
551     if (visit)
552     {
553         for (auto *child : *sequence)
554         {
555             child->traverse(this);
556             if (visit && inVisit)
557             {
558                 if (child != sequence->back())
559                     visit = visitDeclaration(InVisit, node);
560             }
561         }
562     }
563 
564     if (visit && postVisit)
565         visitDeclaration(PostVisit, node);
566 }
567 
traverseFunctionPrototype(TIntermFunctionPrototype * node)568 void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
569 {
570     ScopedNodeInTraversalPath addToPath(this, node);
571 
572     bool visit = true;
573 
574     TIntermSequence *sequence = node->getSequence();
575 
576     if (preVisit)
577         visit = visitFunctionPrototype(PreVisit, node);
578 
579     if (visit)
580     {
581         for (auto *child : *sequence)
582         {
583             child->traverse(this);
584             if (visit && inVisit)
585             {
586                 if (child != sequence->back())
587                     visit = visitFunctionPrototype(InVisit, node);
588             }
589         }
590     }
591 
592     if (visit && postVisit)
593         visitFunctionPrototype(PostVisit, node);
594 }
595 
596 // Traverse an aggregate node.  Same comments in binary node apply here.
traverseAggregate(TIntermAggregate * node)597 void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
598 {
599     ScopedNodeInTraversalPath addToPath(this, node);
600 
601     bool visit = true;
602 
603     TIntermSequence *sequence = node->getSequence();
604 
605     if (preVisit)
606         visit = visitAggregate(PreVisit, node);
607 
608     if (visit)
609     {
610         for (auto *child : *sequence)
611         {
612             child->traverse(this);
613             if (visit && inVisit)
614             {
615                 if (child != sequence->back())
616                     visit = visitAggregate(InVisit, node);
617             }
618         }
619     }
620 
621     if (visit && postVisit)
622         visitAggregate(PostVisit, node);
623 }
624 
CompareInsertion(const NodeInsertMultipleEntry & a,const NodeInsertMultipleEntry & b)625 bool TIntermTraverser::CompareInsertion(const NodeInsertMultipleEntry &a,
626                                         const NodeInsertMultipleEntry &b)
627 {
628     if (a.parent != b.parent)
629     {
630         return a.parent > b.parent;
631     }
632     return a.position > b.position;
633 }
634 
updateTree()635 void TIntermTraverser::updateTree()
636 {
637     // Sort the insertions so that insertion position is decreasing. This way multiple insertions to
638     // the same parent node are handled correctly.
639     std::sort(mInsertions.begin(), mInsertions.end(), CompareInsertion);
640     for (size_t ii = 0; ii < mInsertions.size(); ++ii)
641     {
642         // We can't know here what the intended ordering of two insertions to the same position is,
643         // so it is not supported.
644         ASSERT(ii == 0 || mInsertions[ii].position != mInsertions[ii - 1].position ||
645                mInsertions[ii].parent != mInsertions[ii - 1].parent);
646         const NodeInsertMultipleEntry &insertion = mInsertions[ii];
647         ASSERT(insertion.parent);
648         if (!insertion.insertionsAfter.empty())
649         {
650             bool inserted = insertion.parent->insertChildNodes(insertion.position + 1,
651                                                                insertion.insertionsAfter);
652             ASSERT(inserted);
653         }
654         if (!insertion.insertionsBefore.empty())
655         {
656             bool inserted =
657                 insertion.parent->insertChildNodes(insertion.position, insertion.insertionsBefore);
658             ASSERT(inserted);
659         }
660     }
661     for (size_t ii = 0; ii < mReplacements.size(); ++ii)
662     {
663         const NodeUpdateEntry &replacement = mReplacements[ii];
664         ASSERT(replacement.parent);
665         bool replaced =
666             replacement.parent->replaceChildNode(replacement.original, replacement.replacement);
667         ASSERT(replaced);
668 
669         if (!replacement.originalBecomesChildOfReplacement)
670         {
671             // In AST traversing, a parent is visited before its children.
672             // After we replace a node, if its immediate child is to
673             // be replaced, we need to make sure we don't update the replaced
674             // node; instead, we update the replacement node.
675             for (size_t jj = ii + 1; jj < mReplacements.size(); ++jj)
676             {
677                 NodeUpdateEntry &replacement2 = mReplacements[jj];
678                 if (replacement2.parent == replacement.original)
679                     replacement2.parent = replacement.replacement;
680             }
681         }
682     }
683     for (size_t ii = 0; ii < mMultiReplacements.size(); ++ii)
684     {
685         const NodeReplaceWithMultipleEntry &replacement = mMultiReplacements[ii];
686         ASSERT(replacement.parent);
687         bool replaced = replacement.parent->replaceChildNodeWithMultiple(replacement.original,
688                                                                          replacement.replacements);
689         ASSERT(replaced);
690     }
691 
692     clearReplacementQueue();
693 }
694 
clearReplacementQueue()695 void TIntermTraverser::clearReplacementQueue()
696 {
697     mReplacements.clear();
698     mMultiReplacements.clear();
699     mInsertions.clear();
700 }
701 
queueReplacement(TIntermNode * replacement,OriginalNode originalStatus)702 void TIntermTraverser::queueReplacement(TIntermNode *replacement, OriginalNode originalStatus)
703 {
704     queueReplacementWithParent(getParentNode(), mPath.back(), replacement, originalStatus);
705 }
706 
queueReplacementWithParent(TIntermNode * parent,TIntermNode * original,TIntermNode * replacement,OriginalNode originalStatus)707 void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
708                                                   TIntermNode *original,
709                                                   TIntermNode *replacement,
710                                                   OriginalNode originalStatus)
711 {
712     bool originalBecomesChild = (originalStatus == OriginalNode::BECOMES_CHILD);
713     mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild));
714 }
715 
TLValueTrackingTraverser(bool preVisit,bool inVisit,bool postVisit,TSymbolTable * symbolTable,int shaderVersion)716 TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisit,
717                                                    bool inVisit,
718                                                    bool postVisit,
719                                                    TSymbolTable *symbolTable,
720                                                    int shaderVersion)
721     : TIntermTraverser(preVisit, inVisit, postVisit, symbolTable),
722       mOperatorRequiresLValue(false),
723       mInFunctionCallOutParameter(false),
724       mShaderVersion(shaderVersion)
725 {
726     ASSERT(symbolTable);
727 }
728 
traverseFunctionPrototype(TIntermFunctionPrototype * node)729 void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
730 {
731     TIntermSequence *sequence = node->getSequence();
732     addToFunctionMap(node->getFunctionSymbolInfo()->getId(), sequence);
733 
734     TIntermTraverser::traverseFunctionPrototype(node);
735 }
736 
traverseAggregate(TIntermAggregate * node)737 void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
738 {
739     ScopedNodeInTraversalPath addToPath(this, node);
740 
741     bool visit = true;
742 
743     TIntermSequence *sequence = node->getSequence();
744 
745     if (preVisit)
746         visit = visitAggregate(PreVisit, node);
747 
748     if (visit)
749     {
750         if (node->getOp() == EOpCallFunctionInAST)
751         {
752             if (isInFunctionMap(node))
753             {
754                 TIntermSequence *params             = getFunctionParameters(node);
755                 TIntermSequence::iterator paramIter = params->begin();
756                 for (auto *child : *sequence)
757                 {
758                     ASSERT(paramIter != params->end());
759                     TQualifier qualifier = (*paramIter)->getAsTyped()->getQualifier();
760                     setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
761 
762                     child->traverse(this);
763                     if (visit && inVisit)
764                     {
765                         if (child != sequence->back())
766                             visit = visitAggregate(InVisit, node);
767                     }
768 
769                     ++paramIter;
770                 }
771             }
772             else
773             {
774                 // The node might not be in the function map in case we're in the middle of
775                 // transforming the AST, and have inserted function call nodes without inserting the
776                 // function definitions yet.
777                 setInFunctionCallOutParameter(false);
778                 for (auto *child : *sequence)
779                 {
780                     child->traverse(this);
781                     if (visit && inVisit)
782                     {
783                         if (child != sequence->back())
784                             visit = visitAggregate(InVisit, node);
785                     }
786                 }
787             }
788 
789             setInFunctionCallOutParameter(false);
790         }
791         else
792         {
793             // Find the built-in function corresponding to this op so that we can determine the
794             // in/out qualifiers of its parameters.
795             TFunction *builtInFunc = nullptr;
796             if (!node->isFunctionCall() && !node->isConstructor())
797             {
798                 builtInFunc = static_cast<TFunction *>(
799                     mSymbolTable->findBuiltIn(node->getSymbolTableMangledName(), mShaderVersion));
800             }
801 
802             size_t paramIndex = 0;
803 
804             for (auto *child : *sequence)
805             {
806                 // This assumes that raw functions called with
807                 // EOpCallInternalRawFunction don't have out parameters.
808                 TQualifier qualifier = EvqIn;
809                 if (builtInFunc != nullptr)
810                     qualifier = builtInFunc->getParam(paramIndex).type->getQualifier();
811                 setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
812                 child->traverse(this);
813 
814                 if (visit && inVisit)
815                 {
816                     if (child != sequence->back())
817                         visit = visitAggregate(InVisit, node);
818                 }
819 
820                 ++paramIndex;
821             }
822 
823             setInFunctionCallOutParameter(false);
824         }
825     }
826 
827     if (visit && postVisit)
828         visitAggregate(PostVisit, node);
829 }
830 
831 //
832 // Traverse a ternary node.  Same comments in binary node apply here.
833 //
traverseTernary(TIntermTernary * node)834 void TIntermTraverser::traverseTernary(TIntermTernary *node)
835 {
836     ScopedNodeInTraversalPath addToPath(this, node);
837 
838     bool visit = true;
839 
840     if (preVisit)
841         visit = visitTernary(PreVisit, node);
842 
843     if (visit)
844     {
845         node->getCondition()->traverse(this);
846         if (node->getTrueExpression())
847             node->getTrueExpression()->traverse(this);
848         if (node->getFalseExpression())
849             node->getFalseExpression()->traverse(this);
850     }
851 
852     if (visit && postVisit)
853         visitTernary(PostVisit, node);
854 }
855 
856 // Traverse an if-else node.  Same comments in binary node apply here.
traverseIfElse(TIntermIfElse * node)857 void TIntermTraverser::traverseIfElse(TIntermIfElse *node)
858 {
859     ScopedNodeInTraversalPath addToPath(this, node);
860 
861     bool visit = true;
862 
863     if (preVisit)
864         visit = visitIfElse(PreVisit, node);
865 
866     if (visit)
867     {
868         node->getCondition()->traverse(this);
869         if (node->getTrueBlock())
870             node->getTrueBlock()->traverse(this);
871         if (node->getFalseBlock())
872             node->getFalseBlock()->traverse(this);
873     }
874 
875     if (visit && postVisit)
876         visitIfElse(PostVisit, node);
877 }
878 
879 //
880 // Traverse a switch node.  Same comments in binary node apply here.
881 //
traverseSwitch(TIntermSwitch * node)882 void TIntermTraverser::traverseSwitch(TIntermSwitch *node)
883 {
884     ScopedNodeInTraversalPath addToPath(this, node);
885 
886     bool visit = true;
887 
888     if (preVisit)
889         visit = visitSwitch(PreVisit, node);
890 
891     if (visit)
892     {
893         node->getInit()->traverse(this);
894         if (inVisit)
895             visit = visitSwitch(InVisit, node);
896         if (visit && node->getStatementList())
897             node->getStatementList()->traverse(this);
898     }
899 
900     if (visit && postVisit)
901         visitSwitch(PostVisit, node);
902 }
903 
904 //
905 // Traverse a case node.  Same comments in binary node apply here.
906 //
traverseCase(TIntermCase * node)907 void TIntermTraverser::traverseCase(TIntermCase *node)
908 {
909     ScopedNodeInTraversalPath addToPath(this, node);
910 
911     bool visit = true;
912 
913     if (preVisit)
914         visit = visitCase(PreVisit, node);
915 
916     if (visit && node->getCondition())
917     {
918         node->getCondition()->traverse(this);
919     }
920 
921     if (visit && postVisit)
922         visitCase(PostVisit, node);
923 }
924 
925 //
926 // Traverse a loop node.  Same comments in binary node apply here.
927 //
traverseLoop(TIntermLoop * node)928 void TIntermTraverser::traverseLoop(TIntermLoop *node)
929 {
930     ScopedNodeInTraversalPath addToPath(this, node);
931 
932     bool visit = true;
933 
934     if (preVisit)
935         visit = visitLoop(PreVisit, node);
936 
937     if (visit)
938     {
939         if (node->getInit())
940             node->getInit()->traverse(this);
941 
942         if (node->getCondition())
943             node->getCondition()->traverse(this);
944 
945         if (node->getBody())
946             node->getBody()->traverse(this);
947 
948         if (node->getExpression())
949             node->getExpression()->traverse(this);
950     }
951 
952     if (visit && postVisit)
953         visitLoop(PostVisit, node);
954 }
955 
956 //
957 // Traverse a branch node.  Same comments in binary node apply here.
958 //
traverseBranch(TIntermBranch * node)959 void TIntermTraverser::traverseBranch(TIntermBranch *node)
960 {
961     ScopedNodeInTraversalPath addToPath(this, node);
962 
963     bool visit = true;
964 
965     if (preVisit)
966         visit = visitBranch(PreVisit, node);
967 
968     if (visit && node->getExpression())
969     {
970         node->getExpression()->traverse(this);
971     }
972 
973     if (visit && postVisit)
974         visitBranch(PostVisit, node);
975 }
976 
traverseRaw(TIntermRaw * node)977 void TIntermTraverser::traverseRaw(TIntermRaw *node)
978 {
979     ScopedNodeInTraversalPath addToPath(this, node);
980     visitRaw(node);
981 }
982 
983 }  // namespace sh
984