1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteRowMajorMatrices: Rewrite row-major matrices as column-major.
7 //
8 
9 #include "compiler/translator/tree_ops/gl/mac/RewriteRowMajorMatrices.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 // Only structs with matrices are tracked.  If layout(row_major) is applied to a struct that doesn't
24 // have matrices, it's silently dropped.  This is also used to avoid creating duplicates for inner
25 // structs that don't have matrices.
26 struct StructConversionData
27 {
28     // The converted struct with every matrix transposed.
29     TStructure *convertedStruct = nullptr;
30 
31     // The copy-from and copy-to functions copying from a struct to its converted version and back.
32     TFunction *copyFromOriginal = nullptr;
33     TFunction *copyToOriginal   = nullptr;
34 };
35 
DoesFieldContainRowMajorMatrix(const TField * field,bool isBlockRowMajor)36 bool DoesFieldContainRowMajorMatrix(const TField *field, bool isBlockRowMajor)
37 {
38     TLayoutMatrixPacking matrixPacking = field->type()->getLayoutQualifier().matrixPacking;
39 
40     // The field is row major if either explicitly specified as such, or if it inherits it from the
41     // block layout qualifier.
42     if (matrixPacking == EmpColumnMajor || (matrixPacking == EmpUnspecified && !isBlockRowMajor))
43     {
44         return false;
45     }
46 
47     // The field is qualified with row_major, but if it's not a matrix or a struct containing
48     // matrices, that's a useless qualifier.
49     const TType *type = field->type();
50     return type->isMatrix() || type->isStructureContainingMatrices();
51 }
52 
DuplicateField(const TField * field)53 TField *DuplicateField(const TField *field)
54 {
55     return new TField(new TType(*field->type()), field->name(), field->line(), field->symbolType());
56 }
57 
SetColumnMajor(TType * type)58 void SetColumnMajor(TType *type)
59 {
60     TLayoutQualifier layoutQualifier = type->getLayoutQualifier();
61     layoutQualifier.matrixPacking    = EmpColumnMajor;
62     type->setLayoutQualifier(layoutQualifier);
63 }
64 
TransposeMatrixType(const TType * type)65 TType *TransposeMatrixType(const TType *type)
66 {
67     TType *newType = new TType(*type);
68 
69     SetColumnMajor(newType);
70 
71     newType->setPrimarySize(static_cast<unsigned char>(type->getRows()));
72     newType->setSecondarySize(static_cast<unsigned char>(type->getCols()));
73 
74     return newType;
75 }
76 
CopyArraySizes(const TType * from,TType * to)77 void CopyArraySizes(const TType *from, TType *to)
78 {
79     if (from->isArray())
80     {
81         to->makeArrays(from->getArraySizes());
82     }
83 }
84 
85 // Determine if the node is an index node (array index or struct field selection).  For the purposes
86 // of this transformation, swizzle nodes are considered index nodes too.
IsIndexNode(TIntermNode * node,TIntermNode * child)87 bool IsIndexNode(TIntermNode *node, TIntermNode *child)
88 {
89     if (node->getAsSwizzleNode())
90     {
91         return true;
92     }
93 
94     TIntermBinary *binaryNode = node->getAsBinaryNode();
95     if (binaryNode == nullptr || child != binaryNode->getLeft())
96     {
97         return false;
98     }
99 
100     TOperator op = binaryNode->getOp();
101 
102     return op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
103            op == EOpIndexDirectStruct || op == EOpIndexIndirect;
104 }
105 
CopyToTempVariable(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * prependStatements)106 TIntermSymbol *CopyToTempVariable(TSymbolTable *symbolTable,
107                                   TIntermTyped *node,
108                                   TIntermSequence *prependStatements)
109 {
110     TVariable *temp              = CreateTempVariable(symbolTable, &node->getType());
111     TIntermDeclaration *tempDecl = CreateTempInitDeclarationNode(temp, node);
112     prependStatements->push_back(tempDecl);
113 
114     return new TIntermSymbol(temp);
115 }
116 
CreateStructCopyCall(const TFunction * copyFunc,TIntermTyped * expression)117 TIntermAggregate *CreateStructCopyCall(const TFunction *copyFunc, TIntermTyped *expression)
118 {
119     TIntermSequence args = {expression};
120     return TIntermAggregate::CreateFunctionCall(*copyFunc, &args);
121 }
122 
CreateTransposeCall(TSymbolTable * symbolTable,TIntermTyped * expression)123 TIntermTyped *CreateTransposeCall(TSymbolTable *symbolTable, TIntermTyped *expression)
124 {
125     TIntermSequence args = {expression};
126     return CreateBuiltInFunctionCallNode("transpose", &args, *symbolTable, 300);
127 }
128 
GetIndex(TSymbolTable * symbolTable,TIntermNode * node,TIntermSequence * indices,TIntermSequence * prependStatements)129 TOperator GetIndex(TSymbolTable *symbolTable,
130                    TIntermNode *node,
131                    TIntermSequence *indices,
132                    TIntermSequence *prependStatements)
133 {
134     // Swizzle nodes are converted EOpIndexDirect for simplicity, with one index per swizzle
135     // channel.
136     TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
137     if (asSwizzle)
138     {
139         for (int channel : asSwizzle->getSwizzleOffsets())
140         {
141             indices->push_back(CreateIndexNode(channel));
142         }
143         return EOpIndexDirect;
144     }
145 
146     TIntermBinary *binaryNode = node->getAsBinaryNode();
147     ASSERT(binaryNode);
148 
149     TOperator op = binaryNode->getOp();
150     ASSERT(op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
151            op == EOpIndexDirectStruct || op == EOpIndexIndirect);
152 
153     TIntermTyped *rhs = binaryNode->getRight()->deepCopy();
154     if (rhs->getAsConstantUnion() == nullptr)
155     {
156         rhs = CopyToTempVariable(symbolTable, rhs, prependStatements);
157     }
158 
159     indices->push_back(rhs);
160     return op;
161 }
162 
ReplicateIndexNode(TSymbolTable * symbolTable,TIntermNode * node,TIntermTyped * lhs,TIntermSequence * indices)163 TIntermTyped *ReplicateIndexNode(TSymbolTable *symbolTable,
164                                  TIntermNode *node,
165                                  TIntermTyped *lhs,
166                                  TIntermSequence *indices)
167 {
168     TIntermSwizzle *asSwizzle = node->getAsSwizzleNode();
169     if (asSwizzle)
170     {
171         return new TIntermSwizzle(lhs, asSwizzle->getSwizzleOffsets());
172     }
173 
174     TIntermBinary *binaryNode = node->getAsBinaryNode();
175     ASSERT(binaryNode);
176 
177     ASSERT(indices->size() == 1);
178     TIntermTyped *rhs = indices->front()->getAsTyped();
179 
180     return new TIntermBinary(binaryNode->getOp(), lhs, rhs);
181 }
182 
GetIndexOp(TIntermNode * node)183 TOperator GetIndexOp(TIntermNode *node)
184 {
185     return node->getAsConstantUnion() ? EOpIndexDirect : EOpIndexIndirect;
186 }
187 
IsConvertedField(TIntermTyped * indexNode,const angle::HashMap<const TField *,bool> & convertedFields)188 bool IsConvertedField(TIntermTyped *indexNode,
189                       const angle::HashMap<const TField *, bool> &convertedFields)
190 {
191     TIntermBinary *asBinary = indexNode->getAsBinaryNode();
192     if (asBinary == nullptr)
193     {
194         return false;
195     }
196 
197     if (asBinary->getOp() != EOpIndexDirectInterfaceBlock)
198     {
199         return false;
200     }
201 
202     const TInterfaceBlock *interfaceBlock = asBinary->getLeft()->getType().getInterfaceBlock();
203     ASSERT(interfaceBlock);
204 
205     TIntermConstantUnion *fieldIndexNode = asBinary->getRight()->getAsConstantUnion();
206     ASSERT(fieldIndexNode);
207     ASSERT(fieldIndexNode->getConstantValue() != nullptr);
208 
209     int fieldIndex      = fieldIndexNode->getConstantValue()->getIConst();
210     const TField *field = interfaceBlock->fields()[fieldIndex];
211 
212     return convertedFields.count(field) > 0 && convertedFields.at(field);
213 }
214 
215 // A helper class to transform expressions of array type.  Iterates over every element of the
216 // array.
217 class TransformArrayHelper
218 {
219   public:
TransformArrayHelper(TIntermTyped * baseExpression)220     TransformArrayHelper(TIntermTyped *baseExpression)
221         : mBaseExpression(baseExpression),
222           mBaseExpressionType(baseExpression->getType()),
223           mArrayIndices(mBaseExpressionType.getArraySizes().size(), 0)
224     {}
225 
getNextElement(TIntermTyped * valueExpression,TIntermTyped ** valueElementOut)226     TIntermTyped *getNextElement(TIntermTyped *valueExpression, TIntermTyped **valueElementOut)
227     {
228         const TSpan<const unsigned int> &arraySizes = mBaseExpressionType.getArraySizes();
229 
230         // If the last index overflows, element enumeration is done.
231         if (mArrayIndices.back() >= arraySizes.back())
232         {
233             return nullptr;
234         }
235 
236         TIntermTyped *element = getCurrentElement(mBaseExpression);
237         if (valueExpression)
238         {
239             *valueElementOut = getCurrentElement(valueExpression);
240         }
241 
242         incrementIndices(arraySizes);
243         return element;
244     }
245 
accumulateForRead(TSymbolTable * symbolTable,TIntermTyped * transformedElement,TIntermSequence * prependStatements)246     void accumulateForRead(TSymbolTable *symbolTable,
247                            TIntermTyped *transformedElement,
248                            TIntermSequence *prependStatements)
249     {
250         TIntermTyped *temp = CopyToTempVariable(symbolTable, transformedElement, prependStatements);
251         mReadTransformConstructorArgs.push_back(temp);
252     }
253 
constructReadTransformExpression()254     TIntermTyped *constructReadTransformExpression()
255     {
256         const TSpan<const unsigned int> &baseTypeArraySizes = mBaseExpressionType.getArraySizes();
257         TVector<unsigned int> arraySizes(baseTypeArraySizes.begin(), baseTypeArraySizes.end());
258         TIntermTyped *firstElement = mReadTransformConstructorArgs.front()->getAsTyped();
259         const TType &baseType      = firstElement->getType();
260 
261         // If N dimensions, acc[0] == size[0] and acc[i] == size[i] * acc[i-1].
262         // The last value is unused, and is not present.
263         TVector<unsigned int> accumulatedArraySizes(arraySizes.size() - 1);
264 
265         if (accumulatedArraySizes.size() > 0)
266         {
267             accumulatedArraySizes[0] = arraySizes[0];
268         }
269         for (size_t index = 1; index + 1 < arraySizes.size(); ++index)
270         {
271             accumulatedArraySizes[index] = accumulatedArraySizes[index - 1] * arraySizes[index];
272         }
273 
274         return constructReadTransformExpressionHelper(arraySizes, accumulatedArraySizes, baseType,
275                                                       0);
276     }
277 
278   private:
getCurrentElement(TIntermTyped * expression)279     TIntermTyped *getCurrentElement(TIntermTyped *expression)
280     {
281         TIntermTyped *element = expression->deepCopy();
282         for (auto it = mArrayIndices.rbegin(); it != mArrayIndices.rend(); ++it)
283         {
284             unsigned int index = *it;
285             element            = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(index));
286         }
287         return element;
288     }
289 
incrementIndices(const TSpan<const unsigned int> & arraySizes)290     void incrementIndices(const TSpan<const unsigned int> &arraySizes)
291     {
292         // Assume mArrayIndices is an N digit number, where digit i is in the range
293         // [0, arraySizes[i]).  This function increments this number.  Last digit is the most
294         // significant digit.
295         for (size_t digitIndex = 0; digitIndex < arraySizes.size(); ++digitIndex)
296         {
297             ++mArrayIndices[digitIndex];
298             if (mArrayIndices[digitIndex] < arraySizes[digitIndex])
299             {
300                 break;
301             }
302             if (digitIndex + 1 != arraySizes.size())
303             {
304                 // This digit has now overflown and is reset to 0, carry will be added to the next
305                 // digit.  The most significant digit will keep the overflow though, to make it
306                 // clear we have exhausted the range.
307                 mArrayIndices[digitIndex] = 0;
308             }
309         }
310     }
311 
constructReadTransformExpressionHelper(const TVector<unsigned int> & arraySizes,const TVector<unsigned int> & accumulatedArraySizes,const TType & baseType,size_t elementsOffset)312     TIntermTyped *constructReadTransformExpressionHelper(
313         const TVector<unsigned int> &arraySizes,
314         const TVector<unsigned int> &accumulatedArraySizes,
315         const TType &baseType,
316         size_t elementsOffset)
317     {
318         ASSERT(!arraySizes.empty());
319 
320         TType *transformedType = new TType(baseType);
321         transformedType->makeArrays(arraySizes);
322 
323         // If one dimensional, create the constructor with the given elements.
324         if (arraySizes.size() == 1)
325         {
326             ASSERT(accumulatedArraySizes.size() == 0);
327 
328             auto sliceStart = mReadTransformConstructorArgs.begin() + elementsOffset;
329             TIntermSequence slice(sliceStart, sliceStart + arraySizes[0]);
330 
331             return TIntermAggregate::CreateConstructor(*transformedType, &slice);
332         }
333 
334         // If not, create constructors for every column recursively.
335         TVector<unsigned int> subArraySizes(arraySizes.begin(), arraySizes.end() - 1);
336         TVector<unsigned int> subArrayAccumulatedSizes(accumulatedArraySizes.begin(),
337                                                        accumulatedArraySizes.end() - 1);
338 
339         TIntermSequence constructorArgs;
340         unsigned int colStride = accumulatedArraySizes.back();
341         for (size_t col = 0; col < arraySizes.back(); ++col)
342         {
343             size_t colElementsOffset = elementsOffset + col * colStride;
344 
345             constructorArgs.push_back(constructReadTransformExpressionHelper(
346                 subArraySizes, subArrayAccumulatedSizes, baseType, colElementsOffset));
347         }
348 
349         return TIntermAggregate::CreateConstructor(*transformedType, &constructorArgs);
350     }
351 
352     TIntermTyped *mBaseExpression;
353     const TType &mBaseExpressionType;
354     TVector<unsigned int> mArrayIndices;
355 
356     TIntermSequence mReadTransformConstructorArgs;
357 };
358 
359 // Traverser that:
360 //
361 // 1. Converts |layout(row_major) matCxR M| to |layout(column_major) matRxC Mt|.
362 // 2. Converts |layout(row_major) S s| to |layout(column_major) St st|, where S is a struct that
363 //    contains matrices, and St is a new struct with the transformation in 1 applied to matrix
364 //    members (recursively).
365 // 3. When read from, the following transformations are applied:
366 //
367 //            M       -> transpose(Mt)
368 //            M[c]    -> gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
369 //            M[c][r] -> Mt[r][c]
370 //            M[c].yz -> gvec2(Mt[1][c], Mt[2][c])
371 //            MArr    -> MType[D1]..[DN](transpose(MtArr[0]...[0]), ...)
372 //            s       -> copy_St_to_S(st)
373 //            sArr    -> SType[D1]...[DN](copy_St_to_S(stArr[0]..[0]), ...)
374 //            (matrix reads through struct are transformed similarly to M)
375 //
376 // 4. When written to, the following transformations are applied:
377 //
378 //      M = exp       -> Mt = transpose(exp)
379 //      M[c] = exp    -> temp = exp
380 //                       Mt[0][c] = temp[0]
381 //                       Mt[1][c] = temp[1]
382 //                       ...
383 //                       Mt[N-1][c] = temp[N-1]
384 //      M[c][r] = exp -> Mt[r][c] = exp
385 //      M[c].yz = exp -> temp = exp
386 //                       Mt[1][c] = temp[0]
387 //                       Mt[2][c] = temp[1]
388 //      MArr = exp    -> temp = exp
389 //                       Mt = MtType[D1]..[DN](temp([0]...[0]), ...)
390 //      s = exp       -> st = copy_S_to_St(exp)
391 //      sArr = exp    -> temp = exp
392 //                       St = StType[D1]...[DN](copy_S_to_St(temp[0]..[0]), ...)
393 //      (matrix writes through struct are transformed similarly to M)
394 //
395 // 5. If any of the above is passed to an `inout` parameter, both transformations are applied:
396 //
397 //            f(M[c]) -> temp = gvecN(Mt[0][c], Mt[1][c], ..., Mt[N-1][c])
398 //                       f(temp)
399 //                       Mt[0][c] = temp[0]
400 //                       Mt[1][c] = temp[1]
401 //                       ...
402 //                       Mt[N-1][c] = temp[N-1]
403 //
404 //               f(s) -> temp = copy_St_to_S(st)
405 //                       f(temp)
406 //                       st = copy_S_to_St(temp)
407 //
408 //    If passed to an `out` parameter, the `temp` parameter is simply not initialized.
409 //
410 // 6. If the expression leading to the matrix or struct has array subscripts, temp values are
411 //    created for them to avoid duplicating side effects.
412 //
413 class RewriteRowMajorMatricesTraverser : public TIntermTraverser
414 {
415   public:
RewriteRowMajorMatricesTraverser(TCompiler * compiler,TSymbolTable * symbolTable)416     RewriteRowMajorMatricesTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
417         : TIntermTraverser(true, true, true, symbolTable),
418           mCompiler(compiler),
419           mStructMapOut(&mOuterPass.structMap),
420           mInterfaceBlockMap(&mOuterPass.interfaceBlockMap),
421           mInterfaceBlockFieldConvertedIn(mOuterPass.interfaceBlockFieldConverted),
422           mCopyFunctionDefinitionsOut(&mOuterPass.copyFunctionDefinitions),
423           mOuterTraverser(nullptr),
424           mInnerPassRoot(nullptr),
425           mIsProcessingInnerPassSubtree(false)
426     {}
427 
visitDeclaration(Visit visit,TIntermDeclaration * node)428     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
429     {
430         // No need to process declarations in inner passes.
431         if (mInnerPassRoot != nullptr)
432         {
433             return true;
434         }
435 
436         if (visit != PreVisit)
437         {
438             return true;
439         }
440 
441         const TIntermSequence &sequence = *(node->getSequence());
442 
443         TIntermTyped *variable = sequence.front()->getAsTyped();
444         const TType &type      = variable->getType();
445 
446         // If it's a struct declaration that has matrices, remember it.  If a row-major instance
447         // of it is created, it will have to be converted.
448         if (type.isStructSpecifier() && type.isStructureContainingMatrices())
449         {
450             const TStructure *structure = type.getStruct();
451             ASSERT(structure);
452 
453             ASSERT(mOuterPass.structMap.count(structure) == 0);
454 
455             StructConversionData structData;
456             mOuterPass.structMap[structure] = structData;
457 
458             return false;
459         }
460 
461         // If it's an interface block, it may have to be converted if it contains any row-major
462         // fields.
463         if (type.isInterfaceBlock() && type.getInterfaceBlock()->containsMatrices())
464         {
465             const TInterfaceBlock *block = type.getInterfaceBlock();
466             ASSERT(block);
467             bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
468 
469             const TFieldList &fields = block->fields();
470             bool anyRowMajor         = isBlockRowMajor;
471 
472             for (const TField *field : fields)
473             {
474                 if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
475                 {
476                     anyRowMajor = true;
477                     break;
478                 }
479             }
480 
481             if (anyRowMajor)
482             {
483                 convertInterfaceBlock(node);
484             }
485 
486             return false;
487         }
488 
489         return true;
490     }
491 
visitSymbol(TIntermSymbol * symbol)492     void visitSymbol(TIntermSymbol *symbol) override
493     {
494         // If in inner pass, only process if the symbol is under that root.
495         if (mInnerPassRoot != nullptr && !mIsProcessingInnerPassSubtree)
496         {
497             return;
498         }
499 
500         const TVariable *variable = &symbol->variable();
501         bool needsRewrite         = mInterfaceBlockMap->count(variable) != 0;
502 
503         // If it's a field of a nameless interface block, it may still need conversion.
504         if (!needsRewrite)
505         {
506             // Nameless interface block field symbols have the interface block pointer set, but are
507             // not interface blocks.
508             if (symbol->getType().getInterfaceBlock() && !variable->getType().isInterfaceBlock())
509             {
510                 needsRewrite = convertNamelessInterfaceBlockField(symbol);
511             }
512         }
513 
514         if (needsRewrite)
515         {
516             transformExpression(symbol);
517         }
518     }
519 
visitBinary(Visit visit,TIntermBinary * node)520     bool visitBinary(Visit visit, TIntermBinary *node) override
521     {
522         if (node == mInnerPassRoot)
523         {
524             // We only want to process the right-hand side of an assignment in inner passes.  When
525             // visit is InVisit, the left-hand side is already processed, and the right-hand side is
526             // next.  Set a flag to mark this duration.
527             mIsProcessingInnerPassSubtree = visit == InVisit;
528         }
529 
530         return true;
531     }
532 
getStructCopyFunctions()533     TIntermSequence *getStructCopyFunctions() { return &mOuterPass.copyFunctionDefinitions; }
534 
535   private:
536     typedef angle::HashMap<const TStructure *, StructConversionData> StructMap;
537     typedef angle::HashMap<const TVariable *, TVariable *> InterfaceBlockMap;
538     typedef angle::HashMap<const TField *, bool> InterfaceBlockFieldConverted;
539 
RewriteRowMajorMatricesTraverser(TSymbolTable * symbolTable,RewriteRowMajorMatricesTraverser * outerTraverser,InterfaceBlockMap * interfaceBlockMap,const InterfaceBlockFieldConverted & interfaceBlockFieldConverted,StructMap * structMap,TIntermSequence * copyFunctionDefinitions,TIntermBinary * innerPassRoot)540     RewriteRowMajorMatricesTraverser(
541         TSymbolTable *symbolTable,
542         RewriteRowMajorMatricesTraverser *outerTraverser,
543         InterfaceBlockMap *interfaceBlockMap,
544         const InterfaceBlockFieldConverted &interfaceBlockFieldConverted,
545         StructMap *structMap,
546         TIntermSequence *copyFunctionDefinitions,
547         TIntermBinary *innerPassRoot)
548         : TIntermTraverser(true, true, true, symbolTable),
549           mStructMapOut(structMap),
550           mInterfaceBlockMap(interfaceBlockMap),
551           mInterfaceBlockFieldConvertedIn(interfaceBlockFieldConverted),
552           mCopyFunctionDefinitionsOut(copyFunctionDefinitions),
553           mOuterTraverser(outerTraverser),
554           mInnerPassRoot(innerPassRoot),
555           mIsProcessingInnerPassSubtree(false)
556     {}
557 
convertInterfaceBlock(TIntermDeclaration * node)558     void convertInterfaceBlock(TIntermDeclaration *node)
559     {
560         ASSERT(mInnerPassRoot == nullptr);
561 
562         const TIntermSequence &sequence = *(node->getSequence());
563 
564         TIntermTyped *variableNode   = sequence.front()->getAsTyped();
565         const TType &type            = variableNode->getType();
566         const TInterfaceBlock *block = type.getInterfaceBlock();
567         ASSERT(block);
568 
569         bool isBlockRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
570 
571         // Recreate the struct with its row-major fields converted to column-major equivalents.
572         TIntermSequence newDeclarations;
573 
574         TFieldList *newFields = new TFieldList;
575         for (const TField *field : block->fields())
576         {
577             TField *newField = nullptr;
578 
579             if (DoesFieldContainRowMajorMatrix(field, isBlockRowMajor))
580             {
581                 newField = convertField(field, &newDeclarations);
582 
583                 // Remember that this field was converted.
584                 mOuterPass.interfaceBlockFieldConverted[field] = true;
585             }
586             else
587             {
588                 newField = DuplicateField(field);
589             }
590 
591             newFields->push_back(newField);
592         }
593 
594         // Create a new interface block with these fields.
595         TLayoutQualifier blockLayoutQualifier = type.getLayoutQualifier();
596         blockLayoutQualifier.matrixPacking    = EmpColumnMajor;
597 
598         TInterfaceBlock *newInterfaceBlock =
599             new TInterfaceBlock(mSymbolTable, block->name(), newFields, blockLayoutQualifier,
600                                 block->symbolType(), block->extension());
601 
602         // Create a new declaration with the new type.  Declarations are separated at this point,
603         // so there should be only one variable here.
604         ASSERT(sequence.size() == 1);
605 
606         TType *newInterfaceBlockType =
607             new TType(newInterfaceBlock, type.getQualifier(), blockLayoutQualifier);
608 
609         TIntermDeclaration *newDeclaration = new TIntermDeclaration;
610         const TVariable *variable          = &variableNode->getAsSymbolNode()->variable();
611 
612         const TType *newType = newInterfaceBlockType;
613         if (type.isArray())
614         {
615             TType *newArrayType = new TType(*newType);
616             CopyArraySizes(&type, newArrayType);
617             newType = newArrayType;
618         }
619 
620         // If the interface block variable itself is temp, use an empty name.
621         bool variableIsTemp = variable->symbolType() == SymbolType::Empty;
622         const ImmutableString &variableName =
623             variableIsTemp ? kEmptyImmutableString : variable->name();
624 
625         TVariable *newVariable = new TVariable(mSymbolTable, variableName, newType,
626                                                variable->symbolType(), variable->extension());
627 
628         newDeclaration->appendDeclarator(new TIntermSymbol(newVariable));
629 
630         mOuterPass.interfaceBlockMap[variable] = newVariable;
631 
632         newDeclarations.push_back(newDeclaration);
633 
634         // Replace the interface block definition with the new one, prepending any new struct
635         // definitions.
636         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
637                                         std::move(newDeclarations));
638     }
639 
convertNamelessInterfaceBlockField(TIntermSymbol * symbol)640     bool convertNamelessInterfaceBlockField(TIntermSymbol *symbol)
641     {
642         const TVariable *variable             = &symbol->variable();
643         const TInterfaceBlock *interfaceBlock = symbol->getType().getInterfaceBlock();
644 
645         // Find the variable corresponding to this interface block.  If the interface block
646         // is not rewritten, or this refers to a field that is not rewritten, there's
647         // nothing to do.
648         for (auto iter : *mInterfaceBlockMap)
649         {
650             // Skip other rewritten nameless interface block fields.
651             if (!iter.first->getType().isInterfaceBlock())
652             {
653                 continue;
654             }
655 
656             // Skip if this is not a field of this rewritten interface block.
657             if (iter.first->getType().getInterfaceBlock() != interfaceBlock)
658             {
659                 continue;
660             }
661 
662             const ImmutableString symbolName = symbol->getName();
663 
664             // Find which field it is
665             const TVector<TField *> fields = interfaceBlock->fields();
666             const size_t fieldIndex        = variable->getType().getInterfaceBlockFieldIndex();
667             ASSERT(fieldIndex < fields.size());
668 
669             const TField *field = fields[fieldIndex];
670             ASSERT(field->name() == symbolName);
671 
672             // If this field doesn't need a rewrite, there's nothing to do.
673             if (mInterfaceBlockFieldConvertedIn.count(field) == 0 ||
674                 !mInterfaceBlockFieldConvertedIn.at(field))
675             {
676                 break;
677             }
678 
679             // Create a new variable that references the replaced interface block.
680             TType *newType = new TType(variable->getType());
681             newType->setInterfaceBlockField(iter.second->getType().getInterfaceBlock(), fieldIndex);
682 
683             TVariable *newVariable = new TVariable(mSymbolTable, variable->name(), newType,
684                                                    variable->symbolType(), variable->extension());
685 
686             (*mInterfaceBlockMap)[variable] = newVariable;
687 
688             return true;
689         }
690 
691         return false;
692     }
693 
convertStruct(const TStructure * structure,TIntermSequence * newDeclarations)694     void convertStruct(const TStructure *structure, TIntermSequence *newDeclarations)
695     {
696         ASSERT(mInnerPassRoot == nullptr);
697 
698         ASSERT(mOuterPass.structMap.count(structure) != 0);
699         StructConversionData *structData = &mOuterPass.structMap[structure];
700 
701         if (structData->convertedStruct)
702         {
703             return;
704         }
705 
706         TFieldList *newFields = new TFieldList;
707         for (const TField *field : structure->fields())
708         {
709             newFields->push_back(convertField(field, newDeclarations));
710         }
711 
712         // Create unique names for the converted structs.  We can't leave them nameless and have
713         // a name autogenerated similar to temp variables, as nameless structs exist.  A fake
714         // variable is created for the sole purpose of generating a temp name.
715         TVariable *newStructTypeName =
716             new TVariable(mSymbolTable, kEmptyImmutableString, StaticType::GetBasic<EbtUInt>(),
717                           SymbolType::Empty);
718 
719         TStructure *newStruct = new TStructure(mSymbolTable, newStructTypeName->name(), newFields,
720                                                SymbolType::AngleInternal);
721         TType *newType        = new TType(newStruct, true);
722         TVariable *newStructVar =
723             new TVariable(mSymbolTable, kEmptyImmutableString, newType, SymbolType::Empty);
724 
725         TIntermDeclaration *structDecl = new TIntermDeclaration;
726         structDecl->appendDeclarator(new TIntermSymbol(newStructVar));
727 
728         newDeclarations->push_back(structDecl);
729 
730         structData->convertedStruct = newStruct;
731     }
732 
convertField(const TField * field,TIntermSequence * newDeclarations)733     TField *convertField(const TField *field, TIntermSequence *newDeclarations)
734     {
735         ASSERT(mInnerPassRoot == nullptr);
736 
737         TField *newField = nullptr;
738 
739         const TType *fieldType = field->type();
740         TType *newType         = nullptr;
741 
742         if (fieldType->isStructureContainingMatrices())
743         {
744             // If the field is a struct instance, convert the struct and replace the field
745             // with an instance of the new struct.
746             const TStructure *fieldTypeStruct = fieldType->getStruct();
747             convertStruct(fieldTypeStruct, newDeclarations);
748 
749             StructConversionData &structData = mOuterPass.structMap[fieldTypeStruct];
750             newType                          = new TType(structData.convertedStruct, false);
751             SetColumnMajor(newType);
752             CopyArraySizes(fieldType, newType);
753         }
754         else if (fieldType->isMatrix())
755         {
756             // If the field is a matrix, transpose the matrix and replace the field with
757             // that, removing the matrix packing qualifier.
758             newType = TransposeMatrixType(fieldType);
759         }
760 
761         if (newType)
762         {
763             newField = new TField(newType, field->name(), field->line(), field->symbolType());
764         }
765         else
766         {
767             newField = DuplicateField(field);
768         }
769 
770         return newField;
771     }
772 
determineAccess(TIntermNode * expression,TIntermNode * accessor,bool * isReadOut,bool * isWriteOut)773     void determineAccess(TIntermNode *expression,
774                          TIntermNode *accessor,
775                          bool *isReadOut,
776                          bool *isWriteOut)
777     {
778         // If passing to a function, look at whether the parameter is in, out or inout.
779         TIntermAggregate *functionCall = accessor->getAsAggregate();
780 
781         if (functionCall)
782         {
783             TIntermSequence *arguments = functionCall->getSequence();
784             for (size_t argIndex = 0; argIndex < arguments->size(); ++argIndex)
785             {
786                 if ((*arguments)[argIndex] == expression)
787                 {
788                     TQualifier qualifier = EvqIn;
789 
790                     // If the aggregate is not a function call, it's a constructor, and so every
791                     // argument is an input.
792                     const TFunction *function = functionCall->getFunction();
793                     if (function)
794                     {
795                         const TVariable *param = function->getParam(argIndex);
796                         qualifier              = param->getType().getQualifier();
797                     }
798 
799                     *isReadOut  = qualifier != EvqOut;
800                     *isWriteOut = qualifier == EvqOut || qualifier == EvqInOut;
801                     break;
802                 }
803             }
804             return;
805         }
806 
807         TIntermBinary *assignment = accessor->getAsBinaryNode();
808         if (assignment && IsAssignment(assignment->getOp()))
809         {
810             // If expression is on the right of assignment, it's being read from.
811             *isReadOut = assignment->getRight() == expression;
812             // If it's on the left of assignment, it's being written to.
813             *isWriteOut = assignment->getLeft() == expression;
814             return;
815         }
816 
817         // Any other usage is a read.
818         *isReadOut  = true;
819         *isWriteOut = false;
820     }
821 
transformExpression(TIntermSymbol * symbol)822     void transformExpression(TIntermSymbol *symbol)
823     {
824         // Walk up the parent chain while the nodes are EOpIndex* (whether array indexing or struct
825         // field selection) or swizzle and construct the replacement expression.  This traversal can
826         // lead to one of the following possibilities:
827         //
828         // - a.b[N].etc.s (struct, or struct array): copy function should be declared and used,
829         // - a.b[N].etc.M (matrix or matrix array): transpose() should be used,
830         // - a.b[N].etc.M[c] (a column): each element in column needs to be handled separately,
831         // - a.b[N].etc.M[c].yz (multiple elements): similar to whole column, but a subset of
832         //   elements,
833         // - a.b[N].etc.M[c][r] (an element): single element to handle.
834         // - a.b[N].etc.x (not struct or matrix): not modified
835         //
836         // primaryIndex will contain c, if any.  secondaryIndices will contain {0, ..., R-1}
837         // (if no [r] or swizzle), {r} (if [r]), or {1, 2} (corresponding to .yz) if any.
838         //
839         // In all cases, the base symbol is replaced.  |baseExpression| will contain everything up
840         // to (and not including) the last index/swizzle operations, i.e. a.b[N].etc.s/M/x.  Any
841         // non constant array subscript is assigned to a temp variable to avoid duplicating side
842         // effects.
843         //
844         // ---
845         //
846         // NOTE that due to the use of insertStatementsInParentBlock, cases like this will be
847         // mistranslated, and this bug is likely present in most transformations that use this
848         // feature:
849         //
850         //     if (x == 1 && a.b[x = 2].etc.M = value)
851         //
852         // which will translate to:
853         //
854         //     temp = (x = 2)
855         //     if (x == 1 && a.b[temp].etc.M = transpose(value))
856         //
857         // See http://anglebug.com/3829.
858         //
859         TIntermTyped *baseExpression =
860             new TIntermSymbol(mInterfaceBlockMap->at(&symbol->variable()));
861         const TStructure *structure = nullptr;
862 
863         TIntermNode *primaryIndex = nullptr;
864         TIntermSequence secondaryIndices;
865 
866         // In some cases, it is necessary to prepend or append statements.  Those are captured in
867         // |prependStatements| and |appendStatements|.
868         TIntermSequence prependStatements;
869         TIntermSequence appendStatements;
870 
871         // If the expression is neither a struct or matrix, no modification is necessary.
872         // If it's a struct that doesn't have matrices, again there's no transformation necessary.
873         // If it's an interface block matrix field that didn't need to be transposed, no
874         // transpformation is necessary.
875         //
876         // In all these cases, |baseExpression| contains all of the original expression.
877         //
878         // If the starting symbol itself is a field of a nameless interface block, it needs
879         // conversion if we reach here.
880         bool requiresTransformation = !symbol->getType().isInterfaceBlock();
881 
882         uint32_t accessorIndex         = 0;
883         TIntermTyped *previousAncestor = symbol;
884         while (IsIndexNode(getAncestorNode(accessorIndex), previousAncestor))
885         {
886             TIntermTyped *ancestor = getAncestorNode(accessorIndex)->getAsTyped();
887             ASSERT(ancestor);
888 
889             const TType &previousAncestorType = previousAncestor->getType();
890 
891             TIntermSequence indices;
892             TOperator op = GetIndex(mSymbolTable, ancestor, &indices, &prependStatements);
893 
894             bool opIsIndex     = op == EOpIndexDirect || op == EOpIndexIndirect;
895             bool isArrayIndex  = opIsIndex && previousAncestorType.isArray();
896             bool isMatrixIndex = opIsIndex && previousAncestorType.isMatrix();
897 
898             // If it's a direct index in a matrix, it's the primary index.
899             bool isMatrixPrimarySubscript = isMatrixIndex && !isArrayIndex;
900             ASSERT(!isMatrixPrimarySubscript ||
901                    (primaryIndex == nullptr && secondaryIndices.empty()));
902             // If primary index is seen and the ancestor is still an index, it must be a direct
903             // index as the secondary one.  Note that if primaryIndex is set, there can only ever be
904             // one more parent of interest, and that's subscripting the second dimension.
905             bool isMatrixSecondarySubscript = primaryIndex != nullptr;
906             ASSERT(!isMatrixSecondarySubscript || (opIsIndex && !isArrayIndex));
907 
908             if (requiresTransformation && isMatrixPrimarySubscript)
909             {
910                 ASSERT(indices.size() == 1);
911                 primaryIndex = indices.front();
912 
913                 // Default the secondary indices to include every row.  If there's a secondary
914                 // subscript provided, it will override this.
915                 int rows = previousAncestorType.getRows();
916                 for (int r = 0; r < rows; ++r)
917                 {
918                     secondaryIndices.push_back(CreateIndexNode(r));
919                 }
920             }
921             else if (isMatrixSecondarySubscript)
922             {
923                 ASSERT(requiresTransformation);
924 
925                 secondaryIndices = indices;
926 
927                 // Indices after this point are not interesting.  There can't actually be any other
928                 // index nodes other than desktop GLSL's swizzles on scalars, like M[1][2].yyy.
929                 ++accessorIndex;
930                 break;
931             }
932             else
933             {
934                 // Replicate the expression otherwise.
935                 baseExpression =
936                     ReplicateIndexNode(mSymbolTable, ancestor, baseExpression, &indices);
937 
938                 const TType &ancestorType = ancestor->getType();
939                 structure                 = ancestorType.getStruct();
940 
941                 requiresTransformation =
942                     requiresTransformation ||
943                     IsConvertedField(ancestor, mInterfaceBlockFieldConvertedIn);
944 
945                 // If we reach a point where the expression is neither a matrix-containing struct
946                 // nor a matrix, there's no transformation required.  This can happen if we decend
947                 // through a struct marked with row-major but arrive at a member that doesn't
948                 // include a matrix.
949                 if (!ancestorType.isMatrix() && !ancestorType.isStructureContainingMatrices())
950                 {
951                     requiresTransformation = false;
952                 }
953             }
954 
955             previousAncestor = ancestor;
956             ++accessorIndex;
957         }
958 
959         TIntermNode *originalExpression =
960             accessorIndex == 0 ? symbol : getAncestorNode(accessorIndex - 1);
961         TIntermNode *accessor = getAncestorNode(accessorIndex);
962 
963         // if accessor is EOpArrayLength, we don't need to perform any transformations either.
964         // Note that this only applies to unsized arrays, as the RemoveArrayLengthMethod()
965         // transformation would have removed this operation otherwise.
966         TIntermUnary *accessorAsUnary = accessor->getAsUnaryNode();
967         if (requiresTransformation && accessorAsUnary && accessorAsUnary->getOp() == EOpArrayLength)
968         {
969             ASSERT(accessorAsUnary->getOperand() == originalExpression);
970             ASSERT(accessorAsUnary->getOperand()->getType().isUnsizedArray());
971 
972             requiresTransformation = false;
973 
974             // We need to replace the whole expression including the EOpArrayLength, to avoid
975             // confusing the replacement code as the original and new expressions don't have the
976             // same type (one is the transpose of the other).  This doesn't affect the .length()
977             // operation, so this replacement is ok, though it's not worth special-casing this in
978             // the node replacement algorithm.
979             //
980             // Note: the |if (!requiresTransformation)| immediately below will be entered after
981             // this.
982             originalExpression = accessor;
983             accessor           = getAncestorNode(accessorIndex + 1);
984             baseExpression     = new TIntermUnary(EOpArrayLength, baseExpression, nullptr);
985         }
986 
987         if (!requiresTransformation)
988         {
989             ASSERT(primaryIndex == nullptr);
990             queueReplacementWithParent(accessor, originalExpression, baseExpression,
991                                        OriginalNode::IS_DROPPED);
992 
993             RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
994             traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
995             return;
996         }
997 
998         ASSERT(structure == nullptr || primaryIndex == nullptr);
999         ASSERT(structure != nullptr || baseExpression->getType().isMatrix());
1000 
1001         // At the end, we can determine if the expression is being read from or written to (or both,
1002         // if sent as an inout parameter to a function).  For the sake of the transformation, the
1003         // left-hand side of operations like += can be treated as "written to", without necessarily
1004         // "read from".
1005         bool isRead  = false;
1006         bool isWrite = false;
1007 
1008         determineAccess(originalExpression, accessor, &isRead, &isWrite);
1009 
1010         ASSERT(isRead || isWrite);
1011 
1012         TIntermTyped *readExpression = nullptr;
1013         if (isRead)
1014         {
1015             readExpression = transformReadExpression(
1016                 baseExpression, primaryIndex, &secondaryIndices, structure, &prependStatements);
1017 
1018             // If both read from and written to (i.e. passed to inout parameter), store the
1019             // expression in a temp variable and pass that to the function.
1020             if (isWrite)
1021             {
1022                 readExpression =
1023                     CopyToTempVariable(mSymbolTable, readExpression, &prependStatements);
1024             }
1025 
1026             // Replace the original expression with the transformed one.  Read transformations
1027             // always generate a single expression that can be used in place of the original (as
1028             // oppposed to write transformations that can generate multiple statements).
1029             queueReplacementWithParent(accessor, originalExpression, readExpression,
1030                                        OriginalNode::IS_DROPPED);
1031         }
1032 
1033         TIntermSequence postTransformPrependStatements;
1034         TIntermSequence *writeStatements = &appendStatements;
1035         TOperator assignmentOperator     = EOpAssign;
1036 
1037         if (isWrite)
1038         {
1039             TIntermTyped *valueExpression = readExpression;
1040 
1041             if (!valueExpression)
1042             {
1043                 // If there's already a read expression, this was an inout parameter and
1044                 // |valueExpression| will contain the temp variable that was passed to the function
1045                 // instead.
1046                 //
1047                 // If not, then the modification is either through being passed as an out parameter
1048                 // to a function, or an assignment.  In the former case, create a temp variable to
1049                 // be passed to the function.  In the latter case, create a temp variable that holds
1050                 // the right hand side expression.
1051                 //
1052                 // In either case, use that temp value as the value to assign to |baseExpression|.
1053 
1054                 TVariable *temp =
1055                     CreateTempVariable(mSymbolTable, &originalExpression->getAsTyped()->getType());
1056                 TIntermDeclaration *tempDecl = nullptr;
1057 
1058                 valueExpression = new TIntermSymbol(temp);
1059 
1060                 TIntermBinary *assignment = accessor->getAsBinaryNode();
1061                 if (assignment)
1062                 {
1063                     assignmentOperator = assignment->getOp();
1064                     ASSERT(IsAssignment(assignmentOperator));
1065 
1066                     // We are converting the assignment to the left-hand side of an expression in
1067                     // the form M=exp.  A subexpression of exp itself could require a
1068                     // transformation.  This complicates things as there would be two replacements:
1069                     //
1070                     // - Replace M=exp with temp (because the return value of the assignment could
1071                     //   be used)
1072                     // - Replace exp with exp2, where parent is M=exp
1073                     //
1074                     // The second replacement however is ineffective as the whole of M=exp is
1075                     // already transformed.  What's worse, M=exp is transformed without taking exp's
1076                     // transformations into account.  To address this issue, this same traverser is
1077                     // called on the right-hand side expression, with a special flag such that it
1078                     // only processes that expression.
1079                     //
1080                     RewriteRowMajorMatricesTraverser *outerTraverser =
1081                         mOuterTraverser ? mOuterTraverser : this;
1082                     RewriteRowMajorMatricesTraverser rhsTraverser(
1083                         mSymbolTable, outerTraverser, mInterfaceBlockMap,
1084                         mInterfaceBlockFieldConvertedIn, mStructMapOut, mCopyFunctionDefinitionsOut,
1085                         assignment);
1086                     getRootNode()->traverse(&rhsTraverser);
1087                     bool valid = rhsTraverser.updateTree(mCompiler, getRootNode());
1088                     ASSERT(valid);
1089 
1090                     tempDecl = CreateTempInitDeclarationNode(temp, assignment->getRight());
1091 
1092                     // Replace the whole assignment expression with the right-hand side as a read
1093                     // expression, in case the result of the assignment is used.  For example, this
1094                     // transforms:
1095                     //
1096                     //     if ((M += exp) == X)
1097                     //     {
1098                     //         // use M
1099                     //     }
1100                     //
1101                     // to:
1102                     //
1103                     //     temp = exp;
1104                     //     M += transform(temp);
1105                     //     if (transform(M) == X)
1106                     //     {
1107                     //         // use M
1108                     //     }
1109                     //
1110                     // Note that in this case the assignment to M must be prepended in the parent
1111                     // block.  In contrast, when sent to a function, the assignment to M should be
1112                     // done after the current function call is done.
1113                     //
1114                     // If the read from M itself (to replace assigmnet) needs to generate extra
1115                     // statements, they should be appended after the statements that write to M.
1116                     // These statements are stored in postTransformPrependStatements and appended to
1117                     // prependStatements in the end.
1118                     //
1119                     writeStatements = &prependStatements;
1120 
1121                     TIntermTyped *assignmentResultExpression = transformReadExpression(
1122                         baseExpression->deepCopy(), primaryIndex, &secondaryIndices, structure,
1123                         &postTransformPrependStatements);
1124 
1125                     // Replace the whole assignment, instead of just the right hand side.
1126                     TIntermNode *accessorParent = getAncestorNode(accessorIndex + 1);
1127                     queueReplacementWithParent(accessorParent, accessor, assignmentResultExpression,
1128                                                OriginalNode::IS_DROPPED);
1129                 }
1130                 else
1131                 {
1132                     tempDecl = CreateTempDeclarationNode(temp);
1133 
1134                     // Replace the write expression (a function call argument) with the temp
1135                     // variable.
1136                     queueReplacementWithParent(accessor, originalExpression, valueExpression,
1137                                                OriginalNode::IS_DROPPED);
1138                 }
1139                 prependStatements.push_back(tempDecl);
1140             }
1141 
1142             if (isRead)
1143             {
1144                 baseExpression = baseExpression->deepCopy();
1145             }
1146             transformWriteExpression(baseExpression, primaryIndex, &secondaryIndices, structure,
1147                                      valueExpression, assignmentOperator, writeStatements);
1148         }
1149 
1150         prependStatements.insert(prependStatements.end(), postTransformPrependStatements.begin(),
1151                                  postTransformPrependStatements.end());
1152 
1153         RewriteRowMajorMatricesTraverser *traverser = mOuterTraverser ? mOuterTraverser : this;
1154         traverser->insertStatementsInParentBlock(prependStatements, appendStatements);
1155     }
1156 
transformReadExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermSequence * prependStatements)1157     TIntermTyped *transformReadExpression(TIntermTyped *baseExpression,
1158                                           TIntermNode *primaryIndex,
1159                                           TIntermSequence *secondaryIndices,
1160                                           const TStructure *structure,
1161                                           TIntermSequence *prependStatements)
1162     {
1163         const TType &baseExpressionType = baseExpression->getType();
1164 
1165         if (structure)
1166         {
1167             ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1168             ASSERT(mStructMapOut->count(structure) != 0);
1169             ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1170 
1171             // Declare copy-from-converted-to-original-struct function (if not already).
1172             declareStructCopyToOriginal(structure);
1173 
1174             const TFunction *copyToOriginal = (*mStructMapOut)[structure].copyToOriginal;
1175 
1176             if (baseExpressionType.isArray())
1177             {
1178                 // If base expression is an array, transform every element.
1179                 TransformArrayHelper transformHelper(baseExpression);
1180 
1181                 TIntermTyped *element = nullptr;
1182                 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1183                 {
1184                     TIntermTyped *transformedElement =
1185                         CreateStructCopyCall(copyToOriginal, element);
1186                     transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1187                                                       prependStatements);
1188                 }
1189                 return transformHelper.constructReadTransformExpression();
1190             }
1191             else
1192             {
1193                 // If not reading an array, the result is simply a call to this function with the
1194                 // base expression.
1195                 return CreateStructCopyCall(copyToOriginal, baseExpression);
1196             }
1197         }
1198 
1199         // If not indexed, the result is transpose(exp)
1200         if (primaryIndex == nullptr)
1201         {
1202             ASSERT(secondaryIndices->empty());
1203 
1204             if (baseExpressionType.isArray())
1205             {
1206                 // If array, transpose every element.
1207                 TransformArrayHelper transformHelper(baseExpression);
1208 
1209                 TIntermTyped *element = nullptr;
1210                 while ((element = transformHelper.getNextElement(nullptr, nullptr)) != nullptr)
1211                 {
1212                     TIntermTyped *transformedElement = CreateTransposeCall(mSymbolTable, element);
1213                     transformHelper.accumulateForRead(mSymbolTable, transformedElement,
1214                                                       prependStatements);
1215                 }
1216                 return transformHelper.constructReadTransformExpression();
1217             }
1218             else
1219             {
1220                 return CreateTransposeCall(mSymbolTable, baseExpression);
1221             }
1222         }
1223 
1224         // If indexed the result is a vector (or just one element) where the primary and secondary
1225         // indices are swapped.
1226         ASSERT(!secondaryIndices->empty());
1227 
1228         TOperator primaryIndexOp          = GetIndexOp(primaryIndex);
1229         TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1230 
1231         TIntermSequence transposedColumn;
1232         for (TIntermNode *secondaryIndex : *secondaryIndices)
1233         {
1234             TOperator secondaryIndexOp          = GetIndexOp(secondaryIndex);
1235             TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1236 
1237             TIntermBinary *colIndexed = new TIntermBinary(
1238                 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1239             TIntermBinary *colRowIndexed =
1240                 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1241 
1242             transposedColumn.push_back(colRowIndexed);
1243         }
1244 
1245         if (secondaryIndices->size() == 1)
1246         {
1247             // If only one element, return that directly.
1248             return transposedColumn.front()->getAsTyped();
1249         }
1250 
1251         // Otherwise create a constructor with the appropriate dimension.
1252         TType *vecType = new TType(baseExpressionType.getBasicType(), secondaryIndices->size());
1253         return TIntermAggregate::CreateConstructor(*vecType, &transposedColumn);
1254     }
1255 
transformWriteExpression(TIntermTyped * baseExpression,TIntermNode * primaryIndex,TIntermSequence * secondaryIndices,const TStructure * structure,TIntermTyped * valueExpression,TOperator assignmentOperator,TIntermSequence * writeStatements)1256     void transformWriteExpression(TIntermTyped *baseExpression,
1257                                   TIntermNode *primaryIndex,
1258                                   TIntermSequence *secondaryIndices,
1259                                   const TStructure *structure,
1260                                   TIntermTyped *valueExpression,
1261                                   TOperator assignmentOperator,
1262                                   TIntermSequence *writeStatements)
1263     {
1264         const TType &baseExpressionType = baseExpression->getType();
1265 
1266         if (structure)
1267         {
1268             ASSERT(primaryIndex == nullptr && secondaryIndices->empty());
1269             ASSERT(mStructMapOut->count(structure) != 0);
1270             ASSERT((*mStructMapOut)[structure].convertedStruct != nullptr);
1271 
1272             // Declare copy-to-converted-from-original-struct function (if not already).
1273             declareStructCopyFromOriginal(structure);
1274 
1275             // The result is call to this function with the value expression assigned to base
1276             // expression.
1277             const TFunction *copyFromOriginal = (*mStructMapOut)[structure].copyFromOriginal;
1278 
1279             if (baseExpressionType.isArray())
1280             {
1281                 // If array, assign every element.
1282                 TransformArrayHelper transformHelper(baseExpression);
1283 
1284                 TIntermTyped *element      = nullptr;
1285                 TIntermTyped *valueElement = nullptr;
1286                 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1287                        nullptr)
1288                 {
1289                     TIntermTyped *functionCall =
1290                         CreateStructCopyCall(copyFromOriginal, valueElement);
1291                     writeStatements->push_back(new TIntermBinary(EOpAssign, element, functionCall));
1292                 }
1293             }
1294             else
1295             {
1296                 TIntermTyped *functionCall =
1297                     CreateStructCopyCall(copyFromOriginal, valueExpression->deepCopy());
1298                 writeStatements->push_back(
1299                     new TIntermBinary(EOpAssign, baseExpression, functionCall));
1300             }
1301 
1302             return;
1303         }
1304 
1305         // If not indexed, the result is transpose(exp)
1306         if (primaryIndex == nullptr)
1307         {
1308             ASSERT(secondaryIndices->empty());
1309 
1310             if (baseExpressionType.isArray())
1311             {
1312                 // If array, assign every element.
1313                 TransformArrayHelper transformHelper(baseExpression);
1314 
1315                 TIntermTyped *element      = nullptr;
1316                 TIntermTyped *valueElement = nullptr;
1317                 while ((element = transformHelper.getNextElement(valueExpression, &valueElement)) !=
1318                        nullptr)
1319                 {
1320                     TIntermTyped *valueTransposed = CreateTransposeCall(mSymbolTable, valueElement);
1321                     writeStatements->push_back(
1322                         new TIntermBinary(EOpAssign, element, valueTransposed));
1323                 }
1324             }
1325             else
1326             {
1327                 TIntermTyped *valueTransposed =
1328                     CreateTransposeCall(mSymbolTable, valueExpression->deepCopy());
1329                 writeStatements->push_back(
1330                     new TIntermBinary(assignmentOperator, baseExpression, valueTransposed));
1331             }
1332 
1333             return;
1334         }
1335 
1336         // If indexed, create one assignment per secondary index.  If the right-hand side is a
1337         // scalar, it's used with every assignment.  If it's a vector, the assignment is
1338         // per-component.  The right-hand side cannot be a matrix as that would imply left-hand
1339         // side being a matrix too, which is covered above where |primaryIndex == nullptr|.
1340         ASSERT(!secondaryIndices->empty());
1341 
1342         bool isValueExpressionScalar = valueExpression->getType().getNominalSize() == 1;
1343         ASSERT(isValueExpressionScalar || valueExpression->getType().getNominalSize() ==
1344                                               static_cast<int>(secondaryIndices->size()));
1345 
1346         TOperator primaryIndexOp          = GetIndexOp(primaryIndex);
1347         TIntermTyped *primaryIndexAsTyped = primaryIndex->getAsTyped();
1348 
1349         for (TIntermNode *secondaryIndex : *secondaryIndices)
1350         {
1351             TOperator secondaryIndexOp          = GetIndexOp(secondaryIndex);
1352             TIntermTyped *secondaryIndexAsTyped = secondaryIndex->getAsTyped();
1353 
1354             TIntermBinary *colIndexed = new TIntermBinary(
1355                 secondaryIndexOp, baseExpression->deepCopy(), secondaryIndexAsTyped->deepCopy());
1356             TIntermBinary *colRowIndexed =
1357                 new TIntermBinary(primaryIndexOp, colIndexed, primaryIndexAsTyped->deepCopy());
1358 
1359             TIntermTyped *valueExpressionIndexed = valueExpression->deepCopy();
1360             if (!isValueExpressionScalar)
1361             {
1362                 valueExpressionIndexed = new TIntermBinary(secondaryIndexOp, valueExpressionIndexed,
1363                                                            secondaryIndexAsTyped->deepCopy());
1364             }
1365 
1366             writeStatements->push_back(
1367                 new TIntermBinary(assignmentOperator, colRowIndexed, valueExpressionIndexed));
1368         }
1369     }
1370 
getCopyStructFieldFunction(const TType * fromFieldType,const TType * toFieldType,bool isCopyToOriginal)1371     const TFunction *getCopyStructFieldFunction(const TType *fromFieldType,
1372                                                 const TType *toFieldType,
1373                                                 bool isCopyToOriginal)
1374     {
1375         ASSERT(fromFieldType->getStruct());
1376         ASSERT(toFieldType->getStruct());
1377 
1378         // If copying from or to the original struct, the "to" field struct could require
1379         // conversion to or from the "from" field struct.  |isCopyToOriginal| tells us if we
1380         // should expect to find toField or fromField in mStructMapOut, if true or false
1381         // respectively.
1382         const TFunction *fieldCopyFunction = nullptr;
1383         if (isCopyToOriginal)
1384         {
1385             const TStructure *toFieldStruct = toFieldType->getStruct();
1386 
1387             auto iter = mStructMapOut->find(toFieldStruct);
1388             if (iter != mStructMapOut->end())
1389             {
1390                 declareStructCopyToOriginal(toFieldStruct);
1391                 fieldCopyFunction = iter->second.copyToOriginal;
1392             }
1393         }
1394         else
1395         {
1396             const TStructure *fromFieldStruct = fromFieldType->getStruct();
1397 
1398             auto iter = mStructMapOut->find(fromFieldStruct);
1399             if (iter != mStructMapOut->end())
1400             {
1401                 declareStructCopyFromOriginal(fromFieldStruct);
1402                 fieldCopyFunction = iter->second.copyFromOriginal;
1403             }
1404         }
1405 
1406         return fieldCopyFunction;
1407     }
1408 
addFieldCopy(TIntermBlock * body,TIntermTyped * to,TIntermTyped * from,bool isCopyToOriginal)1409     void addFieldCopy(TIntermBlock *body,
1410                       TIntermTyped *to,
1411                       TIntermTyped *from,
1412                       bool isCopyToOriginal)
1413     {
1414         const TType &fromType = from->getType();
1415         const TType &toType   = to->getType();
1416 
1417         TIntermTyped *rhs = from;
1418 
1419         if (fromType.getStruct())
1420         {
1421             const TFunction *fieldCopyFunction =
1422                 getCopyStructFieldFunction(&fromType, &toType, isCopyToOriginal);
1423 
1424             if (fieldCopyFunction)
1425             {
1426                 rhs = CreateStructCopyCall(fieldCopyFunction, from);
1427             }
1428         }
1429         else if (fromType.isMatrix())
1430         {
1431             rhs = CreateTransposeCall(mSymbolTable, from);
1432         }
1433 
1434         body->appendStatement(new TIntermBinary(EOpAssign, to, rhs));
1435     }
1436 
declareStructCopy(const TStructure * from,const TStructure * to,bool isCopyToOriginal)1437     TFunction *declareStructCopy(const TStructure *from,
1438                                  const TStructure *to,
1439                                  bool isCopyToOriginal)
1440     {
1441         TType *fromType = new TType(from, true);
1442         TType *toType   = new TType(to, true);
1443 
1444         // Create the parameter and return value variables.
1445         TVariable *fromVar = new TVariable(mSymbolTable, ImmutableString("from"), fromType,
1446                                            SymbolType::AngleInternal);
1447         TVariable *toVar =
1448             new TVariable(mSymbolTable, ImmutableString("to"), toType, SymbolType::AngleInternal);
1449 
1450         TIntermSymbol *fromSymbol = new TIntermSymbol(fromVar);
1451         TIntermSymbol *toSymbol   = new TIntermSymbol(toVar);
1452 
1453         // Create the function body as statements are generated.
1454         TIntermBlock *body = new TIntermBlock;
1455 
1456         // Declare the result variable.
1457         TIntermDeclaration *toDecl = new TIntermDeclaration();
1458         toDecl->appendDeclarator(toSymbol);
1459         body->appendStatement(toDecl);
1460 
1461         // Iterate over fields of the struct and copy one by one, transposing the matrices.  If a
1462         // struct is encountered that requires a transformation, this function is recursively
1463         // called.  As a result, it is important that the copy functions are placed in the code in
1464         // order.
1465         const TFieldList &fromFields = from->fields();
1466         const TFieldList &toFields   = to->fields();
1467         ASSERT(fromFields.size() == toFields.size());
1468 
1469         for (size_t fieldIndex = 0; fieldIndex < fromFields.size(); ++fieldIndex)
1470         {
1471             TIntermTyped *fieldIndexNode = CreateIndexNode(static_cast<int>(fieldIndex));
1472 
1473             TIntermTyped *fromField =
1474                 new TIntermBinary(EOpIndexDirectStruct, fromSymbol->deepCopy(), fieldIndexNode);
1475             TIntermTyped *toField = new TIntermBinary(EOpIndexDirectStruct, toSymbol->deepCopy(),
1476                                                       fieldIndexNode->deepCopy());
1477 
1478             const TType *fromFieldType = fromFields[fieldIndex]->type();
1479             bool isStructOrMatrix      = fromFieldType->getStruct() || fromFieldType->isMatrix();
1480 
1481             if (fromFieldType->isArray() && isStructOrMatrix)
1482             {
1483                 // If struct or matrix array, we need to copy element by element.
1484                 TransformArrayHelper transformHelper(toField);
1485 
1486                 TIntermTyped *toElement   = nullptr;
1487                 TIntermTyped *fromElement = nullptr;
1488                 while ((toElement = transformHelper.getNextElement(fromField, &fromElement)) !=
1489                        nullptr)
1490                 {
1491                     addFieldCopy(body, toElement, fromElement, isCopyToOriginal);
1492                 }
1493             }
1494             else
1495             {
1496                 addFieldCopy(body, toField, fromField, isCopyToOriginal);
1497             }
1498         }
1499 
1500         // Add return statement.
1501         body->appendStatement(new TIntermBranch(EOpReturn, toSymbol->deepCopy()));
1502 
1503         // Declare the function
1504         TFunction *copyFunction = new TFunction(mSymbolTable, kEmptyImmutableString,
1505                                                 SymbolType::AngleInternal, toType, true);
1506         copyFunction->addParameter(fromVar);
1507 
1508         TIntermFunctionDefinition *functionDef =
1509             CreateInternalFunctionDefinitionNode(*copyFunction, body);
1510         mCopyFunctionDefinitionsOut->push_back(functionDef);
1511 
1512         return copyFunction;
1513     }
1514 
declareStructCopyFromOriginal(const TStructure * structure)1515     void declareStructCopyFromOriginal(const TStructure *structure)
1516     {
1517         StructConversionData *structData = &(*mStructMapOut)[structure];
1518         if (structData->copyFromOriginal)
1519         {
1520             return;
1521         }
1522 
1523         structData->copyFromOriginal =
1524             declareStructCopy(structure, structData->convertedStruct, false);
1525     }
1526 
declareStructCopyToOriginal(const TStructure * structure)1527     void declareStructCopyToOriginal(const TStructure *structure)
1528     {
1529         StructConversionData *structData = &(*mStructMapOut)[structure];
1530         if (structData->copyToOriginal)
1531         {
1532             return;
1533         }
1534 
1535         structData->copyToOriginal =
1536             declareStructCopy(structData->convertedStruct, structure, true);
1537     }
1538 
1539     TCompiler *mCompiler;
1540 
1541     // This traverser can call itself to transform a subexpression before moving on.  However, it
1542     // needs to accumulate conversion functions in inner passes.  The fields below marked with Out
1543     // or In are inherited from the outer pass (for inner passes), or point to storage fields in
1544     // mOuterPass (for the outer pass).  The latter should not be used by the inner passes as they
1545     // would be empty, so they are placed inside a struct to make them explicit.
1546     struct
1547     {
1548         StructMap structMap;
1549         InterfaceBlockMap interfaceBlockMap;
1550         InterfaceBlockFieldConverted interfaceBlockFieldConverted;
1551         TIntermSequence copyFunctionDefinitions;
1552     } mOuterPass;
1553 
1554     // A map from structures with matrices to their converted version.
1555     StructMap *mStructMapOut;
1556     // A map from interface block instances with row-major matrices to their converted variable.  If
1557     // an interface block is nameless, its fields are placed in this map instead.  When a variable
1558     // in this map is encountered, it signals the start of an expression that my need conversion,
1559     // which is either "interfaceBlock.field..." or "field..." if nameless.
1560     InterfaceBlockMap *mInterfaceBlockMap;
1561     // A map from interface block fields to whether they need to be converted.  If a field was
1562     // already column-major, it shouldn't be transposed.
1563     const InterfaceBlockFieldConverted &mInterfaceBlockFieldConvertedIn;
1564 
1565     TIntermSequence *mCopyFunctionDefinitionsOut;
1566 
1567     // If set, it's an inner pass and this will point to the outer pass traverser.  All statement
1568     // insertions are stored in the outer traverser and applied at once in the end.  This prevents
1569     // the inner passes from adding statements which invalidates the outer traverser's statement
1570     // position tracking.
1571     RewriteRowMajorMatricesTraverser *mOuterTraverser;
1572 
1573     // If set, it's an inner pass that should only process the right-hand side of this particular
1574     // node.
1575     TIntermBinary *mInnerPassRoot;
1576     bool mIsProcessingInnerPassSubtree;
1577 };
1578 
1579 }  // anonymous namespace
1580 
RewriteRowMajorMatrices(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)1581 bool RewriteRowMajorMatrices(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
1582 {
1583     RewriteRowMajorMatricesTraverser traverser(compiler, symbolTable);
1584     root->traverse(&traverser);
1585     if (!traverser.updateTree(compiler, root))
1586     {
1587         return false;
1588     }
1589 
1590     size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
1591     root->insertChildNodes(firstFunctionIndex, *traverser.getStructCopyFunctions());
1592 
1593     return compiler->validateAST(root);
1594 }
1595 }  // namespace sh
1596