1 //
2 // Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveUnreferencedVariables.cpp:
7 //  Drop variables that are declared but never referenced in the AST. This avoids adding unnecessary
8 //  initialization code for them. Also removes unreferenced struct types.
9 //
10 
11 #include "compiler/translator/RemoveUnreferencedVariables.h"
12 
13 #include "compiler/translator/IntermTraverse.h"
14 #include "compiler/translator/SymbolTable.h"
15 
16 namespace sh
17 {
18 
19 namespace
20 {
21 
22 class CollectVariableRefCountsTraverser : public TIntermTraverser
23 {
24   public:
25     CollectVariableRefCountsTraverser();
26 
27     using RefCountMap = std::unordered_map<int, unsigned int>;
getSymbolIdRefCounts()28     RefCountMap &getSymbolIdRefCounts() { return mSymbolIdRefCounts; }
getStructIdRefCounts()29     RefCountMap &getStructIdRefCounts() { return mStructIdRefCounts; }
30 
31     void visitSymbol(TIntermSymbol *node) override;
32     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
33     bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
34 
35   private:
36     void incrementStructTypeRefCount(const TType &type);
37 
38     RefCountMap mSymbolIdRefCounts;
39 
40     // Structure reference counts are counted from symbols, constructors, function calls, function
41     // return values and from interface block and structure fields. We need to track both function
42     // calls and function return values since there's a compiler option not to prune unused
43     // functions. The type of a constant union may also be a struct, but statements that are just a
44     // constant union are always pruned, and if the constant union is used somehow it will get
45     // counted by something else.
46     RefCountMap mStructIdRefCounts;
47 };
48 
CollectVariableRefCountsTraverser()49 CollectVariableRefCountsTraverser::CollectVariableRefCountsTraverser()
50     : TIntermTraverser(true, false, false)
51 {
52 }
53 
incrementStructTypeRefCount(const TType & type)54 void CollectVariableRefCountsTraverser::incrementStructTypeRefCount(const TType &type)
55 {
56     if (type.isInterfaceBlock())
57     {
58         const auto *block = type.getInterfaceBlock();
59         ASSERT(block);
60 
61         // We can end up incrementing ref counts of struct types referenced from an interface block
62         // multiple times for the same block. This doesn't matter, because interface blocks can't be
63         // pruned so we'll never do the reverse operation.
64         for (const auto &field : block->fields())
65         {
66             ASSERT(!field->type()->isInterfaceBlock());
67             incrementStructTypeRefCount(*field->type());
68         }
69         return;
70     }
71 
72     const auto *structure = type.getStruct();
73     if (structure != nullptr)
74     {
75         auto structIter = mStructIdRefCounts.find(structure->uniqueId());
76         if (structIter == mStructIdRefCounts.end())
77         {
78             mStructIdRefCounts[structure->uniqueId()] = 1u;
79 
80             for (const auto &field : structure->fields())
81             {
82                 incrementStructTypeRefCount(*field->type());
83             }
84 
85             return;
86         }
87         ++(structIter->second);
88     }
89 }
90 
visitSymbol(TIntermSymbol * node)91 void CollectVariableRefCountsTraverser::visitSymbol(TIntermSymbol *node)
92 {
93     incrementStructTypeRefCount(node->getType());
94 
95     auto iter = mSymbolIdRefCounts.find(node->getId());
96     if (iter == mSymbolIdRefCounts.end())
97     {
98         mSymbolIdRefCounts[node->getId()] = 1u;
99         return;
100     }
101     ++(iter->second);
102 }
103 
visitAggregate(Visit visit,TIntermAggregate * node)104 bool CollectVariableRefCountsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
105 {
106     // This tracks struct references in both function calls and constructors.
107     incrementStructTypeRefCount(node->getType());
108     return true;
109 }
110 
visitFunctionPrototype(Visit visit,TIntermFunctionPrototype * node)111 bool CollectVariableRefCountsTraverser::visitFunctionPrototype(Visit visit,
112                                                                TIntermFunctionPrototype *node)
113 {
114     incrementStructTypeRefCount(node->getType());
115     return true;
116 }
117 
118 // Traverser that removes all unreferenced variables on one traversal.
119 class RemoveUnreferencedVariablesTraverser : public TIntermTraverser
120 {
121   public:
122     RemoveUnreferencedVariablesTraverser(
123         CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
124         CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
125         TSymbolTable *symbolTable);
126 
127     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
128     void visitSymbol(TIntermSymbol *node) override;
129     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
130 
131     // Traverse loop and block nodes in reverse order. Note that this traverser does not track
132     // parent block positions, so insertStatementInParentBlock is unusable!
133     void traverseBlock(TIntermBlock *block) override;
134     void traverseLoop(TIntermLoop *loop) override;
135 
136   private:
137     void removeVariableDeclaration(TIntermDeclaration *node, TIntermTyped *declarator);
138     void decrementStructTypeRefCount(const TType &type);
139 
140     CollectVariableRefCountsTraverser::RefCountMap *mSymbolIdRefCounts;
141     CollectVariableRefCountsTraverser::RefCountMap *mStructIdRefCounts;
142     bool mRemoveReferences;
143 };
144 
RemoveUnreferencedVariablesTraverser(CollectVariableRefCountsTraverser::RefCountMap * symbolIdRefCounts,CollectVariableRefCountsTraverser::RefCountMap * structIdRefCounts,TSymbolTable * symbolTable)145 RemoveUnreferencedVariablesTraverser::RemoveUnreferencedVariablesTraverser(
146     CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
147     CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
148     TSymbolTable *symbolTable)
149     : TIntermTraverser(true, false, true, symbolTable),
150       mSymbolIdRefCounts(symbolIdRefCounts),
151       mStructIdRefCounts(structIdRefCounts),
152       mRemoveReferences(false)
153 {
154 }
155 
decrementStructTypeRefCount(const TType & type)156 void RemoveUnreferencedVariablesTraverser::decrementStructTypeRefCount(const TType &type)
157 {
158     auto *structure = type.getStruct();
159     if (structure != nullptr)
160     {
161         ASSERT(mStructIdRefCounts->find(structure->uniqueId()) != mStructIdRefCounts->end());
162         unsigned int structRefCount = --(*mStructIdRefCounts)[structure->uniqueId()];
163 
164         if (structRefCount == 0)
165         {
166             for (const auto &field : structure->fields())
167             {
168                 decrementStructTypeRefCount(*field->type());
169             }
170         }
171     }
172 }
173 
removeVariableDeclaration(TIntermDeclaration * node,TIntermTyped * declarator)174 void RemoveUnreferencedVariablesTraverser::removeVariableDeclaration(TIntermDeclaration *node,
175                                                                      TIntermTyped *declarator)
176 {
177     if (declarator->getType().isStructSpecifier() && !declarator->getType().isNamelessStruct())
178     {
179         unsigned int structId = declarator->getType().getStruct()->uniqueId();
180         if ((*mStructIdRefCounts)[structId] > 1u)
181         {
182             // If this declaration declares a named struct type that is used elsewhere, we need to
183             // keep it. We can still change the declarator though so that it doesn't declare an
184             // unreferenced variable.
185 
186             // Note that since we're not removing the entire declaration, the struct's reference
187             // count will end up being one less than the correct refcount. But since the struct
188             // declaration is kept, the incorrect refcount can't cause any other problems.
189 
190             if (declarator->getAsSymbolNode() && declarator->getAsSymbolNode()->getSymbol().empty())
191             {
192                 // Already an empty declaration - nothing to do.
193                 return;
194             }
195             queueReplacementWithParent(node, declarator,
196                                        new TIntermSymbol(mSymbolTable->getEmptySymbolId(),
197                                                          TString(""), declarator->getType()),
198                                        OriginalNode::IS_DROPPED);
199             return;
200         }
201     }
202 
203     if (getParentNode()->getAsBlock())
204     {
205         TIntermSequence emptyReplacement;
206         mMultiReplacements.push_back(
207             NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, emptyReplacement));
208     }
209     else
210     {
211         ASSERT(getParentNode()->getAsLoopNode());
212         queueReplacement(nullptr, OriginalNode::IS_DROPPED);
213     }
214 }
215 
visitDeclaration(Visit visit,TIntermDeclaration * node)216 bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
217 {
218     if (visit == PreVisit)
219     {
220         // SeparateDeclarations should have already been run.
221         ASSERT(node->getSequence()->size() == 1u);
222 
223         TIntermTyped *declarator = node->getSequence()->back()->getAsTyped();
224         ASSERT(declarator);
225 
226         // We can only remove variables that are not a part of the shader interface.
227         TQualifier qualifier = declarator->getQualifier();
228         if (qualifier != EvqTemporary && qualifier != EvqGlobal)
229         {
230             return true;
231         }
232 
233         bool canRemoveVariable    = false;
234         TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
235         if (symbolNode != nullptr)
236         {
237             canRemoveVariable =
238                 (*mSymbolIdRefCounts)[symbolNode->getId()] == 1u || symbolNode->getSymbol().empty();
239         }
240         TIntermBinary *initNode = declarator->getAsBinaryNode();
241         if (initNode != nullptr)
242         {
243             ASSERT(initNode->getLeft()->getAsSymbolNode());
244             int symbolId = initNode->getLeft()->getAsSymbolNode()->getId();
245             canRemoveVariable =
246                 (*mSymbolIdRefCounts)[symbolId] == 1u && !initNode->getRight()->hasSideEffects();
247         }
248 
249         if (canRemoveVariable)
250         {
251             removeVariableDeclaration(node, declarator);
252             mRemoveReferences = true;
253         }
254         return true;
255     }
256     ASSERT(visit == PostVisit);
257     mRemoveReferences = false;
258     return true;
259 }
260 
visitSymbol(TIntermSymbol * node)261 void RemoveUnreferencedVariablesTraverser::visitSymbol(TIntermSymbol *node)
262 {
263     if (mRemoveReferences)
264     {
265         ASSERT(mSymbolIdRefCounts->find(node->getId()) != mSymbolIdRefCounts->end());
266         --(*mSymbolIdRefCounts)[node->getId()];
267 
268         decrementStructTypeRefCount(node->getType());
269     }
270 }
271 
visitAggregate(Visit visit,TIntermAggregate * node)272 bool RemoveUnreferencedVariablesTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
273 {
274     if (mRemoveReferences)
275     {
276         decrementStructTypeRefCount(node->getType());
277     }
278     return true;
279 }
280 
traverseBlock(TIntermBlock * node)281 void RemoveUnreferencedVariablesTraverser::traverseBlock(TIntermBlock *node)
282 {
283     // We traverse blocks in reverse order.  This way reference counts can be decremented when
284     // removing initializers, and variables that become unused when initializers are removed can be
285     // removed on the same traversal.
286 
287     ScopedNodeInTraversalPath addToPath(this, node);
288 
289     bool visit = true;
290 
291     TIntermSequence *sequence = node->getSequence();
292 
293     if (preVisit)
294         visit = visitBlock(PreVisit, node);
295 
296     if (visit)
297     {
298         for (auto iter = sequence->rbegin(); iter != sequence->rend(); ++iter)
299         {
300             (*iter)->traverse(this);
301             if (visit && inVisit)
302             {
303                 if ((iter + 1) != sequence->rend())
304                     visit = visitBlock(InVisit, node);
305             }
306         }
307     }
308 
309     if (visit && postVisit)
310         visitBlock(PostVisit, node);
311 }
312 
traverseLoop(TIntermLoop * node)313 void RemoveUnreferencedVariablesTraverser::traverseLoop(TIntermLoop *node)
314 {
315     // We traverse loops in reverse order as well. The loop body gets traversed before the init
316     // node.
317 
318     ScopedNodeInTraversalPath addToPath(this, node);
319 
320     bool visit = true;
321 
322     if (preVisit)
323         visit = visitLoop(PreVisit, node);
324 
325     if (visit)
326     {
327         // We don't need to traverse loop expressions or conditions since they can't be declarations
328         // in the AST (loops which have a declaration in their condition get transformed in the
329         // parsing stage).
330         ASSERT(node->getExpression() == nullptr ||
331                node->getExpression()->getAsDeclarationNode() == nullptr);
332         ASSERT(node->getCondition() == nullptr ||
333                node->getCondition()->getAsDeclarationNode() == nullptr);
334 
335         if (node->getBody())
336             node->getBody()->traverse(this);
337 
338         if (node->getInit())
339             node->getInit()->traverse(this);
340     }
341 
342     if (visit && postVisit)
343         visitLoop(PostVisit, node);
344 }
345 
346 }  // namespace
347 
RemoveUnreferencedVariables(TIntermBlock * root,TSymbolTable * symbolTable)348 void RemoveUnreferencedVariables(TIntermBlock *root, TSymbolTable *symbolTable)
349 {
350     CollectVariableRefCountsTraverser collector;
351     root->traverse(&collector);
352     RemoveUnreferencedVariablesTraverser traverser(&collector.getSymbolIdRefCounts(),
353                                                    &collector.getStructIdRefCounts(), symbolTable);
354     root->traverse(&traverser);
355     traverser.updateTree();
356 }
357 
358 }  // namespace sh
359