1 //
2 // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7 // replacing them with calls to functions that choose which component to return or write.
8 //
9
10 #include "compiler/translator/RemoveDynamicIndexing.h"
11
12 #include "compiler/translator/Diagnostics.h"
13 #include "compiler/translator/InfoSink.h"
14 #include "compiler/translator/IntermNodePatternMatcher.h"
15 #include "compiler/translator/IntermNode_util.h"
16 #include "compiler/translator/IntermTraverse.h"
17 #include "compiler/translator/StaticType.h"
18 #include "compiler/translator/SymbolTable.h"
19
20 namespace sh
21 {
22
23 namespace
24 {
25
26 const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqIn, 1, 1>();
27
28 constexpr const ImmutableString kBaseName("base");
29 constexpr const ImmutableString kIndexName("index");
30 constexpr const ImmutableString kValueName("value");
31
GetIndexFunctionName(const TType & type,bool write)32 std::string GetIndexFunctionName(const TType &type, bool write)
33 {
34 TInfoSinkBase nameSink;
35 nameSink << "dyn_index_";
36 if (write)
37 {
38 nameSink << "write_";
39 }
40 if (type.isMatrix())
41 {
42 nameSink << "mat" << type.getCols() << "x" << type.getRows();
43 }
44 else
45 {
46 switch (type.getBasicType())
47 {
48 case EbtInt:
49 nameSink << "ivec";
50 break;
51 case EbtBool:
52 nameSink << "bvec";
53 break;
54 case EbtUInt:
55 nameSink << "uvec";
56 break;
57 case EbtFloat:
58 nameSink << "vec";
59 break;
60 default:
61 UNREACHABLE();
62 }
63 nameSink << type.getNominalSize();
64 }
65 return nameSink.str();
66 }
67
CreateParameterSymbol(const TConstParameter & parameter,TSymbolTable * symbolTable)68 TIntermSymbol *CreateParameterSymbol(const TConstParameter ¶meter, TSymbolTable *symbolTable)
69 {
70 TVariable *variable =
71 new TVariable(symbolTable, parameter.name, parameter.type, SymbolType::AngleInternal);
72 return new TIntermSymbol(variable);
73 }
74
CreateIntConstantNode(int i)75 TIntermConstantUnion *CreateIntConstantNode(int i)
76 {
77 TConstantUnion *constant = new TConstantUnion();
78 constant->setIConst(i);
79 return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
80 }
81
EnsureSignedInt(TIntermTyped * node)82 TIntermTyped *EnsureSignedInt(TIntermTyped *node)
83 {
84 if (node->getBasicType() == EbtInt)
85 return node;
86
87 TIntermSequence *arguments = new TIntermSequence();
88 arguments->push_back(node);
89 return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
90 }
91
GetFieldType(const TType & indexedType)92 TType *GetFieldType(const TType &indexedType)
93 {
94 if (indexedType.isMatrix())
95 {
96 TType *fieldType = new TType(indexedType.getBasicType(), indexedType.getPrecision());
97 fieldType->setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
98 return fieldType;
99 }
100 else
101 {
102 return new TType(indexedType.getBasicType(), indexedType.getPrecision());
103 }
104 }
105
GetBaseType(const TType & type,bool write)106 const TType *GetBaseType(const TType &type, bool write)
107 {
108 TType *baseType = new TType(type);
109 // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
110 // end up using mediump version of an indexing function for a highp value, if both mediump and
111 // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
112 // principle this code could be used with multiple backends.
113 baseType->setPrecision(EbpHigh);
114 baseType->setQualifier(EvqInOut);
115 if (!write)
116 baseType->setQualifier(EvqIn);
117 return baseType;
118 }
119
120 // Generate a read or write function for one field in a vector/matrix.
121 // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
122 // indices in other places.
123 // Note that indices can be either int or uint. We create only int versions of the functions,
124 // and convert uint indices to int at the call site.
125 // read function example:
126 // float dyn_index_vec2(in vec2 base, in int index)
127 // {
128 // switch(index)
129 // {
130 // case (0):
131 // return base[0];
132 // case (1):
133 // return base[1];
134 // default:
135 // break;
136 // }
137 // if (index < 0)
138 // return base[0];
139 // return base[1];
140 // }
141 // write function example:
142 // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
143 // {
144 // switch(index)
145 // {
146 // case (0):
147 // base[0] = value;
148 // return;
149 // case (1):
150 // base[1] = value;
151 // return;
152 // default:
153 // break;
154 // }
155 // if (index < 0)
156 // {
157 // base[0] = value;
158 // return;
159 // }
160 // base[1] = value;
161 // }
162 // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
GetIndexFunctionDefinition(const TType & type,bool write,const TFunction & func,TSymbolTable * symbolTable)163 TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
164 bool write,
165 const TFunction &func,
166 TSymbolTable *symbolTable)
167 {
168 ASSERT(!type.isArray());
169
170 int numCases = 0;
171 if (type.isMatrix())
172 {
173 numCases = type.getCols();
174 }
175 else
176 {
177 numCases = type.getNominalSize();
178 }
179
180 std::string functionName = GetIndexFunctionName(type, write);
181 TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
182
183 TIntermSymbol *baseParam = CreateParameterSymbol(func.getParam(0), symbolTable);
184 prototypeNode->getSequence()->push_back(baseParam);
185 TIntermSymbol *indexParam = CreateParameterSymbol(func.getParam(1), symbolTable);
186 prototypeNode->getSequence()->push_back(indexParam);
187 TIntermSymbol *valueParam = nullptr;
188 if (write)
189 {
190 valueParam = CreateParameterSymbol(func.getParam(2), symbolTable);
191 prototypeNode->getSequence()->push_back(valueParam);
192 }
193
194 TIntermBlock *statementList = new TIntermBlock();
195 for (int i = 0; i < numCases; ++i)
196 {
197 TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
198 statementList->getSequence()->push_back(caseNode);
199
200 TIntermBinary *indexNode =
201 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
202 if (write)
203 {
204 TIntermBinary *assignNode =
205 new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
206 statementList->getSequence()->push_back(assignNode);
207 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
208 statementList->getSequence()->push_back(returnNode);
209 }
210 else
211 {
212 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
213 statementList->getSequence()->push_back(returnNode);
214 }
215 }
216
217 // Default case
218 TIntermCase *defaultNode = new TIntermCase(nullptr);
219 statementList->getSequence()->push_back(defaultNode);
220 TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
221 statementList->getSequence()->push_back(breakNode);
222
223 TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
224
225 TIntermBlock *bodyNode = new TIntermBlock();
226 bodyNode->getSequence()->push_back(switchNode);
227
228 TIntermBinary *cond =
229 new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
230
231 // Two blocks: one accesses (either reads or writes) the first element and returns,
232 // the other accesses the last element.
233 TIntermBlock *useFirstBlock = new TIntermBlock();
234 TIntermBlock *useLastBlock = new TIntermBlock();
235 TIntermBinary *indexFirstNode =
236 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
237 TIntermBinary *indexLastNode =
238 new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
239 if (write)
240 {
241 TIntermBinary *assignFirstNode =
242 new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
243 useFirstBlock->getSequence()->push_back(assignFirstNode);
244 TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
245 useFirstBlock->getSequence()->push_back(returnNode);
246
247 TIntermBinary *assignLastNode =
248 new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
249 useLastBlock->getSequence()->push_back(assignLastNode);
250 }
251 else
252 {
253 TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
254 useFirstBlock->getSequence()->push_back(returnFirstNode);
255
256 TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
257 useLastBlock->getSequence()->push_back(returnLastNode);
258 }
259 TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
260 bodyNode->getSequence()->push_back(ifNode);
261 bodyNode->getSequence()->push_back(useLastBlock);
262
263 TIntermFunctionDefinition *indexingFunction =
264 new TIntermFunctionDefinition(prototypeNode, bodyNode);
265 return indexingFunction;
266 }
267
268 class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
269 {
270 public:
271 RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable,
272 PerformanceDiagnostics *perfDiagnostics);
273
274 bool visitBinary(Visit visit, TIntermBinary *node) override;
275
276 void insertHelperDefinitions(TIntermNode *root);
277
278 void nextIteration();
279
usedTreeInsertion() const280 bool usedTreeInsertion() const { return mUsedTreeInsertion; }
281
282 protected:
283 // Maps of types that are indexed to the indexing function ids used for them. Note that these
284 // can not store multiple variants of the same type with different precisions - only one
285 // precision gets stored.
286 std::map<TType, TFunction *> mIndexedVecAndMatrixTypes;
287 std::map<TType, TFunction *> mWrittenVecAndMatrixTypes;
288
289 bool mUsedTreeInsertion;
290
291 // When true, the traverser will remove side effects from any indexing expression.
292 // This is done so that in code like
293 // V[j++][i]++.
294 // where V is an array of vectors, j++ will only be evaluated once.
295 bool mRemoveIndexSideEffectsInSubtree;
296
297 PerformanceDiagnostics *mPerfDiagnostics;
298 };
299
RemoveDynamicIndexingTraverser(TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)300 RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
301 TSymbolTable *symbolTable,
302 PerformanceDiagnostics *perfDiagnostics)
303 : TLValueTrackingTraverser(true, false, false, symbolTable),
304 mUsedTreeInsertion(false),
305 mRemoveIndexSideEffectsInSubtree(false),
306 mPerfDiagnostics(perfDiagnostics)
307 {
308 }
309
insertHelperDefinitions(TIntermNode * root)310 void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
311 {
312 TIntermBlock *rootBlock = root->getAsBlock();
313 ASSERT(rootBlock != nullptr);
314 TIntermSequence insertions;
315 for (auto &type : mIndexedVecAndMatrixTypes)
316 {
317 insertions.push_back(
318 GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
319 }
320 for (auto &type : mWrittenVecAndMatrixTypes)
321 {
322 insertions.push_back(
323 GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
324 }
325 rootBlock->insertChildNodes(0, insertions);
326 }
327
328 // Create a call to dyn_index_*() based on an indirect indexing op node
CreateIndexFunctionCall(TIntermBinary * node,TIntermTyped * index,TFunction * indexingFunction)329 TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
330 TIntermTyped *index,
331 TFunction *indexingFunction)
332 {
333 ASSERT(node->getOp() == EOpIndexIndirect);
334 TIntermSequence *arguments = new TIntermSequence();
335 arguments->push_back(node->getLeft());
336 arguments->push_back(index);
337
338 TIntermAggregate *indexingCall =
339 TIntermAggregate::CreateFunctionCall(*indexingFunction, arguments);
340 indexingCall->setLine(node->getLine());
341 return indexingCall;
342 }
343
CreateIndexedWriteFunctionCall(TIntermBinary * node,TVariable * index,TVariable * writtenValue,TFunction * indexedWriteFunction)344 TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
345 TVariable *index,
346 TVariable *writtenValue,
347 TFunction *indexedWriteFunction)
348 {
349 ASSERT(node->getOp() == EOpIndexIndirect);
350 TIntermSequence *arguments = new TIntermSequence();
351 // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
352 arguments->push_back(node->getLeft()->deepCopy());
353 arguments->push_back(CreateTempSymbolNode(index));
354 arguments->push_back(CreateTempSymbolNode(writtenValue));
355
356 TIntermAggregate *indexedWriteCall =
357 TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, arguments);
358 indexedWriteCall->setLine(node->getLine());
359 return indexedWriteCall;
360 }
361
visitBinary(Visit visit,TIntermBinary * node)362 bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
363 {
364 if (mUsedTreeInsertion)
365 return false;
366
367 if (node->getOp() == EOpIndexIndirect)
368 {
369 if (mRemoveIndexSideEffectsInSubtree)
370 {
371 ASSERT(node->getRight()->hasSideEffects());
372 // In case we're just removing index side effects, convert
373 // v_expr[index_expr]
374 // to this:
375 // int s0 = index_expr; v_expr[s0];
376 // Now v_expr[s0] can be safely executed several times without unintended side effects.
377 TIntermDeclaration *indexVariableDeclaration = nullptr;
378 TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(),
379 EvqTemporary, &indexVariableDeclaration);
380 insertStatementInParentBlock(indexVariableDeclaration);
381 mUsedTreeInsertion = true;
382
383 // Replace the index with the temp variable
384 TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable);
385 queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
386 }
387 else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
388 {
389 mPerfDiagnostics->warning(node->getLine(),
390 "Performance: dynamic indexing of vectors and "
391 "matrices is emulated and can be slow.",
392 "[]");
393 bool write = isLValueRequiredHere();
394
395 #if defined(ANGLE_ENABLE_ASSERTS)
396 // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
397 // implemented checks in this traverser.
398 IntermNodePatternMatcher matcher(
399 IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
400 ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
401 #endif
402
403 const TType &type = node->getLeft()->getType();
404 ImmutableString indexingFunctionName(GetIndexFunctionName(type, false));
405 TFunction *indexingFunction = nullptr;
406 if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
407 {
408 indexingFunction =
409 new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal,
410 GetFieldType(type), true);
411 indexingFunction->addParameter(
412 TConstParameter(kBaseName, GetBaseType(type, false)));
413 indexingFunction->addParameter(TConstParameter(kIndexName, kIndexType));
414 mIndexedVecAndMatrixTypes[type] = indexingFunction;
415 }
416 else
417 {
418 indexingFunction = mIndexedVecAndMatrixTypes[type];
419 }
420
421 if (write)
422 {
423 // Convert:
424 // v_expr[index_expr]++;
425 // to this:
426 // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
427 // dyn_index_write(v_expr, s0, s1);
428 // This works even if index_expr has some side effects.
429 if (node->getLeft()->hasSideEffects())
430 {
431 // If v_expr has side effects, those need to be removed before proceeding.
432 // Otherwise the side effects of v_expr would be evaluated twice.
433 // The only case where an l-value can have side effects is when it is
434 // indexing. For example, it can be V[j++] where V is an array of vectors.
435 mRemoveIndexSideEffectsInSubtree = true;
436 return true;
437 }
438
439 TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
440 if (leftBinary != nullptr &&
441 IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(leftBinary))
442 {
443 // This is a case like:
444 // mat2 m;
445 // m[a][b]++;
446 // Process the child node m[a] first.
447 return true;
448 }
449
450 // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
451 // only writes it and doesn't need the previous value. http://anglebug.com/1116
452
453 TFunction *indexedWriteFunction = nullptr;
454 if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
455 {
456 ImmutableString functionName(
457 GetIndexFunctionName(node->getLeft()->getType(), true));
458 indexedWriteFunction =
459 new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal,
460 StaticType::GetBasic<EbtVoid>(), false);
461 indexedWriteFunction->addParameter(
462 TConstParameter(kBaseName, GetBaseType(type, true)));
463 indexedWriteFunction->addParameter(TConstParameter(kIndexName, kIndexType));
464 TType *valueType = GetFieldType(type);
465 valueType->setQualifier(EvqIn);
466 indexedWriteFunction->addParameter(
467 TConstParameter(kValueName, static_cast<const TType *>(valueType)));
468 mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
469 }
470 else
471 {
472 indexedWriteFunction = mWrittenVecAndMatrixTypes[type];
473 }
474
475 TIntermSequence insertionsBefore;
476 TIntermSequence insertionsAfter;
477
478 // Store the index in a temporary signed int variable.
479 // s0 = index_expr;
480 TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
481 TIntermDeclaration *indexVariableDeclaration = nullptr;
482 TVariable *indexVariable = DeclareTempVariable(
483 mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration);
484 insertionsBefore.push_back(indexVariableDeclaration);
485
486 // s1 = dyn_index(v_expr, s0);
487 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
488 node, CreateTempSymbolNode(indexVariable), indexingFunction);
489 TIntermDeclaration *fieldVariableDeclaration = nullptr;
490 TVariable *fieldVariable = DeclareTempVariable(
491 mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
492 insertionsBefore.push_back(fieldVariableDeclaration);
493
494 // dyn_index_write(v_expr, s0, s1);
495 TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
496 node, indexVariable, fieldVariable, indexedWriteFunction);
497 insertionsAfter.push_back(indexedWriteCall);
498 insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
499
500 // replace the node with s1
501 queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED);
502 mUsedTreeInsertion = true;
503 }
504 else
505 {
506 // The indexed value is not being written, so we can simply convert
507 // v_expr[index_expr]
508 // into
509 // dyn_index(v_expr, index_expr)
510 // If the index_expr is unsigned, we'll convert it to signed.
511 ASSERT(!mRemoveIndexSideEffectsInSubtree);
512 TIntermAggregate *indexingCall = CreateIndexFunctionCall(
513 node, EnsureSignedInt(node->getRight()), indexingFunction);
514 queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
515 }
516 }
517 }
518 return !mUsedTreeInsertion;
519 }
520
nextIteration()521 void RemoveDynamicIndexingTraverser::nextIteration()
522 {
523 mUsedTreeInsertion = false;
524 mRemoveIndexSideEffectsInSubtree = false;
525 }
526
527 } // namespace
528
RemoveDynamicIndexing(TIntermNode * root,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)529 void RemoveDynamicIndexing(TIntermNode *root,
530 TSymbolTable *symbolTable,
531 PerformanceDiagnostics *perfDiagnostics)
532 {
533 RemoveDynamicIndexingTraverser traverser(symbolTable, perfDiagnostics);
534 do
535 {
536 traverser.nextIteration();
537 root->traverse(&traverser);
538 traverser.updateTree();
539 } while (traverser.usedTreeInsertion());
540 // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
541 // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
542 // function call nodes with no corresponding definition nodes. This needs special handling in
543 // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
544 // superficial reading of the code.
545 traverser.insertHelperDefinitions(root);
546 }
547
548 } // namespace sh
549