1 // 2 // Copyright (c) 2017 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // IntermTraverse.h : base classes for AST traversers that walk the AST and 7 // also have the ability to transform it by replacing nodes. 8 9 #ifndef COMPILER_TRANSLATOR_INTERMTRAVERSE_H_ 10 #define COMPILER_TRANSLATOR_INTERMTRAVERSE_H_ 11 12 #include "compiler/translator/IntermNode.h" 13 14 namespace sh 15 { 16 17 class TSymbolTable; 18 class TSymbolUniqueId; 19 20 enum Visit 21 { 22 PreVisit, 23 InVisit, 24 PostVisit 25 }; 26 27 // For traversing the tree. User should derive from this class overriding the visit functions, 28 // and then pass an object of the subclass to a traverse method of a node. 29 // 30 // The traverse*() functions may also be overridden to do other bookkeeping on the tree to provide 31 // contextual information to the visit functions, such as whether the node is the target of an 32 // assignment. This is complex to maintain and so should only be done in special cases. 33 // 34 // When using this, just fill in the methods for nodes you want visited. 35 // Return false from a pre-visit to skip visiting that node's subtree. 36 class TIntermTraverser : angle::NonCopyable 37 { 38 public: 39 POOL_ALLOCATOR_NEW_DELETE(); 40 TIntermTraverser(bool preVisit, 41 bool inVisit, 42 bool postVisit, 43 TSymbolTable *symbolTable = nullptr); 44 virtual ~TIntermTraverser(); 45 visitSymbol(TIntermSymbol * node)46 virtual void visitSymbol(TIntermSymbol *node) {} visitRaw(TIntermRaw * node)47 virtual void visitRaw(TIntermRaw *node) {} visitConstantUnion(TIntermConstantUnion * node)48 virtual void visitConstantUnion(TIntermConstantUnion *node) {} visitSwizzle(Visit visit,TIntermSwizzle * node)49 virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; } visitBinary(Visit visit,TIntermBinary * node)50 virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; } visitUnary(Visit visit,TIntermUnary * node)51 virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; } visitTernary(Visit visit,TIntermTernary * node)52 virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; } visitIfElse(Visit visit,TIntermIfElse * node)53 virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; } visitSwitch(Visit visit,TIntermSwitch * node)54 virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; } visitCase(Visit visit,TIntermCase * node)55 virtual bool visitCase(Visit visit, TIntermCase *node) { return true; } visitFunctionPrototype(Visit visit,TIntermFunctionPrototype * node)56 virtual bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) 57 { 58 return true; 59 } visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)60 virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) 61 { 62 return true; 63 } visitAggregate(Visit visit,TIntermAggregate * node)64 virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; } visitBlock(Visit visit,TIntermBlock * node)65 virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; } visitInvariantDeclaration(Visit visit,TIntermInvariantDeclaration * node)66 virtual bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node) 67 { 68 return true; 69 } visitDeclaration(Visit visit,TIntermDeclaration * node)70 virtual bool visitDeclaration(Visit visit, TIntermDeclaration *node) { return true; } visitLoop(Visit visit,TIntermLoop * node)71 virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; } visitBranch(Visit visit,TIntermBranch * node)72 virtual bool visitBranch(Visit visit, TIntermBranch *node) { return true; } 73 74 // The traverse functions contain logic for iterating over the children of the node 75 // and calling the visit functions in the appropriate places. They also track some 76 // context that may be used by the visit functions. 77 virtual void traverseSymbol(TIntermSymbol *node); 78 virtual void traverseRaw(TIntermRaw *node); 79 virtual void traverseConstantUnion(TIntermConstantUnion *node); 80 virtual void traverseSwizzle(TIntermSwizzle *node); 81 virtual void traverseBinary(TIntermBinary *node); 82 virtual void traverseUnary(TIntermUnary *node); 83 virtual void traverseTernary(TIntermTernary *node); 84 virtual void traverseIfElse(TIntermIfElse *node); 85 virtual void traverseSwitch(TIntermSwitch *node); 86 virtual void traverseCase(TIntermCase *node); 87 virtual void traverseFunctionPrototype(TIntermFunctionPrototype *node); 88 virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node); 89 virtual void traverseAggregate(TIntermAggregate *node); 90 virtual void traverseBlock(TIntermBlock *node); 91 virtual void traverseInvariantDeclaration(TIntermInvariantDeclaration *node); 92 virtual void traverseDeclaration(TIntermDeclaration *node); 93 virtual void traverseLoop(TIntermLoop *node); 94 virtual void traverseBranch(TIntermBranch *node); 95 getMaxDepth()96 int getMaxDepth() const { return mMaxDepth; } 97 98 // If traversers need to replace nodes, they can add the replacements in 99 // mReplacements/mMultiReplacements during traversal and the user of the traverser should call 100 // this function after traversal to perform them. 101 void updateTree(); 102 103 protected: 104 // Should only be called from traverse*() functions incrementDepth(TIntermNode * current)105 void incrementDepth(TIntermNode *current) 106 { 107 mDepth++; 108 mMaxDepth = std::max(mMaxDepth, mDepth); 109 mPath.push_back(current); 110 } 111 112 // Should only be called from traverse*() functions decrementDepth()113 void decrementDepth() 114 { 115 mDepth--; 116 mPath.pop_back(); 117 } 118 119 // RAII helper for incrementDepth/decrementDepth 120 class ScopedNodeInTraversalPath 121 { 122 public: ScopedNodeInTraversalPath(TIntermTraverser * traverser,TIntermNode * current)123 ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current) 124 : mTraverser(traverser) 125 { 126 mTraverser->incrementDepth(current); 127 } ~ScopedNodeInTraversalPath()128 ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); } 129 130 private: 131 TIntermTraverser *mTraverser; 132 }; 133 getParentNode()134 TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; } 135 136 // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode() getAncestorNode(unsigned int n)137 TIntermNode *getAncestorNode(unsigned int n) 138 { 139 if (mPath.size() > n + 1u) 140 { 141 return mPath[mPath.size() - n - 2u]; 142 } 143 return nullptr; 144 } 145 146 const TIntermBlock *getParentBlock() const; 147 148 void pushParentBlock(TIntermBlock *node); 149 void incrementParentBlockPos(); 150 void popParentBlock(); 151 152 // To replace a single node with multiple nodes in the parent aggregate. May be used with blocks 153 // but also with other nodes like declarations. 154 struct NodeReplaceWithMultipleEntry 155 { NodeReplaceWithMultipleEntryNodeReplaceWithMultipleEntry156 NodeReplaceWithMultipleEntry(TIntermAggregateBase *_parent, 157 TIntermNode *_original, 158 TIntermSequence _replacements) 159 : parent(_parent), original(_original), replacements(_replacements) 160 { 161 } 162 163 TIntermAggregateBase *parent; 164 TIntermNode *original; 165 TIntermSequence replacements; 166 }; 167 168 // Helper to insert statements in the parent block of the node currently being traversed. 169 // The statements will be inserted before the node being traversed once updateTree is called. 170 // Should only be called during PreVisit or PostVisit if called from block nodes. 171 // Note that two insertions to the same position in the same block are not supported. 172 void insertStatementsInParentBlock(const TIntermSequence &insertions); 173 174 // Same as above, but supports simultaneous insertion of statements before and after the node 175 // currently being traversed. 176 void insertStatementsInParentBlock(const TIntermSequence &insertionsBefore, 177 const TIntermSequence &insertionsAfter); 178 179 // Helper to insert a single statement. 180 void insertStatementInParentBlock(TIntermNode *statement); 181 182 // Helper to create a temporary symbol node with the given qualifier. 183 TIntermSymbol *createTempSymbol(const TType &type, TQualifier qualifier); 184 // Helper to create a temporary symbol node. 185 TIntermSymbol *createTempSymbol(const TType &type); 186 // Create a node that declares but doesn't initialize a temporary symbol. 187 TIntermDeclaration *createTempDeclaration(const TType &type); 188 // Create a node that initializes the current temporary symbol with initializer. The symbol will 189 // have the given qualifier. 190 TIntermDeclaration *createTempInitDeclaration(TIntermTyped *initializer, TQualifier qualifier); 191 // Create a node that initializes the current temporary symbol with initializer. 192 TIntermDeclaration *createTempInitDeclaration(TIntermTyped *initializer); 193 // Create a node that assigns rightNode to the current temporary symbol. 194 TIntermBinary *createTempAssignment(TIntermTyped *rightNode); 195 // Increment temporary symbol index. 196 void nextTemporaryId(); 197 198 enum class OriginalNode 199 { 200 BECOMES_CHILD, 201 IS_DROPPED 202 }; 203 204 void clearReplacementQueue(); 205 206 // Replace the node currently being visited with replacement. 207 void queueReplacement(TIntermNode *replacement, OriginalNode originalStatus); 208 // Explicitly specify a node to replace with replacement. 209 void queueReplacementWithParent(TIntermNode *parent, 210 TIntermNode *original, 211 TIntermNode *replacement, 212 OriginalNode originalStatus); 213 214 const bool preVisit; 215 const bool inVisit; 216 const bool postVisit; 217 218 int mDepth; 219 int mMaxDepth; 220 221 bool mInGlobalScope; 222 223 // During traversing, save all the changes that need to happen into 224 // mReplacements/mMultiReplacements, then do them by calling updateTree(). 225 // Multi replacements are processed after single replacements. 226 std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements; 227 228 TSymbolTable *mSymbolTable; 229 230 private: 231 // To insert multiple nodes into the parent block. 232 struct NodeInsertMultipleEntry 233 { NodeInsertMultipleEntryNodeInsertMultipleEntry234 NodeInsertMultipleEntry(TIntermBlock *_parent, 235 TIntermSequence::size_type _position, 236 TIntermSequence _insertionsBefore, 237 TIntermSequence _insertionsAfter) 238 : parent(_parent), 239 position(_position), 240 insertionsBefore(_insertionsBefore), 241 insertionsAfter(_insertionsAfter) 242 { 243 } 244 245 TIntermBlock *parent; 246 TIntermSequence::size_type position; 247 TIntermSequence insertionsBefore; 248 TIntermSequence insertionsAfter; 249 }; 250 251 static bool CompareInsertion(const NodeInsertMultipleEntry &a, 252 const NodeInsertMultipleEntry &b); 253 254 // To replace a single node with another on the parent node 255 struct NodeUpdateEntry 256 { NodeUpdateEntryNodeUpdateEntry257 NodeUpdateEntry(TIntermNode *_parent, 258 TIntermNode *_original, 259 TIntermNode *_replacement, 260 bool _originalBecomesChildOfReplacement) 261 : parent(_parent), 262 original(_original), 263 replacement(_replacement), 264 originalBecomesChildOfReplacement(_originalBecomesChildOfReplacement) 265 { 266 } 267 268 TIntermNode *parent; 269 TIntermNode *original; 270 TIntermNode *replacement; 271 bool originalBecomesChildOfReplacement; 272 }; 273 274 struct ParentBlock 275 { ParentBlockParentBlock276 ParentBlock(TIntermBlock *nodeIn, TIntermSequence::size_type posIn) 277 : node(nodeIn), pos(posIn) 278 { 279 } 280 281 TIntermBlock *node; 282 TIntermSequence::size_type pos; 283 }; 284 285 std::vector<NodeInsertMultipleEntry> mInsertions; 286 std::vector<NodeUpdateEntry> mReplacements; 287 288 // All the nodes from root to the current node during traversing. 289 TVector<TIntermNode *> mPath; 290 291 // All the code blocks from the root to the current node's parent during traversal. 292 std::vector<ParentBlock> mParentBlockStack; 293 294 TSymbolUniqueId *mTemporaryId; 295 }; 296 297 // Traverser parent class that tracks where a node is a destination of a write operation and so is 298 // required to be an l-value. 299 class TLValueTrackingTraverser : public TIntermTraverser 300 { 301 public: 302 TLValueTrackingTraverser(bool preVisit, 303 bool inVisit, 304 bool postVisit, 305 TSymbolTable *symbolTable, 306 int shaderVersion); ~TLValueTrackingTraverser()307 virtual ~TLValueTrackingTraverser() {} 308 309 void traverseBinary(TIntermBinary *node) final; 310 void traverseUnary(TIntermUnary *node) final; 311 void traverseFunctionPrototype(TIntermFunctionPrototype *node) final; 312 void traverseAggregate(TIntermAggregate *node) final; 313 314 protected: isLValueRequiredHere()315 bool isLValueRequiredHere() const 316 { 317 return mOperatorRequiresLValue || mInFunctionCallOutParameter; 318 } 319 320 private: 321 // Track whether an l-value is required in the node that is currently being traversed by the 322 // surrounding operator. 323 // Use isLValueRequiredHere to check all conditions which require an l-value. setOperatorRequiresLValue(bool lValueRequired)324 void setOperatorRequiresLValue(bool lValueRequired) 325 { 326 mOperatorRequiresLValue = lValueRequired; 327 } operatorRequiresLValue()328 bool operatorRequiresLValue() const { return mOperatorRequiresLValue; } 329 330 // Add a function encountered during traversal to the function map. 331 void addToFunctionMap(const TSymbolUniqueId &id, TIntermSequence *paramSequence); 332 333 // Return true if the prototype or definition of the function being called has been encountered 334 // during traversal. 335 bool isInFunctionMap(const TIntermAggregate *callNode) const; 336 337 // Return the parameters sequence from the function definition or prototype. 338 TIntermSequence *getFunctionParameters(const TIntermAggregate *callNode); 339 340 // Track whether an l-value is required inside a function call. 341 void setInFunctionCallOutParameter(bool inOutParameter); 342 bool isInFunctionCallOutParameter() const; 343 344 bool mOperatorRequiresLValue; 345 bool mInFunctionCallOutParameter; 346 347 // Map from function symbol id values to their parameter sequences 348 TMap<int, TIntermSequence *> mFunctionMap; 349 350 const int mShaderVersion; 351 }; 352 353 } // namespace sh 354 355 #endif // COMPILER_TRANSLATOR_INTERMTRAVERSE_H_ 356