1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveUnreferencedVariables.cpp:
7 //  Drop variables that are declared but never referenced in the AST. This avoids adding unnecessary
8 //  initialization code for them. Also removes unreferenced struct types.
9 //
10 
11 #include "compiler/translator/tree_ops/RemoveUnreferencedVariables.h"
12 
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermTraverse.h"
15 
16 namespace sh
17 {
18 
19 namespace
20 {
21 
22 class CollectVariableRefCountsTraverser : public TIntermTraverser
23 {
24   public:
25     CollectVariableRefCountsTraverser();
26 
27     using RefCountMap = angle::HashMap<int, unsigned int>;
getSymbolIdRefCounts()28     RefCountMap &getSymbolIdRefCounts() { return mSymbolIdRefCounts; }
getStructIdRefCounts()29     RefCountMap &getStructIdRefCounts() { return mStructIdRefCounts; }
30 
31     void visitSymbol(TIntermSymbol *node) override;
32     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
33     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
34 
35   private:
36     void incrementStructTypeRefCount(const TType &type);
37 
38     RefCountMap mSymbolIdRefCounts;
39 
40     // Structure reference counts are counted from symbols, constructors, function calls, function
41     // return values and from interface block and structure fields. We need to track both function
42     // calls and function return values since there's a compiler option not to prune unused
43     // functions. The type of a constant union may also be a struct, but statements that are just a
44     // constant union are always pruned, and if the constant union is used somehow it will get
45     // counted by something else.
46     RefCountMap mStructIdRefCounts;
47 };
48 
CollectVariableRefCountsTraverser()49 CollectVariableRefCountsTraverser::CollectVariableRefCountsTraverser()
50     : TIntermTraverser(true, false, false)
51 {}
52 
incrementStructTypeRefCount(const TType & type)53 void CollectVariableRefCountsTraverser::incrementStructTypeRefCount(const TType &type)
54 {
55     if (type.isInterfaceBlock())
56     {
57         const auto *block = type.getInterfaceBlock();
58         ASSERT(block);
59 
60         // We can end up incrementing ref counts of struct types referenced from an interface block
61         // multiple times for the same block. This doesn't matter, because interface blocks can't be
62         // pruned so we'll never do the reverse operation.
63         for (const auto &field : block->fields())
64         {
65             ASSERT(!field->type()->isInterfaceBlock());
66             incrementStructTypeRefCount(*field->type());
67         }
68         return;
69     }
70 
71     const auto *structure = type.getStruct();
72     if (structure != nullptr)
73     {
74         auto structIter = mStructIdRefCounts.find(structure->uniqueId().get());
75         if (structIter == mStructIdRefCounts.end())
76         {
77             mStructIdRefCounts[structure->uniqueId().get()] = 1u;
78 
79             for (const auto &field : structure->fields())
80             {
81                 incrementStructTypeRefCount(*field->type());
82             }
83 
84             return;
85         }
86         ++(structIter->second);
87     }
88 }
89 
visitSymbol(TIntermSymbol * node)90 void CollectVariableRefCountsTraverser::visitSymbol(TIntermSymbol *node)
91 {
92     incrementStructTypeRefCount(node->getType());
93 
94     auto iter = mSymbolIdRefCounts.find(node->uniqueId().get());
95     if (iter == mSymbolIdRefCounts.end())
96     {
97         mSymbolIdRefCounts[node->uniqueId().get()] = 1u;
98         return;
99     }
100     ++(iter->second);
101 }
102 
visitAggregate(Visit visit,TIntermAggregate * node)103 bool CollectVariableRefCountsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
104 {
105     // This tracks struct references in both function calls and constructors.
106     incrementStructTypeRefCount(node->getType());
107     return true;
108 }
109 
visitFunctionPrototype(TIntermFunctionPrototype * node)110 void CollectVariableRefCountsTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
111 {
112     incrementStructTypeRefCount(node->getType());
113     size_t paramCount = node->getFunction()->getParamCount();
114     for (size_t i = 0; i < paramCount; ++i)
115     {
116         incrementStructTypeRefCount(node->getFunction()->getParam(i)->getType());
117     }
118 }
119 
120 // Traverser that removes all unreferenced variables on one traversal.
121 class RemoveUnreferencedVariablesTraverser : public TIntermTraverser
122 {
123   public:
124     RemoveUnreferencedVariablesTraverser(
125         CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
126         CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
127         TSymbolTable *symbolTable);
128 
129     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
130     void visitSymbol(TIntermSymbol *node) override;
131     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
132 
133     // Traverse loop and block nodes in reverse order. Note that this traverser does not track
134     // parent block positions, so insertStatementInParentBlock is unusable!
135     void traverseBlock(TIntermBlock *block) override;
136     void traverseLoop(TIntermLoop *loop) override;
137 
138   private:
139     void removeVariableDeclaration(TIntermDeclaration *node, TIntermTyped *declarator);
140     void decrementStructTypeRefCount(const TType &type);
141 
142     CollectVariableRefCountsTraverser::RefCountMap *mSymbolIdRefCounts;
143     CollectVariableRefCountsTraverser::RefCountMap *mStructIdRefCounts;
144     bool mRemoveReferences;
145 };
146 
RemoveUnreferencedVariablesTraverser(CollectVariableRefCountsTraverser::RefCountMap * symbolIdRefCounts,CollectVariableRefCountsTraverser::RefCountMap * structIdRefCounts,TSymbolTable * symbolTable)147 RemoveUnreferencedVariablesTraverser::RemoveUnreferencedVariablesTraverser(
148     CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
149     CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
150     TSymbolTable *symbolTable)
151     : TIntermTraverser(true, false, true, symbolTable),
152       mSymbolIdRefCounts(symbolIdRefCounts),
153       mStructIdRefCounts(structIdRefCounts),
154       mRemoveReferences(false)
155 {}
156 
decrementStructTypeRefCount(const TType & type)157 void RemoveUnreferencedVariablesTraverser::decrementStructTypeRefCount(const TType &type)
158 {
159     auto *structure = type.getStruct();
160     if (structure != nullptr)
161     {
162         ASSERT(mStructIdRefCounts->find(structure->uniqueId().get()) != mStructIdRefCounts->end());
163         unsigned int structRefCount = --(*mStructIdRefCounts)[structure->uniqueId().get()];
164 
165         if (structRefCount == 0)
166         {
167             for (const auto &field : structure->fields())
168             {
169                 decrementStructTypeRefCount(*field->type());
170             }
171         }
172     }
173 }
174 
removeVariableDeclaration(TIntermDeclaration * node,TIntermTyped * declarator)175 void RemoveUnreferencedVariablesTraverser::removeVariableDeclaration(TIntermDeclaration *node,
176                                                                      TIntermTyped *declarator)
177 {
178     if (declarator->getType().isStructSpecifier() && !declarator->getType().isNamelessStruct())
179     {
180         unsigned int structId = declarator->getType().getStruct()->uniqueId().get();
181         unsigned int structRefCountInThisDeclarator = 1u;
182         if (declarator->getAsBinaryNode() &&
183             declarator->getAsBinaryNode()->getRight()->getAsAggregate())
184         {
185             ASSERT(declarator->getAsBinaryNode()->getLeft()->getType().getStruct() ==
186                    declarator->getType().getStruct());
187             ASSERT(declarator->getAsBinaryNode()->getRight()->getType().getStruct() ==
188                    declarator->getType().getStruct());
189             structRefCountInThisDeclarator = 2u;
190         }
191         if ((*mStructIdRefCounts)[structId] > structRefCountInThisDeclarator)
192         {
193             // If this declaration declares a named struct type that is used elsewhere, we need to
194             // keep it. We can still change the declarator though so that it doesn't declare an
195             // unreferenced variable.
196 
197             // Note that since we're not removing the entire declaration, the struct's reference
198             // count will end up being one less than the correct refcount. But since the struct
199             // declaration is kept, the incorrect refcount can't cause any other problems.
200 
201             if (declarator->getAsSymbolNode() &&
202                 declarator->getAsSymbolNode()->variable().symbolType() == SymbolType::Empty)
203             {
204                 // Already an empty declaration - nothing to do.
205                 return;
206             }
207             TVariable *emptyVariable =
208                 new TVariable(mSymbolTable, kEmptyImmutableString, new TType(declarator->getType()),
209                               SymbolType::Empty);
210             queueReplacementWithParent(node, declarator, new TIntermSymbol(emptyVariable),
211                                        OriginalNode::IS_DROPPED);
212             return;
213         }
214     }
215 
216     if (getParentNode()->getAsBlock())
217     {
218         TIntermSequence emptyReplacement;
219         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
220                                         std::move(emptyReplacement));
221     }
222     else
223     {
224         ASSERT(getParentNode()->getAsLoopNode());
225         queueReplacement(nullptr, OriginalNode::IS_DROPPED);
226     }
227 }
228 
visitDeclaration(Visit visit,TIntermDeclaration * node)229 bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
230 {
231     if (visit == PreVisit)
232     {
233         // SeparateDeclarations should have already been run.
234         ASSERT(node->getSequence()->size() == 1u);
235 
236         TIntermTyped *declarator = node->getSequence()->back()->getAsTyped();
237         ASSERT(declarator);
238 
239         // We can only remove variables that are not a part of the shader interface.
240         TQualifier qualifier = declarator->getQualifier();
241         if (qualifier != EvqTemporary && qualifier != EvqGlobal && qualifier != EvqConst)
242         {
243             return true;
244         }
245 
246         bool canRemoveVariable    = false;
247         TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
248         if (symbolNode != nullptr)
249         {
250             canRemoveVariable = (*mSymbolIdRefCounts)[symbolNode->uniqueId().get()] == 1u ||
251                                 symbolNode->variable().symbolType() == SymbolType::Empty;
252         }
253         TIntermBinary *initNode = declarator->getAsBinaryNode();
254         if (initNode != nullptr)
255         {
256             ASSERT(initNode->getLeft()->getAsSymbolNode());
257             int symbolId = initNode->getLeft()->getAsSymbolNode()->uniqueId().get();
258             canRemoveVariable =
259                 (*mSymbolIdRefCounts)[symbolId] == 1u && !initNode->getRight()->hasSideEffects();
260         }
261 
262         if (canRemoveVariable)
263         {
264             removeVariableDeclaration(node, declarator);
265             mRemoveReferences = true;
266         }
267         return true;
268     }
269     ASSERT(visit == PostVisit);
270     mRemoveReferences = false;
271     return true;
272 }
273 
visitSymbol(TIntermSymbol * node)274 void RemoveUnreferencedVariablesTraverser::visitSymbol(TIntermSymbol *node)
275 {
276     if (mRemoveReferences)
277     {
278         ASSERT(mSymbolIdRefCounts->find(node->uniqueId().get()) != mSymbolIdRefCounts->end());
279         --(*mSymbolIdRefCounts)[node->uniqueId().get()];
280 
281         decrementStructTypeRefCount(node->getType());
282     }
283 }
284 
visitAggregate(Visit visit,TIntermAggregate * node)285 bool RemoveUnreferencedVariablesTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
286 {
287     if (visit == PreVisit && mRemoveReferences)
288     {
289         decrementStructTypeRefCount(node->getType());
290     }
291     return true;
292 }
293 
traverseBlock(TIntermBlock * node)294 void RemoveUnreferencedVariablesTraverser::traverseBlock(TIntermBlock *node)
295 {
296     // We traverse blocks in reverse order.  This way reference counts can be decremented when
297     // removing initializers, and variables that become unused when initializers are removed can be
298     // removed on the same traversal.
299 
300     ScopedNodeInTraversalPath addToPath(this, node);
301 
302     bool visit = true;
303 
304     TIntermSequence *sequence = node->getSequence();
305 
306     if (preVisit)
307         visit = visitBlock(PreVisit, node);
308 
309     if (visit)
310     {
311         for (auto iter = sequence->rbegin(); iter != sequence->rend(); ++iter)
312         {
313             (*iter)->traverse(this);
314             if (visit && inVisit)
315             {
316                 if ((iter + 1) != sequence->rend())
317                     visit = visitBlock(InVisit, node);
318             }
319         }
320     }
321 
322     if (visit && postVisit)
323         visitBlock(PostVisit, node);
324 }
325 
traverseLoop(TIntermLoop * node)326 void RemoveUnreferencedVariablesTraverser::traverseLoop(TIntermLoop *node)
327 {
328     // We traverse loops in reverse order as well. The loop body gets traversed before the init
329     // node.
330 
331     ScopedNodeInTraversalPath addToPath(this, node);
332 
333     bool visit = true;
334 
335     if (preVisit)
336         visit = visitLoop(PreVisit, node);
337 
338     if (visit)
339     {
340         // We don't need to traverse loop expressions or conditions since they can't be declarations
341         // in the AST (loops which have a declaration in their condition get transformed in the
342         // parsing stage).
343         ASSERT(node->getExpression() == nullptr ||
344                node->getExpression()->getAsDeclarationNode() == nullptr);
345         ASSERT(node->getCondition() == nullptr ||
346                node->getCondition()->getAsDeclarationNode() == nullptr);
347 
348         if (node->getBody())
349             node->getBody()->traverse(this);
350 
351         if (node->getInit())
352             node->getInit()->traverse(this);
353     }
354 
355     if (visit && postVisit)
356         visitLoop(PostVisit, node);
357 }
358 
359 }  // namespace
360 
RemoveUnreferencedVariables(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)361 bool RemoveUnreferencedVariables(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
362 {
363     CollectVariableRefCountsTraverser collector;
364     root->traverse(&collector);
365     RemoveUnreferencedVariablesTraverser traverser(&collector.getSymbolIdRefCounts(),
366                                                    &collector.getStructIdRefCounts(), symbolTable);
367     root->traverse(&traverser);
368     return traverser.updateTree(compiler, root);
369 }
370 
371 }  // namespace sh
372