1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7 // replacing them with calls to functions that choose which component to return or write.
8 //
9
10 #include "compiler/translator/RemoveDynamicIndexing.h"
11
12 #include "compiler/translator/Diagnostics.h"
13 #include "compiler/translator/InfoSink.h"
14 #include "compiler/translator/IntermNodePatternMatcher.h"
15 #include "compiler/translator/IntermNode_util.h"
16 #include "compiler/translator/IntermTraverse.h"
17 #include "compiler/translator/SymbolTable.h"
18
19 namespace sh
20 {
21
22 namespace
23 {
24
GetIndexFunctionName(const TType & type,bool write)25 std::string GetIndexFunctionName(const TType &type, bool write)
26 {
27 TInfoSinkBase nameSink;
28 nameSink << "dyn_index_";
29 if (write)
30 {
31 nameSink << "write_";
32 }
33 if (type.isMatrix())
34 {
35 nameSink << "mat" << type.getCols() << "x" << type.getRows();
36 }
37 else
38 {
39 switch (type.getBasicType())
40 {
41 case EbtInt:
42 nameSink << "ivec";
43 break;
44 case EbtBool:
45 nameSink << "bvec";
46 break;
47 case EbtUInt:
48 nameSink << "uvec";
49 break;
50 case EbtFloat:
51 nameSink << "vec";
52 break;
53 default:
54 UNREACHABLE();
55 }
56 nameSink << type.getNominalSize();
57 }
58 return nameSink.str();
59 }
60
CreateBaseSymbol(const TType & type,TQualifier qualifier,TSymbolTable * symbolTable)61 TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier, TSymbolTable *symbolTable)
62 {
63 TIntermSymbol *symbol = new TIntermSymbol(symbolTable->nextUniqueId(), "base", type);
64 symbol->setInternal(true);
65 symbol->getTypePointer()->setQualifier(qualifier);
66 return symbol;
67 }
68
CreateIndexSymbol(TSymbolTable * symbolTable)69 TIntermSymbol *CreateIndexSymbol(TSymbolTable *symbolTable)
70 {
71 TIntermSymbol *symbol =
72 new TIntermSymbol(symbolTable->nextUniqueId(), "index", TType(EbtInt, EbpHigh));
73 symbol->setInternal(true);
74 symbol->getTypePointer()->setQualifier(EvqIn);
75 return symbol;
76 }
77
CreateValueSymbol(const TType & type,TSymbolTable * symbolTable)78 TIntermSymbol *CreateValueSymbol(const TType &type, TSymbolTable *symbolTable)
79 {
80 TIntermSymbol *symbol = new TIntermSymbol(symbolTable->nextUniqueId(), "value", type);
81 symbol->setInternal(true);
82 symbol->getTypePointer()->setQualifier(EvqIn);
83 return symbol;
84 }
85
CreateIntConstantNode(int i)86 TIntermConstantUnion *CreateIntConstantNode(int i)
87 {
88 TConstantUnion *constant = new TConstantUnion();
89 constant->setIConst(i);
90 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
91 }
92
EnsureSignedInt(TIntermTyped * node)93 TIntermTyped *EnsureSignedInt(TIntermTyped *node)
94 {
95 if (node->getBasicType() == EbtInt)
96 return node;
97
98 TIntermSequence *arguments = new TIntermSequence();
99 arguments->push_back(node);
100 return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
101 }
102
GetFieldType(const TType & indexedType)103 TType GetFieldType(const TType &indexedType)
104 {
105 if (indexedType.isMatrix())
106 {
107 TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
108 fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
109 return fieldType;
110 }
111 else
112 {
113 return TType(indexedType.getBasicType(), indexedType.getPrecision());
114 }
115 }
116
117 // Generate a read or write function for one field in a vector/matrix.
118 // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
119 // indices in other places.
120 // Note that indices can be either int or uint. We create only int versions of the functions,
121 // and convert uint indices to int at the call site.
122 // read function example:
123 // float dyn_index_vec2(in vec2 base, in int index)
124 // {
125 // switch(index)
126 // {
127 // case (0):
128 // return base[0];
129 // case (1):
130 // return base[1];
131 // default:
132 // break;
133 // }
134 // if (index < 0)
135 // return base[0];
136 // return base[1];
137 // }
138 // write function example:
139 // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
140 // {
141 // switch(index)
142 // {
143 // case (0):
144 // base[0] = value;
145 // return;
146 // case (1):
147 // base[1] = value;
148 // return;
149 // default:
150 // break;
151 // }
152 // if (index < 0)
153 // {
154 // base[0] = value;
155 // return;
156 // }
157 // base[1] = value;
158 // }
159 // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
GetIndexFunctionDefinition(TType type,bool write,const TSymbolUniqueId & functionId,TSymbolTable * symbolTable)160 TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type,
161 bool write,
162 const TSymbolUniqueId &functionId,
163 TSymbolTable *symbolTable)
164 {
165 ASSERT(!type.isArray());
166 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
167 // end up using mediump version of an indexing function for a highp value, if both mediump and
168 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
169 // principle this code could be used with multiple backends.
170 type.setPrecision(EbpHigh);
171
172 TType fieldType = GetFieldType(type);
173 int numCases = 0;
174 if (type.isMatrix())
175 {
176 numCases = type.getCols();
177 }
178 else
179 {
180 numCases = type.getNominalSize();
181 }
182
183 TType returnType(EbtVoid);
184 if (!write)
185 {
186 returnType = fieldType;
187 }
188
189 std::string functionName = GetIndexFunctionName(type, write);
190 TIntermFunctionPrototype *prototypeNode =
191 CreateInternalFunctionPrototypeNode(returnType, functionName.c_str(), functionId);
192
193 TQualifier baseQualifier = EvqInOut;
194 if (!write)
195 baseQualifier = EvqIn;
196 TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier, symbolTable);
197 prototypeNode->getSequence()->push_back(baseParam);
198 TIntermSymbol *indexParam = CreateIndexSymbol(symbolTable);
199 prototypeNode->getSequence()->push_back(indexParam);
200 TIntermSymbol *valueParam = nullptr;
201 if (write)
202 {
203 valueParam = CreateValueSymbol(fieldType, symbolTable);
204 prototypeNode->getSequence()->push_back(valueParam);
205 }
206
207 TIntermBlock *statementList = new TIntermBlock();
208 for (int i = 0; i < numCases; ++i)
209 {
210 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
211 statementList->getSequence()->push_back(caseNode);
212
213 TIntermBinary *indexNode =
214 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
215 if (write)
216 {
217 TIntermBinary *assignNode =
218 new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
219 statementList->getSequence()->push_back(assignNode);
220 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
221 statementList->getSequence()->push_back(returnNode);
222 }
223 else
224 {
225 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
226 statementList->getSequence()->push_back(returnNode);
227 }
228 }
229
230 // Default case
231 TIntermCase *defaultNode = new TIntermCase(nullptr);
232 statementList->getSequence()->push_back(defaultNode);
233 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
234 statementList->getSequence()->push_back(breakNode);
235
236 TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
237
238 TIntermBlock *bodyNode = new TIntermBlock();
239 bodyNode->getSequence()->push_back(switchNode);
240
241 TIntermBinary *cond =
242 new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
243 cond->setType(TType(EbtBool, EbpUndefined));
244
245 // Two blocks: one accesses (either reads or writes) the first element and returns,
246 // the other accesses the last element.
247 TIntermBlock *useFirstBlock = new TIntermBlock();
248 TIntermBlock *useLastBlock = new TIntermBlock();
249 TIntermBinary *indexFirstNode =
250 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
251 TIntermBinary *indexLastNode =
252 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
253 if (write)
254 {
255 TIntermBinary *assignFirstNode =
256 new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
257 useFirstBlock->getSequence()->push_back(assignFirstNode);
258 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
259 useFirstBlock->getSequence()->push_back(returnNode);
260
261 TIntermBinary *assignLastNode =
262 new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
263 useLastBlock->getSequence()->push_back(assignLastNode);
264 }
265 else
266 {
267 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
268 useFirstBlock->getSequence()->push_back(returnFirstNode);
269
270 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
271 useLastBlock->getSequence()->push_back(returnLastNode);
272 }
273 TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
274 bodyNode->getSequence()->push_back(ifNode);
275 bodyNode->getSequence()->push_back(useLastBlock);
276
277 TIntermFunctionDefinition *indexingFunction =
278 new TIntermFunctionDefinition(prototypeNode, bodyNode);
279 return indexingFunction;
280 }
281
282 class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
283 {
284 public:
285 RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable,
286 int shaderVersion,
287 PerformanceDiagnostics *perfDiagnostics);
288
289 bool visitBinary(Visit visit, TIntermBinary *node) override;
290
291 void insertHelperDefinitions(TIntermNode *root);
292
293 void nextIteration();
294
usedTreeInsertion() const295 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
296
297 protected:
298 // Maps of types that are indexed to the indexing function ids used for them. Note that these
299 // can not store multiple variants of the same type with different precisions - only one
300 // precision gets stored.
301 std::map<TType, TSymbolUniqueId *> mIndexedVecAndMatrixTypes;
302 std::map<TType, TSymbolUniqueId *> mWrittenVecAndMatrixTypes;
303
304 bool mUsedTreeInsertion;
305
306 // When true, the traverser will remove side effects from any indexing expression.
307 // This is done so that in code like
308 // V[j++][i]++.
309 // where V is an array of vectors, j++ will only be evaluated once.
310 bool mRemoveIndexSideEffectsInSubtree;
311
312 PerformanceDiagnostics *mPerfDiagnostics;
313 };
314
RemoveDynamicIndexingTraverser(TSymbolTable * symbolTable,int shaderVersion,PerformanceDiagnostics * perfDiagnostics)315 RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
316 TSymbolTable *symbolTable,
317 int shaderVersion,
318 PerformanceDiagnostics *perfDiagnostics)
319 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
320 mUsedTreeInsertion(false),
321 mRemoveIndexSideEffectsInSubtree(false),
322 mPerfDiagnostics(perfDiagnostics)
323 {
324 }
325
insertHelperDefinitions(TIntermNode * root)326 void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
327 {
328 TIntermBlock *rootBlock = root->getAsBlock();
329 ASSERT(rootBlock != nullptr);
330 TIntermSequence insertions;
331 for (auto &type : mIndexedVecAndMatrixTypes)
332 {
333 insertions.push_back(
334 GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
335 }
336 for (auto &type : mWrittenVecAndMatrixTypes)
337 {
338 insertions.push_back(
339 GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
340 }
341 rootBlock->insertChildNodes(0, insertions);
342 }
343
344 // Create a call to dyn_index_*() based on an indirect indexing op node
CreateIndexFunctionCall(TIntermBinary * node,TIntermTyped * index,const TSymbolUniqueId & functionId)345 TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
346 TIntermTyped *index,
347 const TSymbolUniqueId &functionId)
348 {
349 ASSERT(node->getOp() == EOpIndexIndirect);
350 TIntermSequence *arguments = new TIntermSequence();
351 arguments->push_back(node->getLeft());
352 arguments->push_back(index);
353
354 TType fieldType = GetFieldType(node->getLeft()->getType());
355 std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), false);
356 TIntermAggregate *indexingCall =
357 CreateInternalFunctionCallNode(fieldType, functionName.c_str(), functionId, arguments);
358 indexingCall->setLine(node->getLine());
359 indexingCall->getFunctionSymbolInfo()->setKnownToNotHaveSideEffects(true);
360 return indexingCall;
361 }
362
CreateIndexedWriteFunctionCall(TIntermBinary * node,TIntermTyped * index,TIntermTyped * writtenValue,const TSymbolUniqueId & functionId)363 TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
364 TIntermTyped *index,
365 TIntermTyped *writtenValue,
366 const TSymbolUniqueId &functionId)
367 {
368 ASSERT(node->getOp() == EOpIndexIndirect);
369 TIntermSequence *arguments = new TIntermSequence();
370 // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
371 arguments->push_back(node->getLeft()->deepCopy());
372 arguments->push_back(index->deepCopy());
373 arguments->push_back(writtenValue);
374
375 std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), true);
376 TIntermAggregate *indexedWriteCall =
377 CreateInternalFunctionCallNode(TType(EbtVoid), functionName.c_str(), functionId, arguments);
378 indexedWriteCall->setLine(node->getLine());
379 return indexedWriteCall;
380 }
381
visitBinary(Visit visit,TIntermBinary * node)382 bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
383 {
384 if (mUsedTreeInsertion)
385 return false;
386
387 if (node->getOp() == EOpIndexIndirect)
388 {
389 if (mRemoveIndexSideEffectsInSubtree)
390 {
391 ASSERT(node->getRight()->hasSideEffects());
392 // In case we're just removing index side effects, convert
393 // v_expr[index_expr]
394 // to this:
395 // int s0 = index_expr; v_expr[s0];
396 // Now v_expr[s0] can be safely executed several times without unintended side effects.
397
398 // Init the temp variable holding the index
399 TIntermDeclaration *initIndex = createTempInitDeclaration(node->getRight());
400 insertStatementInParentBlock(initIndex);
401 mUsedTreeInsertion = true;
402
403 // Replace the index with the temp variable
404 TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
405 queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
406 }
407 else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
408 {
409 mPerfDiagnostics->warning(node->getLine(),
410 "Performance: dynamic indexing of vectors and "
411 "matrices is emulated and can be slow.",
412 "[]");
413 bool write = isLValueRequiredHere();
414
415 #if defined(ANGLE_ENABLE_ASSERTS)
416 // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
417 // implemented checks in this traverser.
418 IntermNodePatternMatcher matcher(
419 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
420 ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
421 #endif
422
423 const TType &type = node->getLeft()->getType();
424 TSymbolUniqueId *indexingFunctionId = new TSymbolUniqueId(mSymbolTable);
425 if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
426 {
427 mIndexedVecAndMatrixTypes[type] = indexingFunctionId;
428 }
429 else
430 {
431 indexingFunctionId = mIndexedVecAndMatrixTypes[type];
432 }
433
434 if (write)
435 {
436 // Convert:
437 // v_expr[index_expr]++;
438 // to this:
439 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
440 // dyn_index_write(v_expr, s0, s1);
441 // This works even if index_expr has some side effects.
442 if (node->getLeft()->hasSideEffects())
443 {
444 // If v_expr has side effects, those need to be removed before proceeding.
445 // Otherwise the side effects of v_expr would be evaluated twice.
446 // The only case where an l-value can have side effects is when it is
447 // indexing. For example, it can be V[j++] where V is an array of vectors.
448 mRemoveIndexSideEffectsInSubtree = true;
449 return true;
450 }
451
452 TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
453 if (leftBinary != nullptr &&
454 IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(leftBinary))
455 {
456 // This is a case like:
457 // mat2 m;
458 // m[a][b]++;
459 // Process the child node m[a] first.
460 return true;
461 }
462
463 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
464 // only writes it and doesn't need the previous value. http://anglebug.com/1116
465
466 TSymbolUniqueId *indexedWriteFunctionId = new TSymbolUniqueId(mSymbolTable);
467 if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
468 {
469 mWrittenVecAndMatrixTypes[type] = indexedWriteFunctionId;
470 }
471 else
472 {
473 indexedWriteFunctionId = mWrittenVecAndMatrixTypes[type];
474 }
475 TType fieldType = GetFieldType(type);
476
477 TIntermSequence insertionsBefore;
478 TIntermSequence insertionsAfter;
479
480 // Store the index in a temporary signed int variable.
481 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
482 TIntermDeclaration *initIndex = createTempInitDeclaration(indexInitializer);
483 initIndex->setLine(node->getLine());
484 insertionsBefore.push_back(initIndex);
485
486 // Create a node for referring to the index after the nextTemporaryId() call
487 // below.
488 TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
489
490 TIntermAggregate *indexingCall =
491 CreateIndexFunctionCall(node, tempIndex, *indexingFunctionId);
492
493 nextTemporaryId(); // From now on, creating temporary symbols that refer to the
494 // field value.
495 insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
496
497 TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
498 node, tempIndex, createTempSymbol(fieldType), *indexedWriteFunctionId);
499 insertionsAfter.push_back(indexedWriteCall);
500 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
501 queueReplacement(createTempSymbol(fieldType), OriginalNode::IS_DROPPED);
502 mUsedTreeInsertion = true;
503 }
504 else
505 {
506 // The indexed value is not being written, so we can simply convert
507 // v_expr[index_expr]
508 // into
509 // dyn_index(v_expr, index_expr)
510 // If the index_expr is unsigned, we'll convert it to signed.
511 ASSERT(!mRemoveIndexSideEffectsInSubtree);
512 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
513 node, EnsureSignedInt(node->getRight()), *indexingFunctionId);
514 queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
515 }
516 }
517 }
518 return !mUsedTreeInsertion;
519 }
520
nextIteration()521 void RemoveDynamicIndexingTraverser::nextIteration()
522 {
523 mUsedTreeInsertion = false;
524 mRemoveIndexSideEffectsInSubtree = false;
525 nextTemporaryId();
526 }
527
528 } // namespace
529
RemoveDynamicIndexing(TIntermNode * root,TSymbolTable * symbolTable,int shaderVersion,PerformanceDiagnostics * perfDiagnostics)530 void RemoveDynamicIndexing(TIntermNode *root,
531 TSymbolTable *symbolTable,
532 int shaderVersion,
533 PerformanceDiagnostics *perfDiagnostics)
534 {
535 RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion, perfDiagnostics);
536 do
537 {
538 traverser.nextIteration();
539 root->traverse(&traverser);
540 traverser.updateTree();
541 } while (traverser.usedTreeInsertion());
542 // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
543 // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
544 // function call nodes with no corresponding definition nodes. This needs special handling in
545 // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
546 // superficial reading of the code.
547 traverser.insertHelperDefinitions(root);
548 }
549
550 } // namespace sh
551