1 //
2 // Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveUnreferencedVariables.cpp:
7 //  Drop variables that are declared but never referenced in the AST. This avoids adding unnecessary
8 //  initialization code for them. Also removes unreferenced struct types.
9 //
10 
11 #include "compiler/translator/RemoveUnreferencedVariables.h"
12 
13 #include "compiler/translator/IntermTraverse.h"
14 #include "compiler/translator/SymbolTable.h"
15 
16 namespace sh
17 {
18 
19 namespace
20 {
21 
22 class CollectVariableRefCountsTraverser : public TIntermTraverser
23 {
24   public:
25     CollectVariableRefCountsTraverser();
26 
27     using RefCountMap = std::unordered_map<int, unsigned int>;
getSymbolIdRefCounts()28     RefCountMap &getSymbolIdRefCounts() { return mSymbolIdRefCounts; }
getStructIdRefCounts()29     RefCountMap &getStructIdRefCounts() { return mStructIdRefCounts; }
30 
31     void visitSymbol(TIntermSymbol *node) override;
32     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
33     bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
34 
35   private:
36     void incrementStructTypeRefCount(const TType &type);
37 
38     RefCountMap mSymbolIdRefCounts;
39 
40     // Structure reference counts are counted from symbols, constructors, function calls, function
41     // return values and from interface block and structure fields. We need to track both function
42     // calls and function return values since there's a compiler option not to prune unused
43     // functions. The type of a constant union may also be a struct, but statements that are just a
44     // constant union are always pruned, and if the constant union is used somehow it will get
45     // counted by something else.
46     RefCountMap mStructIdRefCounts;
47 };
48 
CollectVariableRefCountsTraverser()49 CollectVariableRefCountsTraverser::CollectVariableRefCountsTraverser()
50     : TIntermTraverser(true, false, false)
51 {
52 }
53 
incrementStructTypeRefCount(const TType & type)54 void CollectVariableRefCountsTraverser::incrementStructTypeRefCount(const TType &type)
55 {
56     if (type.isInterfaceBlock())
57     {
58         const auto *block = type.getInterfaceBlock();
59         ASSERT(block);
60 
61         // We can end up incrementing ref counts of struct types referenced from an interface block
62         // multiple times for the same block. This doesn't matter, because interface blocks can't be
63         // pruned so we'll never do the reverse operation.
64         for (const auto &field : block->fields())
65         {
66             ASSERT(!field->type()->isInterfaceBlock());
67             incrementStructTypeRefCount(*field->type());
68         }
69         return;
70     }
71 
72     const auto *structure = type.getStruct();
73     if (structure != nullptr)
74     {
75         auto structIter = mStructIdRefCounts.find(structure->uniqueId().get());
76         if (structIter == mStructIdRefCounts.end())
77         {
78             mStructIdRefCounts[structure->uniqueId().get()] = 1u;
79 
80             for (const auto &field : structure->fields())
81             {
82                 incrementStructTypeRefCount(*field->type());
83             }
84 
85             return;
86         }
87         ++(structIter->second);
88     }
89 }
90 
visitSymbol(TIntermSymbol * node)91 void CollectVariableRefCountsTraverser::visitSymbol(TIntermSymbol *node)
92 {
93     incrementStructTypeRefCount(node->getType());
94 
95     auto iter = mSymbolIdRefCounts.find(node->uniqueId().get());
96     if (iter == mSymbolIdRefCounts.end())
97     {
98         mSymbolIdRefCounts[node->uniqueId().get()] = 1u;
99         return;
100     }
101     ++(iter->second);
102 }
103 
visitAggregate(Visit visit,TIntermAggregate * node)104 bool CollectVariableRefCountsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
105 {
106     // This tracks struct references in both function calls and constructors.
107     incrementStructTypeRefCount(node->getType());
108     return true;
109 }
110 
visitFunctionPrototype(Visit visit,TIntermFunctionPrototype * node)111 bool CollectVariableRefCountsTraverser::visitFunctionPrototype(Visit visit,
112                                                                TIntermFunctionPrototype *node)
113 {
114     incrementStructTypeRefCount(node->getType());
115     return true;
116 }
117 
118 // Traverser that removes all unreferenced variables on one traversal.
119 class RemoveUnreferencedVariablesTraverser : public TIntermTraverser
120 {
121   public:
122     RemoveUnreferencedVariablesTraverser(
123         CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
124         CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
125         TSymbolTable *symbolTable);
126 
127     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
128     void visitSymbol(TIntermSymbol *node) override;
129     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
130 
131     // Traverse loop and block nodes in reverse order. Note that this traverser does not track
132     // parent block positions, so insertStatementInParentBlock is unusable!
133     void traverseBlock(TIntermBlock *block) override;
134     void traverseLoop(TIntermLoop *loop) override;
135 
136   private:
137     void removeVariableDeclaration(TIntermDeclaration *node, TIntermTyped *declarator);
138     void decrementStructTypeRefCount(const TType &type);
139 
140     CollectVariableRefCountsTraverser::RefCountMap *mSymbolIdRefCounts;
141     CollectVariableRefCountsTraverser::RefCountMap *mStructIdRefCounts;
142     bool mRemoveReferences;
143 };
144 
RemoveUnreferencedVariablesTraverser(CollectVariableRefCountsTraverser::RefCountMap * symbolIdRefCounts,CollectVariableRefCountsTraverser::RefCountMap * structIdRefCounts,TSymbolTable * symbolTable)145 RemoveUnreferencedVariablesTraverser::RemoveUnreferencedVariablesTraverser(
146     CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
147     CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
148     TSymbolTable *symbolTable)
149     : TIntermTraverser(true, false, true, symbolTable),
150       mSymbolIdRefCounts(symbolIdRefCounts),
151       mStructIdRefCounts(structIdRefCounts),
152       mRemoveReferences(false)
153 {
154 }
155 
decrementStructTypeRefCount(const TType & type)156 void RemoveUnreferencedVariablesTraverser::decrementStructTypeRefCount(const TType &type)
157 {
158     auto *structure = type.getStruct();
159     if (structure != nullptr)
160     {
161         ASSERT(mStructIdRefCounts->find(structure->uniqueId().get()) != mStructIdRefCounts->end());
162         unsigned int structRefCount = --(*mStructIdRefCounts)[structure->uniqueId().get()];
163 
164         if (structRefCount == 0)
165         {
166             for (const auto &field : structure->fields())
167             {
168                 decrementStructTypeRefCount(*field->type());
169             }
170         }
171     }
172 }
173 
removeVariableDeclaration(TIntermDeclaration * node,TIntermTyped * declarator)174 void RemoveUnreferencedVariablesTraverser::removeVariableDeclaration(TIntermDeclaration *node,
175                                                                      TIntermTyped *declarator)
176 {
177     if (declarator->getType().isStructSpecifier() && !declarator->getType().isNamelessStruct())
178     {
179         unsigned int structId = declarator->getType().getStruct()->uniqueId().get();
180         unsigned int structRefCountInThisDeclarator = 1u;
181         if (declarator->getAsBinaryNode() &&
182             declarator->getAsBinaryNode()->getRight()->getAsAggregate())
183         {
184             ASSERT(declarator->getAsBinaryNode()->getLeft()->getType().getStruct() ==
185                    declarator->getType().getStruct());
186             ASSERT(declarator->getAsBinaryNode()->getRight()->getType().getStruct() ==
187                    declarator->getType().getStruct());
188             structRefCountInThisDeclarator = 2u;
189         }
190         if ((*mStructIdRefCounts)[structId] > structRefCountInThisDeclarator)
191         {
192             // If this declaration declares a named struct type that is used elsewhere, we need to
193             // keep it. We can still change the declarator though so that it doesn't declare an
194             // unreferenced variable.
195 
196             // Note that since we're not removing the entire declaration, the struct's reference
197             // count will end up being one less than the correct refcount. But since the struct
198             // declaration is kept, the incorrect refcount can't cause any other problems.
199 
200             if (declarator->getAsSymbolNode() &&
201                 declarator->getAsSymbolNode()->variable().symbolType() == SymbolType::Empty)
202             {
203                 // Already an empty declaration - nothing to do.
204                 return;
205             }
206             TVariable *emptyVariable =
207                 new TVariable(mSymbolTable, ImmutableString(""), new TType(declarator->getType()),
208                               SymbolType::Empty);
209             queueReplacementWithParent(node, declarator, new TIntermSymbol(emptyVariable),
210                                        OriginalNode::IS_DROPPED);
211             return;
212         }
213     }
214 
215     if (getParentNode()->getAsBlock())
216     {
217         TIntermSequence emptyReplacement;
218         mMultiReplacements.push_back(
219             NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, emptyReplacement));
220     }
221     else
222     {
223         ASSERT(getParentNode()->getAsLoopNode());
224         queueReplacement(nullptr, OriginalNode::IS_DROPPED);
225     }
226 }
227 
visitDeclaration(Visit visit,TIntermDeclaration * node)228 bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
229 {
230     if (visit == PreVisit)
231     {
232         // SeparateDeclarations should have already been run.
233         ASSERT(node->getSequence()->size() == 1u);
234 
235         TIntermTyped *declarator = node->getSequence()->back()->getAsTyped();
236         ASSERT(declarator);
237 
238         // We can only remove variables that are not a part of the shader interface.
239         TQualifier qualifier = declarator->getQualifier();
240         if (qualifier != EvqTemporary && qualifier != EvqGlobal && qualifier != EvqConst)
241         {
242             return true;
243         }
244 
245         bool canRemoveVariable    = false;
246         TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
247         if (symbolNode != nullptr)
248         {
249             canRemoveVariable = (*mSymbolIdRefCounts)[symbolNode->uniqueId().get()] == 1u ||
250                                 symbolNode->variable().symbolType() == SymbolType::Empty;
251         }
252         TIntermBinary *initNode = declarator->getAsBinaryNode();
253         if (initNode != nullptr)
254         {
255             ASSERT(initNode->getLeft()->getAsSymbolNode());
256             int symbolId = initNode->getLeft()->getAsSymbolNode()->uniqueId().get();
257             canRemoveVariable =
258                 (*mSymbolIdRefCounts)[symbolId] == 1u && !initNode->getRight()->hasSideEffects();
259         }
260 
261         if (canRemoveVariable)
262         {
263             removeVariableDeclaration(node, declarator);
264             mRemoveReferences = true;
265         }
266         return true;
267     }
268     ASSERT(visit == PostVisit);
269     mRemoveReferences = false;
270     return true;
271 }
272 
visitSymbol(TIntermSymbol * node)273 void RemoveUnreferencedVariablesTraverser::visitSymbol(TIntermSymbol *node)
274 {
275     if (mRemoveReferences)
276     {
277         ASSERT(mSymbolIdRefCounts->find(node->uniqueId().get()) != mSymbolIdRefCounts->end());
278         --(*mSymbolIdRefCounts)[node->uniqueId().get()];
279 
280         decrementStructTypeRefCount(node->getType());
281     }
282 }
283 
visitAggregate(Visit visit,TIntermAggregate * node)284 bool RemoveUnreferencedVariablesTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
285 {
286     if (visit == PreVisit && mRemoveReferences)
287     {
288         decrementStructTypeRefCount(node->getType());
289     }
290     return true;
291 }
292 
traverseBlock(TIntermBlock * node)293 void RemoveUnreferencedVariablesTraverser::traverseBlock(TIntermBlock *node)
294 {
295     // We traverse blocks in reverse order.  This way reference counts can be decremented when
296     // removing initializers, and variables that become unused when initializers are removed can be
297     // removed on the same traversal.
298 
299     ScopedNodeInTraversalPath addToPath(this, node);
300 
301     bool visit = true;
302 
303     TIntermSequence *sequence = node->getSequence();
304 
305     if (preVisit)
306         visit = visitBlock(PreVisit, node);
307 
308     if (visit)
309     {
310         for (auto iter = sequence->rbegin(); iter != sequence->rend(); ++iter)
311         {
312             (*iter)->traverse(this);
313             if (visit && inVisit)
314             {
315                 if ((iter + 1) != sequence->rend())
316                     visit = visitBlock(InVisit, node);
317             }
318         }
319     }
320 
321     if (visit && postVisit)
322         visitBlock(PostVisit, node);
323 }
324 
traverseLoop(TIntermLoop * node)325 void RemoveUnreferencedVariablesTraverser::traverseLoop(TIntermLoop *node)
326 {
327     // We traverse loops in reverse order as well. The loop body gets traversed before the init
328     // node.
329 
330     ScopedNodeInTraversalPath addToPath(this, node);
331 
332     bool visit = true;
333 
334     if (preVisit)
335         visit = visitLoop(PreVisit, node);
336 
337     if (visit)
338     {
339         // We don't need to traverse loop expressions or conditions since they can't be declarations
340         // in the AST (loops which have a declaration in their condition get transformed in the
341         // parsing stage).
342         ASSERT(node->getExpression() == nullptr ||
343                node->getExpression()->getAsDeclarationNode() == nullptr);
344         ASSERT(node->getCondition() == nullptr ||
345                node->getCondition()->getAsDeclarationNode() == nullptr);
346 
347         if (node->getBody())
348             node->getBody()->traverse(this);
349 
350         if (node->getInit())
351             node->getInit()->traverse(this);
352     }
353 
354     if (visit && postVisit)
355         visitLoop(PostVisit, node);
356 }
357 
358 }  // namespace
359 
RemoveUnreferencedVariables(TIntermBlock * root,TSymbolTable * symbolTable)360 void RemoveUnreferencedVariables(TIntermBlock *root, TSymbolTable *symbolTable)
361 {
362     CollectVariableRefCountsTraverser collector;
363     root->traverse(&collector);
364     RemoveUnreferencedVariablesTraverser traverser(&collector.getSymbolIdRefCounts(),
365                                                    &collector.getStructIdRefCounts(), symbolTable);
366     root->traverse(&traverser);
367     traverser.updateTree();
368 }
369 
370 }  // namespace sh
371