1 //
2 // Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // IntermTraverse.h : base classes for AST traversers that walk the AST and
7 //   also have the ability to transform it by replacing nodes.
8 
9 #ifndef COMPILER_TRANSLATOR_INTERMTRAVERSE_H_
10 #define COMPILER_TRANSLATOR_INTERMTRAVERSE_H_
11 
12 #include "compiler/translator/IntermNode.h"
13 
14 namespace sh
15 {
16 
17 class TSymbolTable;
18 class TSymbolUniqueId;
19 
20 enum Visit
21 {
22     PreVisit,
23     InVisit,
24     PostVisit
25 };
26 
27 // For traversing the tree.  User should derive from this class overriding the visit functions,
28 // and then pass an object of the subclass to a traverse method of a node.
29 //
30 // The traverse*() functions may also be overridden to do other bookkeeping on the tree to provide
31 // contextual information to the visit functions, such as whether the node is the target of an
32 // assignment. This is complex to maintain and so should only be done in special cases.
33 //
34 // When using this, just fill in the methods for nodes you want visited.
35 // Return false from a pre-visit to skip visiting that node's subtree.
36 class TIntermTraverser : angle::NonCopyable
37 {
38   public:
39     POOL_ALLOCATOR_NEW_DELETE();
40     TIntermTraverser(bool preVisit,
41                      bool inVisit,
42                      bool postVisit,
43                      TSymbolTable *symbolTable = nullptr);
44     virtual ~TIntermTraverser();
45 
visitSymbol(TIntermSymbol * node)46     virtual void visitSymbol(TIntermSymbol *node) {}
visitRaw(TIntermRaw * node)47     virtual void visitRaw(TIntermRaw *node) {}
visitConstantUnion(TIntermConstantUnion * node)48     virtual void visitConstantUnion(TIntermConstantUnion *node) {}
visitSwizzle(Visit visit,TIntermSwizzle * node)49     virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; }
visitBinary(Visit visit,TIntermBinary * node)50     virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; }
visitUnary(Visit visit,TIntermUnary * node)51     virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; }
visitTernary(Visit visit,TIntermTernary * node)52     virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; }
visitIfElse(Visit visit,TIntermIfElse * node)53     virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
visitSwitch(Visit visit,TIntermSwitch * node)54     virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
visitCase(Visit visit,TIntermCase * node)55     virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
visitFunctionPrototype(Visit visit,TIntermFunctionPrototype * node)56     virtual bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
57     {
58         return true;
59     }
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)60     virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
61     {
62         return true;
63     }
visitAggregate(Visit visit,TIntermAggregate * node)64     virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; }
visitBlock(Visit visit,TIntermBlock * node)65     virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; }
visitInvariantDeclaration(Visit visit,TIntermInvariantDeclaration * node)66     virtual bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node)
67     {
68         return true;
69     }
visitDeclaration(Visit visit,TIntermDeclaration * node)70     virtual bool visitDeclaration(Visit visit, TIntermDeclaration *node) { return true; }
visitLoop(Visit visit,TIntermLoop * node)71     virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; }
visitBranch(Visit visit,TIntermBranch * node)72     virtual bool visitBranch(Visit visit, TIntermBranch *node) { return true; }
73 
74     // The traverse functions contain logic for iterating over the children of the node
75     // and calling the visit functions in the appropriate places. They also track some
76     // context that may be used by the visit functions.
77     virtual void traverseSymbol(TIntermSymbol *node);
78     virtual void traverseRaw(TIntermRaw *node);
79     virtual void traverseConstantUnion(TIntermConstantUnion *node);
80     virtual void traverseSwizzle(TIntermSwizzle *node);
81     virtual void traverseBinary(TIntermBinary *node);
82     virtual void traverseUnary(TIntermUnary *node);
83     virtual void traverseTernary(TIntermTernary *node);
84     virtual void traverseIfElse(TIntermIfElse *node);
85     virtual void traverseSwitch(TIntermSwitch *node);
86     virtual void traverseCase(TIntermCase *node);
87     virtual void traverseFunctionPrototype(TIntermFunctionPrototype *node);
88     virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
89     virtual void traverseAggregate(TIntermAggregate *node);
90     virtual void traverseBlock(TIntermBlock *node);
91     virtual void traverseInvariantDeclaration(TIntermInvariantDeclaration *node);
92     virtual void traverseDeclaration(TIntermDeclaration *node);
93     virtual void traverseLoop(TIntermLoop *node);
94     virtual void traverseBranch(TIntermBranch *node);
95 
getMaxDepth()96     int getMaxDepth() const { return mMaxDepth; }
97 
98     // If traversers need to replace nodes, they can add the replacements in
99     // mReplacements/mMultiReplacements during traversal and the user of the traverser should call
100     // this function after traversal to perform them.
101     void updateTree();
102 
103   protected:
104     // Should only be called from traverse*() functions
incrementDepth(TIntermNode * current)105     void incrementDepth(TIntermNode *current)
106     {
107         mDepth++;
108         mMaxDepth = std::max(mMaxDepth, mDepth);
109         mPath.push_back(current);
110     }
111 
112     // Should only be called from traverse*() functions
decrementDepth()113     void decrementDepth()
114     {
115         mDepth--;
116         mPath.pop_back();
117     }
118 
119     // RAII helper for incrementDepth/decrementDepth
120     class ScopedNodeInTraversalPath
121     {
122       public:
ScopedNodeInTraversalPath(TIntermTraverser * traverser,TIntermNode * current)123         ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
124             : mTraverser(traverser)
125         {
126             mTraverser->incrementDepth(current);
127         }
~ScopedNodeInTraversalPath()128         ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
129 
130       private:
131         TIntermTraverser *mTraverser;
132     };
133 
getParentNode()134     TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
135 
136     // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode()
getAncestorNode(unsigned int n)137     TIntermNode *getAncestorNode(unsigned int n)
138     {
139         if (mPath.size() > n + 1u)
140         {
141             return mPath[mPath.size() - n - 2u];
142         }
143         return nullptr;
144     }
145 
146     const TIntermBlock *getParentBlock() const;
147 
148     void pushParentBlock(TIntermBlock *node);
149     void incrementParentBlockPos();
150     void popParentBlock();
151 
152     // To replace a single node with multiple nodes in the parent aggregate. May be used with blocks
153     // but also with other nodes like declarations.
154     struct NodeReplaceWithMultipleEntry
155     {
NodeReplaceWithMultipleEntryNodeReplaceWithMultipleEntry156         NodeReplaceWithMultipleEntry(TIntermAggregateBase *_parent,
157                                      TIntermNode *_original,
158                                      TIntermSequence _replacements)
159             : parent(_parent), original(_original), replacements(_replacements)
160         {
161         }
162 
163         TIntermAggregateBase *parent;
164         TIntermNode *original;
165         TIntermSequence replacements;
166     };
167 
168     // Helper to insert statements in the parent block of the node currently being traversed.
169     // The statements will be inserted before the node being traversed once updateTree is called.
170     // Should only be called during PreVisit or PostVisit if called from block nodes.
171     // Note that two insertions to the same position in the same block are not supported.
172     void insertStatementsInParentBlock(const TIntermSequence &insertions);
173 
174     // Same as above, but supports simultaneous insertion of statements before and after the node
175     // currently being traversed.
176     void insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
177                                        const TIntermSequence &insertionsAfter);
178 
179     // Helper to insert a single statement.
180     void insertStatementInParentBlock(TIntermNode *statement);
181 
182     // Helper to create a temporary symbol node with the given qualifier.
183     TIntermSymbol *createTempSymbol(const TType &type, TQualifier qualifier);
184     // Helper to create a temporary symbol node.
185     TIntermSymbol *createTempSymbol(const TType &type);
186     // Create a node that declares but doesn't initialize a temporary symbol.
187     TIntermDeclaration *createTempDeclaration(const TType &type);
188     // Create a node that initializes the current temporary symbol with initializer. The symbol will
189     // have the given qualifier.
190     TIntermDeclaration *createTempInitDeclaration(TIntermTyped *initializer, TQualifier qualifier);
191     // Create a node that initializes the current temporary symbol with initializer.
192     TIntermDeclaration *createTempInitDeclaration(TIntermTyped *initializer);
193     // Create a node that assigns rightNode to the current temporary symbol.
194     TIntermBinary *createTempAssignment(TIntermTyped *rightNode);
195     // Increment temporary symbol index.
196     void nextTemporaryId();
197 
198     enum class OriginalNode
199     {
200         BECOMES_CHILD,
201         IS_DROPPED
202     };
203 
204     void clearReplacementQueue();
205 
206     // Replace the node currently being visited with replacement.
207     void queueReplacement(TIntermNode *replacement, OriginalNode originalStatus);
208     // Explicitly specify a node to replace with replacement.
209     void queueReplacementWithParent(TIntermNode *parent,
210                                     TIntermNode *original,
211                                     TIntermNode *replacement,
212                                     OriginalNode originalStatus);
213 
214     const bool preVisit;
215     const bool inVisit;
216     const bool postVisit;
217 
218     int mDepth;
219     int mMaxDepth;
220 
221     bool mInGlobalScope;
222 
223     // During traversing, save all the changes that need to happen into
224     // mReplacements/mMultiReplacements, then do them by calling updateTree().
225     // Multi replacements are processed after single replacements.
226     std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements;
227 
228     TSymbolTable *mSymbolTable;
229 
230   private:
231     // To insert multiple nodes into the parent block.
232     struct NodeInsertMultipleEntry
233     {
NodeInsertMultipleEntryNodeInsertMultipleEntry234         NodeInsertMultipleEntry(TIntermBlock *_parent,
235                                 TIntermSequence::size_type _position,
236                                 TIntermSequence _insertionsBefore,
237                                 TIntermSequence _insertionsAfter)
238             : parent(_parent),
239               position(_position),
240               insertionsBefore(_insertionsBefore),
241               insertionsAfter(_insertionsAfter)
242         {
243         }
244 
245         TIntermBlock *parent;
246         TIntermSequence::size_type position;
247         TIntermSequence insertionsBefore;
248         TIntermSequence insertionsAfter;
249     };
250 
251     static bool CompareInsertion(const NodeInsertMultipleEntry &a,
252                                  const NodeInsertMultipleEntry &b);
253 
254     // To replace a single node with another on the parent node
255     struct NodeUpdateEntry
256     {
NodeUpdateEntryNodeUpdateEntry257         NodeUpdateEntry(TIntermNode *_parent,
258                         TIntermNode *_original,
259                         TIntermNode *_replacement,
260                         bool _originalBecomesChildOfReplacement)
261             : parent(_parent),
262               original(_original),
263               replacement(_replacement),
264               originalBecomesChildOfReplacement(_originalBecomesChildOfReplacement)
265         {
266         }
267 
268         TIntermNode *parent;
269         TIntermNode *original;
270         TIntermNode *replacement;
271         bool originalBecomesChildOfReplacement;
272     };
273 
274     struct ParentBlock
275     {
ParentBlockParentBlock276         ParentBlock(TIntermBlock *nodeIn, TIntermSequence::size_type posIn)
277             : node(nodeIn), pos(posIn)
278         {
279         }
280 
281         TIntermBlock *node;
282         TIntermSequence::size_type pos;
283     };
284 
285     std::vector<NodeInsertMultipleEntry> mInsertions;
286     std::vector<NodeUpdateEntry> mReplacements;
287 
288     // All the nodes from root to the current node during traversing.
289     TVector<TIntermNode *> mPath;
290 
291     // All the code blocks from the root to the current node's parent during traversal.
292     std::vector<ParentBlock> mParentBlockStack;
293 
294     TSymbolUniqueId *mTemporaryId;
295 };
296 
297 // Traverser parent class that tracks where a node is a destination of a write operation and so is
298 // required to be an l-value.
299 class TLValueTrackingTraverser : public TIntermTraverser
300 {
301   public:
302     TLValueTrackingTraverser(bool preVisit,
303                              bool inVisit,
304                              bool postVisit,
305                              TSymbolTable *symbolTable,
306                              int shaderVersion);
~TLValueTrackingTraverser()307     virtual ~TLValueTrackingTraverser() {}
308 
309     void traverseBinary(TIntermBinary *node) final;
310     void traverseUnary(TIntermUnary *node) final;
311     void traverseFunctionPrototype(TIntermFunctionPrototype *node) final;
312     void traverseAggregate(TIntermAggregate *node) final;
313 
314   protected:
isLValueRequiredHere()315     bool isLValueRequiredHere() const
316     {
317         return mOperatorRequiresLValue || mInFunctionCallOutParameter;
318     }
319 
320   private:
321     // Track whether an l-value is required in the node that is currently being traversed by the
322     // surrounding operator.
323     // Use isLValueRequiredHere to check all conditions which require an l-value.
setOperatorRequiresLValue(bool lValueRequired)324     void setOperatorRequiresLValue(bool lValueRequired)
325     {
326         mOperatorRequiresLValue = lValueRequired;
327     }
operatorRequiresLValue()328     bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
329 
330     // Add a function encountered during traversal to the function map.
331     void addToFunctionMap(const TSymbolUniqueId &id, TIntermSequence *paramSequence);
332 
333     // Return true if the prototype or definition of the function being called has been encountered
334     // during traversal.
335     bool isInFunctionMap(const TIntermAggregate *callNode) const;
336 
337     // Return the parameters sequence from the function definition or prototype.
338     TIntermSequence *getFunctionParameters(const TIntermAggregate *callNode);
339 
340     // Track whether an l-value is required inside a function call.
341     void setInFunctionCallOutParameter(bool inOutParameter);
342     bool isInFunctionCallOutParameter() const;
343 
344     bool mOperatorRequiresLValue;
345     bool mInFunctionCallOutParameter;
346 
347     // Map from function symbol id values to their parameter sequences
348     TMap<int, TIntermSequence *> mFunctionMap;
349 
350     const int mShaderVersion;
351 };
352 
353 }  // namespace sh
354 
355 #endif  // COMPILER_TRANSLATOR_INTERMTRAVERSE_H_
356