1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7 // replacing them with calls to functions that choose which component to return or write.
8 //
9 
10 #include "compiler/translator/RemoveDynamicIndexing.h"
11 
12 #include "compiler/translator/Diagnostics.h"
13 #include "compiler/translator/InfoSink.h"
14 #include "compiler/translator/IntermNodePatternMatcher.h"
15 #include "compiler/translator/IntermNode_util.h"
16 #include "compiler/translator/IntermTraverse.h"
17 #include "compiler/translator/SymbolTable.h"
18 
19 namespace sh
20 {
21 
22 namespace
23 {
24 
GetIndexFunctionName(const TType & type,bool write)25 std::string GetIndexFunctionName(const TType &type, bool write)
26 {
27     TInfoSinkBase nameSink;
28     nameSink << "dyn_index_";
29     if (write)
30     {
31         nameSink << "write_";
32     }
33     if (type.isMatrix())
34     {
35         nameSink << "mat" << type.getCols() << "x" << type.getRows();
36     }
37     else
38     {
39         switch (type.getBasicType())
40         {
41             case EbtInt:
42                 nameSink << "ivec";
43                 break;
44             case EbtBool:
45                 nameSink << "bvec";
46                 break;
47             case EbtUInt:
48                 nameSink << "uvec";
49                 break;
50             case EbtFloat:
51                 nameSink << "vec";
52                 break;
53             default:
54                 UNREACHABLE();
55         }
56         nameSink << type.getNominalSize();
57     }
58     return nameSink.str();
59 }
60 
CreateBaseSymbol(const TType & type,TQualifier qualifier,TSymbolTable * symbolTable)61 TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier, TSymbolTable *symbolTable)
62 {
63     TIntermSymbol *symbol = new TIntermSymbol(symbolTable->nextUniqueId(), "base", type);
64     symbol->setInternal(true);
65     symbol->getTypePointer()->setQualifier(qualifier);
66     return symbol;
67 }
68 
CreateIndexSymbol(TSymbolTable * symbolTable)69 TIntermSymbol *CreateIndexSymbol(TSymbolTable *symbolTable)
70 {
71     TIntermSymbol *symbol =
72         new TIntermSymbol(symbolTable->nextUniqueId(), "index", TType(EbtInt, EbpHigh));
73     symbol->setInternal(true);
74     symbol->getTypePointer()->setQualifier(EvqIn);
75     return symbol;
76 }
77 
CreateValueSymbol(const TType & type,TSymbolTable * symbolTable)78 TIntermSymbol *CreateValueSymbol(const TType &type, TSymbolTable *symbolTable)
79 {
80     TIntermSymbol *symbol = new TIntermSymbol(symbolTable->nextUniqueId(), "value", type);
81     symbol->setInternal(true);
82     symbol->getTypePointer()->setQualifier(EvqIn);
83     return symbol;
84 }
85 
CreateIntConstantNode(int i)86 TIntermConstantUnion *CreateIntConstantNode(int i)
87 {
88     TConstantUnion *constant = new TConstantUnion();
89     constant->setIConst(i);
90     return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
91 }
92 
EnsureSignedInt(TIntermTyped * node)93 TIntermTyped *EnsureSignedInt(TIntermTyped *node)
94 {
95     if (node->getBasicType() == EbtInt)
96         return node;
97 
98     TIntermSequence *arguments = new TIntermSequence();
99     arguments->push_back(node);
100     return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
101 }
102 
GetFieldType(const TType & indexedType)103 TType GetFieldType(const TType &indexedType)
104 {
105     if (indexedType.isMatrix())
106     {
107         TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
108         fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
109         return fieldType;
110     }
111     else
112     {
113         return TType(indexedType.getBasicType(), indexedType.getPrecision());
114     }
115 }
116 
117 // Generate a read or write function for one field in a vector/matrix.
118 // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
119 // indices in other places.
120 // Note that indices can be either int or uint. We create only int versions of the functions,
121 // and convert uint indices to int at the call site.
122 // read function example:
123 // float dyn_index_vec2(in vec2 base, in int index)
124 // {
125 //    switch(index)
126 //    {
127 //      case (0):
128 //        return base[0];
129 //      case (1):
130 //        return base[1];
131 //      default:
132 //        break;
133 //    }
134 //    if (index < 0)
135 //      return base[0];
136 //    return base[1];
137 // }
138 // write function example:
139 // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
140 // {
141 //    switch(index)
142 //    {
143 //      case (0):
144 //        base[0] = value;
145 //        return;
146 //      case (1):
147 //        base[1] = value;
148 //        return;
149 //      default:
150 //        break;
151 //    }
152 //    if (index < 0)
153 //    {
154 //      base[0] = value;
155 //      return;
156 //    }
157 //    base[1] = value;
158 // }
159 // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
GetIndexFunctionDefinition(TType type,bool write,const TSymbolUniqueId & functionId,TSymbolTable * symbolTable)160 TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type,
161                                                       bool write,
162                                                       const TSymbolUniqueId &functionId,
163                                                       TSymbolTable *symbolTable)
164 {
165     ASSERT(!type.isArray());
166     // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
167     // end up using mediump version of an indexing function for a highp value, if both mediump and
168     // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
169     // principle this code could be used with multiple backends.
170     type.setPrecision(EbpHigh);
171 
172     TType fieldType = GetFieldType(type);
173     int numCases    = 0;
174     if (type.isMatrix())
175     {
176         numCases = type.getCols();
177     }
178     else
179     {
180         numCases = type.getNominalSize();
181     }
182 
183     TType returnType(EbtVoid);
184     if (!write)
185     {
186         returnType = fieldType;
187     }
188 
189     std::string functionName                = GetIndexFunctionName(type, write);
190     TIntermFunctionPrototype *prototypeNode =
191         CreateInternalFunctionPrototypeNode(returnType, functionName.c_str(), functionId);
192 
193     TQualifier baseQualifier     = EvqInOut;
194     if (!write)
195         baseQualifier        = EvqIn;
196     TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier, symbolTable);
197     prototypeNode->getSequence()->push_back(baseParam);
198     TIntermSymbol *indexParam = CreateIndexSymbol(symbolTable);
199     prototypeNode->getSequence()->push_back(indexParam);
200     TIntermSymbol *valueParam = nullptr;
201     if (write)
202     {
203         valueParam = CreateValueSymbol(fieldType, symbolTable);
204         prototypeNode->getSequence()->push_back(valueParam);
205     }
206 
207     TIntermBlock *statementList = new TIntermBlock();
208     for (int i = 0; i < numCases; ++i)
209     {
210         TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
211         statementList->getSequence()->push_back(caseNode);
212 
213         TIntermBinary *indexNode =
214             new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
215         if (write)
216         {
217             TIntermBinary *assignNode =
218                 new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
219             statementList->getSequence()->push_back(assignNode);
220             TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
221             statementList->getSequence()->push_back(returnNode);
222         }
223         else
224         {
225             TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
226             statementList->getSequence()->push_back(returnNode);
227         }
228     }
229 
230     // Default case
231     TIntermCase *defaultNode = new TIntermCase(nullptr);
232     statementList->getSequence()->push_back(defaultNode);
233     TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
234     statementList->getSequence()->push_back(breakNode);
235 
236     TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
237 
238     TIntermBlock *bodyNode = new TIntermBlock();
239     bodyNode->getSequence()->push_back(switchNode);
240 
241     TIntermBinary *cond =
242         new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
243     cond->setType(TType(EbtBool, EbpUndefined));
244 
245     // Two blocks: one accesses (either reads or writes) the first element and returns,
246     // the other accesses the last element.
247     TIntermBlock *useFirstBlock = new TIntermBlock();
248     TIntermBlock *useLastBlock  = new TIntermBlock();
249     TIntermBinary *indexFirstNode =
250         new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
251     TIntermBinary *indexLastNode =
252         new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
253     if (write)
254     {
255         TIntermBinary *assignFirstNode =
256             new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
257         useFirstBlock->getSequence()->push_back(assignFirstNode);
258         TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
259         useFirstBlock->getSequence()->push_back(returnNode);
260 
261         TIntermBinary *assignLastNode =
262             new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
263         useLastBlock->getSequence()->push_back(assignLastNode);
264     }
265     else
266     {
267         TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
268         useFirstBlock->getSequence()->push_back(returnFirstNode);
269 
270         TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
271         useLastBlock->getSequence()->push_back(returnLastNode);
272     }
273     TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
274     bodyNode->getSequence()->push_back(ifNode);
275     bodyNode->getSequence()->push_back(useLastBlock);
276 
277     TIntermFunctionDefinition *indexingFunction =
278         new TIntermFunctionDefinition(prototypeNode, bodyNode);
279     return indexingFunction;
280 }
281 
282 class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
283 {
284   public:
285     RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable,
286                                    int shaderVersion,
287                                    PerformanceDiagnostics *perfDiagnostics);
288 
289     bool visitBinary(Visit visit, TIntermBinary *node) override;
290 
291     void insertHelperDefinitions(TIntermNode *root);
292 
293     void nextIteration();
294 
usedTreeInsertion() const295     bool usedTreeInsertion() const { return mUsedTreeInsertion; }
296 
297   protected:
298     // Maps of types that are indexed to the indexing function ids used for them. Note that these
299     // can not store multiple variants of the same type with different precisions - only one
300     // precision gets stored.
301     std::map<TType, TSymbolUniqueId *> mIndexedVecAndMatrixTypes;
302     std::map<TType, TSymbolUniqueId *> mWrittenVecAndMatrixTypes;
303 
304     bool mUsedTreeInsertion;
305 
306     // When true, the traverser will remove side effects from any indexing expression.
307     // This is done so that in code like
308     //   V[j++][i]++.
309     // where V is an array of vectors, j++ will only be evaluated once.
310     bool mRemoveIndexSideEffectsInSubtree;
311 
312     PerformanceDiagnostics *mPerfDiagnostics;
313 };
314 
RemoveDynamicIndexingTraverser(TSymbolTable * symbolTable,int shaderVersion,PerformanceDiagnostics * perfDiagnostics)315 RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
316     TSymbolTable *symbolTable,
317     int shaderVersion,
318     PerformanceDiagnostics *perfDiagnostics)
319     : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
320       mUsedTreeInsertion(false),
321       mRemoveIndexSideEffectsInSubtree(false),
322       mPerfDiagnostics(perfDiagnostics)
323 {
324 }
325 
insertHelperDefinitions(TIntermNode * root)326 void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
327 {
328     TIntermBlock *rootBlock = root->getAsBlock();
329     ASSERT(rootBlock != nullptr);
330     TIntermSequence insertions;
331     for (auto &type : mIndexedVecAndMatrixTypes)
332     {
333         insertions.push_back(
334             GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
335     }
336     for (auto &type : mWrittenVecAndMatrixTypes)
337     {
338         insertions.push_back(
339             GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
340     }
341     rootBlock->insertChildNodes(0, insertions);
342 }
343 
344 // Create a call to dyn_index_*() based on an indirect indexing op node
CreateIndexFunctionCall(TIntermBinary * node,TIntermTyped * index,const TSymbolUniqueId & functionId)345 TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
346                                           TIntermTyped *index,
347                                           const TSymbolUniqueId &functionId)
348 {
349     ASSERT(node->getOp() == EOpIndexIndirect);
350     TIntermSequence *arguments = new TIntermSequence();
351     arguments->push_back(node->getLeft());
352     arguments->push_back(index);
353 
354     TType fieldType                = GetFieldType(node->getLeft()->getType());
355     std::string functionName       = GetIndexFunctionName(node->getLeft()->getType(), false);
356     TIntermAggregate *indexingCall =
357         CreateInternalFunctionCallNode(fieldType, functionName.c_str(), functionId, arguments);
358     indexingCall->setLine(node->getLine());
359     indexingCall->getFunctionSymbolInfo()->setKnownToNotHaveSideEffects(true);
360     return indexingCall;
361 }
362 
CreateIndexedWriteFunctionCall(TIntermBinary * node,TIntermTyped * index,TIntermTyped * writtenValue,const TSymbolUniqueId & functionId)363 TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
364                                                  TIntermTyped *index,
365                                                  TIntermTyped *writtenValue,
366                                                  const TSymbolUniqueId &functionId)
367 {
368     ASSERT(node->getOp() == EOpIndexIndirect);
369     TIntermSequence *arguments = new TIntermSequence();
370     // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
371     arguments->push_back(node->getLeft()->deepCopy());
372     arguments->push_back(index->deepCopy());
373     arguments->push_back(writtenValue);
374 
375     std::string functionName           = GetIndexFunctionName(node->getLeft()->getType(), true);
376     TIntermAggregate *indexedWriteCall =
377         CreateInternalFunctionCallNode(TType(EbtVoid), functionName.c_str(), functionId, arguments);
378     indexedWriteCall->setLine(node->getLine());
379     return indexedWriteCall;
380 }
381 
visitBinary(Visit visit,TIntermBinary * node)382 bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
383 {
384     if (mUsedTreeInsertion)
385         return false;
386 
387     if (node->getOp() == EOpIndexIndirect)
388     {
389         if (mRemoveIndexSideEffectsInSubtree)
390         {
391             ASSERT(node->getRight()->hasSideEffects());
392             // In case we're just removing index side effects, convert
393             //   v_expr[index_expr]
394             // to this:
395             //   int s0 = index_expr; v_expr[s0];
396             // Now v_expr[s0] can be safely executed several times without unintended side effects.
397 
398             // Init the temp variable holding the index
399             TIntermDeclaration *initIndex = createTempInitDeclaration(node->getRight());
400             insertStatementInParentBlock(initIndex);
401             mUsedTreeInsertion = true;
402 
403             // Replace the index with the temp variable
404             TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
405             queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
406         }
407         else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
408         {
409             mPerfDiagnostics->warning(node->getLine(),
410                                       "Performance: dynamic indexing of vectors and "
411                                       "matrices is emulated and can be slow.",
412                                       "[]");
413             bool write = isLValueRequiredHere();
414 
415 #if defined(ANGLE_ENABLE_ASSERTS)
416             // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
417             // implemented checks in this traverser.
418             IntermNodePatternMatcher matcher(
419                 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
420             ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
421 #endif
422 
423             const TType &type = node->getLeft()->getType();
424             TSymbolUniqueId *indexingFunctionId = new TSymbolUniqueId(mSymbolTable);
425             if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
426             {
427                 mIndexedVecAndMatrixTypes[type] = indexingFunctionId;
428             }
429             else
430             {
431                 indexingFunctionId = mIndexedVecAndMatrixTypes[type];
432             }
433 
434             if (write)
435             {
436                 // Convert:
437                 //   v_expr[index_expr]++;
438                 // to this:
439                 //   int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
440                 //   dyn_index_write(v_expr, s0, s1);
441                 // This works even if index_expr has some side effects.
442                 if (node->getLeft()->hasSideEffects())
443                 {
444                     // If v_expr has side effects, those need to be removed before proceeding.
445                     // Otherwise the side effects of v_expr would be evaluated twice.
446                     // The only case where an l-value can have side effects is when it is
447                     // indexing. For example, it can be V[j++] where V is an array of vectors.
448                     mRemoveIndexSideEffectsInSubtree = true;
449                     return true;
450                 }
451 
452                 TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
453                 if (leftBinary != nullptr &&
454                     IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(leftBinary))
455                 {
456                     // This is a case like:
457                     // mat2 m;
458                     // m[a][b]++;
459                     // Process the child node m[a] first.
460                     return true;
461                 }
462 
463                 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
464                 // only writes it and doesn't need the previous value. http://anglebug.com/1116
465 
466                 TSymbolUniqueId *indexedWriteFunctionId = new TSymbolUniqueId(mSymbolTable);
467                 if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
468                 {
469                     mWrittenVecAndMatrixTypes[type] = indexedWriteFunctionId;
470                 }
471                 else
472                 {
473                     indexedWriteFunctionId = mWrittenVecAndMatrixTypes[type];
474                 }
475                 TType fieldType = GetFieldType(type);
476 
477                 TIntermSequence insertionsBefore;
478                 TIntermSequence insertionsAfter;
479 
480                 // Store the index in a temporary signed int variable.
481                 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
482                 TIntermDeclaration *initIndex  = createTempInitDeclaration(indexInitializer);
483                 initIndex->setLine(node->getLine());
484                 insertionsBefore.push_back(initIndex);
485 
486                 // Create a node for referring to the index after the nextTemporaryId() call
487                 // below.
488                 TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
489 
490                 TIntermAggregate *indexingCall =
491                     CreateIndexFunctionCall(node, tempIndex, *indexingFunctionId);
492 
493                 nextTemporaryId();  // From now on, creating temporary symbols that refer to the
494                                     // field value.
495                 insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
496 
497                 TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
498                     node, tempIndex, createTempSymbol(fieldType), *indexedWriteFunctionId);
499                 insertionsAfter.push_back(indexedWriteCall);
500                 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
501                 queueReplacement(createTempSymbol(fieldType), OriginalNode::IS_DROPPED);
502                 mUsedTreeInsertion = true;
503             }
504             else
505             {
506                 // The indexed value is not being written, so we can simply convert
507                 //   v_expr[index_expr]
508                 // into
509                 //   dyn_index(v_expr, index_expr)
510                 // If the index_expr is unsigned, we'll convert it to signed.
511                 ASSERT(!mRemoveIndexSideEffectsInSubtree);
512                 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
513                     node, EnsureSignedInt(node->getRight()), *indexingFunctionId);
514                 queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
515             }
516         }
517     }
518     return !mUsedTreeInsertion;
519 }
520 
nextIteration()521 void RemoveDynamicIndexingTraverser::nextIteration()
522 {
523     mUsedTreeInsertion               = false;
524     mRemoveIndexSideEffectsInSubtree = false;
525     nextTemporaryId();
526 }
527 
528 }  // namespace
529 
RemoveDynamicIndexing(TIntermNode * root,TSymbolTable * symbolTable,int shaderVersion,PerformanceDiagnostics * perfDiagnostics)530 void RemoveDynamicIndexing(TIntermNode *root,
531                            TSymbolTable *symbolTable,
532                            int shaderVersion,
533                            PerformanceDiagnostics *perfDiagnostics)
534 {
535     RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion, perfDiagnostics);
536     do
537     {
538         traverser.nextIteration();
539         root->traverse(&traverser);
540         traverser.updateTree();
541     } while (traverser.usedTreeInsertion());
542     // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
543     // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
544     // function call nodes with no corresponding definition nodes. This needs special handling in
545     // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
546     // superficial reading of the code.
547     traverser.insertHelperDefinitions(root);
548 }
549 
550 }  // namespace sh
551