1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7 // replacing them with calls to functions that choose which component to return or write.
8 //
9 
10 #include "compiler/translator/RemoveDynamicIndexing.h"
11 
12 #include "compiler/translator/Diagnostics.h"
13 #include "compiler/translator/InfoSink.h"
14 #include "compiler/translator/IntermNodePatternMatcher.h"
15 #include "compiler/translator/IntermNode_util.h"
16 #include "compiler/translator/IntermTraverse.h"
17 #include "compiler/translator/StaticType.h"
18 #include "compiler/translator/SymbolTable.h"
19 
20 namespace sh
21 {
22 
23 namespace
24 {
25 
26 const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqIn, 1, 1>();
27 
28 constexpr const ImmutableString kBaseName("base");
29 constexpr const ImmutableString kIndexName("index");
30 constexpr const ImmutableString kValueName("value");
31 
GetIndexFunctionName(const TType & type,bool write)32 std::string GetIndexFunctionName(const TType &type, bool write)
33 {
34     TInfoSinkBase nameSink;
35     nameSink << "dyn_index_";
36     if (write)
37     {
38         nameSink << "write_";
39     }
40     if (type.isMatrix())
41     {
42         nameSink << "mat" << type.getCols() << "x" << type.getRows();
43     }
44     else
45     {
46         switch (type.getBasicType())
47         {
48             case EbtInt:
49                 nameSink << "ivec";
50                 break;
51             case EbtBool:
52                 nameSink << "bvec";
53                 break;
54             case EbtUInt:
55                 nameSink << "uvec";
56                 break;
57             case EbtFloat:
58                 nameSink << "vec";
59                 break;
60             default:
61                 UNREACHABLE();
62         }
63         nameSink << type.getNominalSize();
64     }
65     return nameSink.str();
66 }
67 
CreateParameterSymbol(const TConstParameter & parameter,TSymbolTable * symbolTable)68 TIntermSymbol *CreateParameterSymbol(const TConstParameter &parameter, TSymbolTable *symbolTable)
69 {
70     TVariable *variable =
71         new TVariable(symbolTable, parameter.name, parameter.type, SymbolType::AngleInternal);
72     return new TIntermSymbol(variable);
73 }
74 
CreateIntConstantNode(int i)75 TIntermConstantUnion *CreateIntConstantNode(int i)
76 {
77     TConstantUnion *constant = new TConstantUnion();
78     constant->setIConst(i);
79     return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
80 }
81 
EnsureSignedInt(TIntermTyped * node)82 TIntermTyped *EnsureSignedInt(TIntermTyped *node)
83 {
84     if (node->getBasicType() == EbtInt)
85         return node;
86 
87     TIntermSequence *arguments = new TIntermSequence();
88     arguments->push_back(node);
89     return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
90 }
91 
GetFieldType(const TType & indexedType)92 TType *GetFieldType(const TType &indexedType)
93 {
94     if (indexedType.isMatrix())
95     {
96         TType *fieldType = new TType(indexedType.getBasicType(), indexedType.getPrecision());
97         fieldType->setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
98         return fieldType;
99     }
100     else
101     {
102         return new TType(indexedType.getBasicType(), indexedType.getPrecision());
103     }
104 }
105 
GetBaseType(const TType & type,bool write)106 const TType *GetBaseType(const TType &type, bool write)
107 {
108     TType *baseType = new TType(type);
109     // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
110     // end up using mediump version of an indexing function for a highp value, if both mediump and
111     // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
112     // principle this code could be used with multiple backends.
113     baseType->setPrecision(EbpHigh);
114     baseType->setQualifier(EvqInOut);
115     if (!write)
116         baseType->setQualifier(EvqIn);
117     return baseType;
118 }
119 
120 // Generate a read or write function for one field in a vector/matrix.
121 // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
122 // indices in other places.
123 // Note that indices can be either int or uint. We create only int versions of the functions,
124 // and convert uint indices to int at the call site.
125 // read function example:
126 // float dyn_index_vec2(in vec2 base, in int index)
127 // {
128 //    switch(index)
129 //    {
130 //      case (0):
131 //        return base[0];
132 //      case (1):
133 //        return base[1];
134 //      default:
135 //        break;
136 //    }
137 //    if (index < 0)
138 //      return base[0];
139 //    return base[1];
140 // }
141 // write function example:
142 // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
143 // {
144 //    switch(index)
145 //    {
146 //      case (0):
147 //        base[0] = value;
148 //        return;
149 //      case (1):
150 //        base[1] = value;
151 //        return;
152 //      default:
153 //        break;
154 //    }
155 //    if (index < 0)
156 //    {
157 //      base[0] = value;
158 //      return;
159 //    }
160 //    base[1] = value;
161 // }
162 // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
GetIndexFunctionDefinition(const TType & type,bool write,const TFunction & func,TSymbolTable * symbolTable)163 TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
164                                                       bool write,
165                                                       const TFunction &func,
166                                                       TSymbolTable *symbolTable)
167 {
168     ASSERT(!type.isArray());
169 
170     int numCases    = 0;
171     if (type.isMatrix())
172     {
173         numCases = type.getCols();
174     }
175     else
176     {
177         numCases = type.getNominalSize();
178     }
179 
180     std::string functionName                = GetIndexFunctionName(type, write);
181     TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
182 
183     TIntermSymbol *baseParam = CreateParameterSymbol(func.getParam(0), symbolTable);
184     prototypeNode->getSequence()->push_back(baseParam);
185     TIntermSymbol *indexParam = CreateParameterSymbol(func.getParam(1), symbolTable);
186     prototypeNode->getSequence()->push_back(indexParam);
187     TIntermSymbol *valueParam = nullptr;
188     if (write)
189     {
190         valueParam = CreateParameterSymbol(func.getParam(2), symbolTable);
191         prototypeNode->getSequence()->push_back(valueParam);
192     }
193 
194     TIntermBlock *statementList = new TIntermBlock();
195     for (int i = 0; i < numCases; ++i)
196     {
197         TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
198         statementList->getSequence()->push_back(caseNode);
199 
200         TIntermBinary *indexNode =
201             new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
202         if (write)
203         {
204             TIntermBinary *assignNode =
205                 new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
206             statementList->getSequence()->push_back(assignNode);
207             TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
208             statementList->getSequence()->push_back(returnNode);
209         }
210         else
211         {
212             TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
213             statementList->getSequence()->push_back(returnNode);
214         }
215     }
216 
217     // Default case
218     TIntermCase *defaultNode = new TIntermCase(nullptr);
219     statementList->getSequence()->push_back(defaultNode);
220     TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
221     statementList->getSequence()->push_back(breakNode);
222 
223     TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
224 
225     TIntermBlock *bodyNode = new TIntermBlock();
226     bodyNode->getSequence()->push_back(switchNode);
227 
228     TIntermBinary *cond =
229         new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
230 
231     // Two blocks: one accesses (either reads or writes) the first element and returns,
232     // the other accesses the last element.
233     TIntermBlock *useFirstBlock = new TIntermBlock();
234     TIntermBlock *useLastBlock  = new TIntermBlock();
235     TIntermBinary *indexFirstNode =
236         new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
237     TIntermBinary *indexLastNode =
238         new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
239     if (write)
240     {
241         TIntermBinary *assignFirstNode =
242             new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
243         useFirstBlock->getSequence()->push_back(assignFirstNode);
244         TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
245         useFirstBlock->getSequence()->push_back(returnNode);
246 
247         TIntermBinary *assignLastNode =
248             new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
249         useLastBlock->getSequence()->push_back(assignLastNode);
250     }
251     else
252     {
253         TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
254         useFirstBlock->getSequence()->push_back(returnFirstNode);
255 
256         TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
257         useLastBlock->getSequence()->push_back(returnLastNode);
258     }
259     TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
260     bodyNode->getSequence()->push_back(ifNode);
261     bodyNode->getSequence()->push_back(useLastBlock);
262 
263     TIntermFunctionDefinition *indexingFunction =
264         new TIntermFunctionDefinition(prototypeNode, bodyNode);
265     return indexingFunction;
266 }
267 
268 class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
269 {
270   public:
271     RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable,
272                                    PerformanceDiagnostics *perfDiagnostics);
273 
274     bool visitBinary(Visit visit, TIntermBinary *node) override;
275 
276     void insertHelperDefinitions(TIntermNode *root);
277 
278     void nextIteration();
279 
usedTreeInsertion() const280     bool usedTreeInsertion() const { return mUsedTreeInsertion; }
281 
282   protected:
283     // Maps of types that are indexed to the indexing function ids used for them. Note that these
284     // can not store multiple variants of the same type with different precisions - only one
285     // precision gets stored.
286     std::map<TType, TFunction *> mIndexedVecAndMatrixTypes;
287     std::map<TType, TFunction *> mWrittenVecAndMatrixTypes;
288 
289     bool mUsedTreeInsertion;
290 
291     // When true, the traverser will remove side effects from any indexing expression.
292     // This is done so that in code like
293     //   V[j++][i]++.
294     // where V is an array of vectors, j++ will only be evaluated once.
295     bool mRemoveIndexSideEffectsInSubtree;
296 
297     PerformanceDiagnostics *mPerfDiagnostics;
298 };
299 
RemoveDynamicIndexingTraverser(TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)300 RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
301     TSymbolTable *symbolTable,
302     PerformanceDiagnostics *perfDiagnostics)
303     : TLValueTrackingTraverser(true, false, false, symbolTable),
304       mUsedTreeInsertion(false),
305       mRemoveIndexSideEffectsInSubtree(false),
306       mPerfDiagnostics(perfDiagnostics)
307 {
308 }
309 
insertHelperDefinitions(TIntermNode * root)310 void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
311 {
312     TIntermBlock *rootBlock = root->getAsBlock();
313     ASSERT(rootBlock != nullptr);
314     TIntermSequence insertions;
315     for (auto &type : mIndexedVecAndMatrixTypes)
316     {
317         insertions.push_back(
318             GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
319     }
320     for (auto &type : mWrittenVecAndMatrixTypes)
321     {
322         insertions.push_back(
323             GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
324     }
325     rootBlock->insertChildNodes(0, insertions);
326 }
327 
328 // Create a call to dyn_index_*() based on an indirect indexing op node
CreateIndexFunctionCall(TIntermBinary * node,TIntermTyped * index,TFunction * indexingFunction)329 TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
330                                           TIntermTyped *index,
331                                           TFunction *indexingFunction)
332 {
333     ASSERT(node->getOp() == EOpIndexIndirect);
334     TIntermSequence *arguments = new TIntermSequence();
335     arguments->push_back(node->getLeft());
336     arguments->push_back(index);
337 
338     TIntermAggregate *indexingCall =
339         TIntermAggregate::CreateFunctionCall(*indexingFunction, arguments);
340     indexingCall->setLine(node->getLine());
341     return indexingCall;
342 }
343 
CreateIndexedWriteFunctionCall(TIntermBinary * node,TVariable * index,TVariable * writtenValue,TFunction * indexedWriteFunction)344 TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
345                                                  TVariable *index,
346                                                  TVariable *writtenValue,
347                                                  TFunction *indexedWriteFunction)
348 {
349     ASSERT(node->getOp() == EOpIndexIndirect);
350     TIntermSequence *arguments = new TIntermSequence();
351     // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
352     arguments->push_back(node->getLeft()->deepCopy());
353     arguments->push_back(CreateTempSymbolNode(index));
354     arguments->push_back(CreateTempSymbolNode(writtenValue));
355 
356     TIntermAggregate *indexedWriteCall =
357         TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, arguments);
358     indexedWriteCall->setLine(node->getLine());
359     return indexedWriteCall;
360 }
361 
visitBinary(Visit visit,TIntermBinary * node)362 bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
363 {
364     if (mUsedTreeInsertion)
365         return false;
366 
367     if (node->getOp() == EOpIndexIndirect)
368     {
369         if (mRemoveIndexSideEffectsInSubtree)
370         {
371             ASSERT(node->getRight()->hasSideEffects());
372             // In case we're just removing index side effects, convert
373             //   v_expr[index_expr]
374             // to this:
375             //   int s0 = index_expr; v_expr[s0];
376             // Now v_expr[s0] can be safely executed several times without unintended side effects.
377             TIntermDeclaration *indexVariableDeclaration = nullptr;
378             TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(),
379                                                            EvqTemporary, &indexVariableDeclaration);
380             insertStatementInParentBlock(indexVariableDeclaration);
381             mUsedTreeInsertion = true;
382 
383             // Replace the index with the temp variable
384             TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable);
385             queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
386         }
387         else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
388         {
389             mPerfDiagnostics->warning(node->getLine(),
390                                       "Performance: dynamic indexing of vectors and "
391                                       "matrices is emulated and can be slow.",
392                                       "[]");
393             bool write = isLValueRequiredHere();
394 
395 #if defined(ANGLE_ENABLE_ASSERTS)
396             // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
397             // implemented checks in this traverser.
398             IntermNodePatternMatcher matcher(
399                 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
400             ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
401 #endif
402 
403             const TType &type = node->getLeft()->getType();
404             ImmutableString indexingFunctionName(GetIndexFunctionName(type, false));
405             TFunction *indexingFunction = nullptr;
406             if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
407             {
408                 indexingFunction =
409                     new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal,
410                                   GetFieldType(type), true);
411                 indexingFunction->addParameter(
412                     TConstParameter(kBaseName, GetBaseType(type, false)));
413                 indexingFunction->addParameter(TConstParameter(kIndexName, kIndexType));
414                 mIndexedVecAndMatrixTypes[type] = indexingFunction;
415             }
416             else
417             {
418                 indexingFunction = mIndexedVecAndMatrixTypes[type];
419             }
420 
421             if (write)
422             {
423                 // Convert:
424                 //   v_expr[index_expr]++;
425                 // to this:
426                 //   int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
427                 //   dyn_index_write(v_expr, s0, s1);
428                 // This works even if index_expr has some side effects.
429                 if (node->getLeft()->hasSideEffects())
430                 {
431                     // If v_expr has side effects, those need to be removed before proceeding.
432                     // Otherwise the side effects of v_expr would be evaluated twice.
433                     // The only case where an l-value can have side effects is when it is
434                     // indexing. For example, it can be V[j++] where V is an array of vectors.
435                     mRemoveIndexSideEffectsInSubtree = true;
436                     return true;
437                 }
438 
439                 TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
440                 if (leftBinary != nullptr &&
441                     IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(leftBinary))
442                 {
443                     // This is a case like:
444                     // mat2 m;
445                     // m[a][b]++;
446                     // Process the child node m[a] first.
447                     return true;
448                 }
449 
450                 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
451                 // only writes it and doesn't need the previous value. http://anglebug.com/1116
452 
453                 TFunction *indexedWriteFunction = nullptr;
454                 if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
455                 {
456                     ImmutableString functionName(
457                         GetIndexFunctionName(node->getLeft()->getType(), true));
458                     indexedWriteFunction =
459                         new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal,
460                                       StaticType::GetBasic<EbtVoid>(), false);
461                     indexedWriteFunction->addParameter(
462                         TConstParameter(kBaseName, GetBaseType(type, true)));
463                     indexedWriteFunction->addParameter(TConstParameter(kIndexName, kIndexType));
464                     TType *valueType = GetFieldType(type);
465                     valueType->setQualifier(EvqIn);
466                     indexedWriteFunction->addParameter(
467                         TConstParameter(kValueName, static_cast<const TType *>(valueType)));
468                     mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
469                 }
470                 else
471                 {
472                     indexedWriteFunction = mWrittenVecAndMatrixTypes[type];
473                 }
474 
475                 TIntermSequence insertionsBefore;
476                 TIntermSequence insertionsAfter;
477 
478                 // Store the index in a temporary signed int variable.
479                 // s0 = index_expr;
480                 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
481                 TIntermDeclaration *indexVariableDeclaration = nullptr;
482                 TVariable *indexVariable                     = DeclareTempVariable(
483                     mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration);
484                 insertionsBefore.push_back(indexVariableDeclaration);
485 
486                 // s1 = dyn_index(v_expr, s0);
487                 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
488                     node, CreateTempSymbolNode(indexVariable), indexingFunction);
489                 TIntermDeclaration *fieldVariableDeclaration = nullptr;
490                 TVariable *fieldVariable                     = DeclareTempVariable(
491                     mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
492                 insertionsBefore.push_back(fieldVariableDeclaration);
493 
494                 // dyn_index_write(v_expr, s0, s1);
495                 TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
496                     node, indexVariable, fieldVariable, indexedWriteFunction);
497                 insertionsAfter.push_back(indexedWriteCall);
498                 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
499 
500                 // replace the node with s1
501                 queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED);
502                 mUsedTreeInsertion = true;
503             }
504             else
505             {
506                 // The indexed value is not being written, so we can simply convert
507                 //   v_expr[index_expr]
508                 // into
509                 //   dyn_index(v_expr, index_expr)
510                 // If the index_expr is unsigned, we'll convert it to signed.
511                 ASSERT(!mRemoveIndexSideEffectsInSubtree);
512                 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
513                     node, EnsureSignedInt(node->getRight()), indexingFunction);
514                 queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
515             }
516         }
517     }
518     return !mUsedTreeInsertion;
519 }
520 
nextIteration()521 void RemoveDynamicIndexingTraverser::nextIteration()
522 {
523     mUsedTreeInsertion               = false;
524     mRemoveIndexSideEffectsInSubtree = false;
525 }
526 
527 }  // namespace
528 
RemoveDynamicIndexing(TIntermNode * root,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)529 void RemoveDynamicIndexing(TIntermNode *root,
530                            TSymbolTable *symbolTable,
531                            PerformanceDiagnostics *perfDiagnostics)
532 {
533     RemoveDynamicIndexingTraverser traverser(symbolTable, perfDiagnostics);
534     do
535     {
536         traverser.nextIteration();
537         root->traverse(&traverser);
538         traverser.updateTree();
539     } while (traverser.usedTreeInsertion());
540     // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
541     // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
542     // function call nodes with no corresponding definition nodes. This needs special handling in
543     // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
544     // superficial reading of the code.
545     traverser.insertHelperDefinitions(root);
546 }
547 
548 }  // namespace sh
549