1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved. Use of this
3 // source code is governed by a BSD-style license that can be found in the
4 // LICENSE file.
5 //
6 // ReplaceArrayOfMatrixVarying: Find any references to array of matrices varying
7 // and replace it with array of vectors.
8 //
9 
10 #include "compiler/translator/tree_util/ReplaceArrayOfMatrixVarying.h"
11 
12 #include <vector>
13 
14 #include "common/bitset_utils.h"
15 #include "common/debug.h"
16 #include "common/utilities.h"
17 #include "compiler/translator/Compiler.h"
18 #include "compiler/translator/SymbolTable.h"
19 #include "compiler/translator/tree_util/BuiltIn.h"
20 #include "compiler/translator/tree_util/FindMain.h"
21 #include "compiler/translator/tree_util/IntermNode_util.h"
22 #include "compiler/translator/tree_util/IntermTraverse.h"
23 #include "compiler/translator/tree_util/ReplaceVariable.h"
24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
25 #include "compiler/translator/util.h"
26 
27 namespace sh
28 {
29 
30 // We create two variables to replace the given varying:
31 // - The new varying which is an array of vectors to be used at input/ouput only.
32 // - The new global variable which is a same type as given variable, to temporarily be used
33 // as replacements for assignments, arithmetic ops and so on. During input/ouput phrase, this temp
34 // variable will be copied from/to the array of vectors variable above.
35 // NOTE(hqle): Consider eliminating the need for using temp variable.
36 
37 namespace
38 {
39 class CollectVaryingTraverser : public TIntermTraverser
40 {
41   public:
CollectVaryingTraverser(std::vector<const TVariable * > * varyingsOut)42     CollectVaryingTraverser(std::vector<const TVariable *> *varyingsOut)
43         : TIntermTraverser(true, false, false), mVaryingsOut(varyingsOut)
44     {}
45 
visitDeclaration(Visit visit,TIntermDeclaration * node)46     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
47     {
48         const TIntermSequence &sequence = *(node->getSequence());
49 
50         if (sequence.size() != 1)
51         {
52             return false;
53         }
54 
55         TIntermTyped *variableType = sequence.front()->getAsTyped();
56         if (!variableType || !IsVarying(variableType->getQualifier()) ||
57             !variableType->isMatrix() || !variableType->isArray())
58         {
59             return false;
60         }
61 
62         TIntermSymbol *variableSymbol = variableType->getAsSymbolNode();
63         if (!variableSymbol)
64         {
65             return false;
66         }
67 
68         mVaryingsOut->push_back(&variableSymbol->variable());
69 
70         return false;
71     }
72 
73   private:
74     std::vector<const TVariable *> *mVaryingsOut;
75 };
76 }  // namespace
77 
ReplaceArrayOfMatrixVarying(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TVariable * varying)78 ANGLE_NO_DISCARD bool ReplaceArrayOfMatrixVarying(TCompiler *compiler,
79                                                   TIntermBlock *root,
80                                                   TSymbolTable *symbolTable,
81                                                   const TVariable *varying)
82 {
83     const TType &type = varying->getType();
84 
85     // Create global variable to temporarily acts as the given variable in places such as
86     // arithmetic, assignments an so on.
87     TType *tmpReplacementType = new TType(type);
88     tmpReplacementType->setQualifier(EvqGlobal);
89     tmpReplacementType->realize();
90 
91     TVariable *tempReplaceVar = new TVariable(
92         symbolTable, ImmutableString(std::string("ANGLE_AOM_Temp_") + varying->name().data()),
93         tmpReplacementType, SymbolType::AngleInternal);
94 
95     if (!ReplaceVariable(compiler, root, varying, tempReplaceVar))
96     {
97         return false;
98     }
99 
100     // Create array of vectors type
101     TType *varyingReplaceType =
102         new TType(type.getBasicType(), type.getPrecision(), type.getQualifier(),
103                   static_cast<unsigned char>(type.getRows()), 1);
104     varyingReplaceType->setInvariant(type.isInvariant());
105     varyingReplaceType->setMemoryQualifier(type.getMemoryQualifier());
106     varyingReplaceType->setLayoutQualifier(type.getLayoutQualifier());
107     varyingReplaceType->makeArray(type.getCols() * type.getOutermostArraySize());
108     varyingReplaceType->realize();
109 
110     TVariable *varyingReplaceVar =
111         new TVariable(symbolTable, varying->name(), varyingReplaceType, SymbolType::UserDefined);
112 
113     TIntermSymbol *varyingReplaceDeclarator = new TIntermSymbol(varyingReplaceVar);
114     TIntermDeclaration *varyingReplaceDecl  = new TIntermDeclaration;
115     varyingReplaceDecl->appendDeclarator(varyingReplaceDeclarator);
116     root->insertStatement(0, varyingReplaceDecl);
117 
118     // Copy from/to the temp variable
119     TIntermBlock *reassignBlock         = new TIntermBlock;
120     TIntermSymbol *tempReplaceSymbol    = new TIntermSymbol(tempReplaceVar);
121     TIntermSymbol *varyingReplaceSymbol = new TIntermSymbol(varyingReplaceVar);
122     bool isInput                        = IsVaryingIn(type.getQualifier());
123 
124     for (unsigned int i = 0; i < type.getOutermostArraySize(); ++i)
125     {
126         TIntermBinary *tempMatrixIndexed =
127             new TIntermBinary(EOpIndexDirect, tempReplaceSymbol->deepCopy(), CreateIndexNode(i));
128         for (int col = 0; col < type.getCols(); ++col)
129         {
130 
131             TIntermBinary *tempMatrixColIndexed = new TIntermBinary(
132                 EOpIndexDirect, tempMatrixIndexed->deepCopy(), CreateIndexNode(col));
133             TIntermBinary *vectorIndexed =
134                 new TIntermBinary(EOpIndexDirect, varyingReplaceSymbol->deepCopy(),
135                                   CreateIndexNode(i * type.getCols() + col));
136             TIntermBinary *assignment;
137             if (isInput)
138             {
139                 assignment = new TIntermBinary(EOpAssign, tempMatrixColIndexed, vectorIndexed);
140             }
141             else
142             {
143                 assignment = new TIntermBinary(EOpAssign, vectorIndexed, tempMatrixColIndexed);
144             }
145             reassignBlock->appendStatement(assignment);
146         }
147     }
148 
149     if (isInput)
150     {
151         TIntermFunctionDefinition *main = FindMain(root);
152         main->getBody()->insertStatement(0, reassignBlock);
153         return compiler->validateAST(root);
154     }
155     else
156     {
157         return RunAtTheEndOfShader(compiler, root, reassignBlock, symbolTable);
158     }
159 }
160 
ReplaceArrayOfMatrixVaryings(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)161 ANGLE_NO_DISCARD bool ReplaceArrayOfMatrixVaryings(TCompiler *compiler,
162                                                    TIntermBlock *root,
163                                                    TSymbolTable *symbolTable)
164 {
165     std::vector<const TVariable *> arrayOfMatrixVars;
166     CollectVaryingTraverser varCollector(&arrayOfMatrixVars);
167     root->traverse(&varCollector);
168 
169     for (const TVariable *var : arrayOfMatrixVars)
170     {
171         if (!ReplaceArrayOfMatrixVarying(compiler, root, symbolTable, var))
172         {
173             return false;
174         }
175     }
176 
177     return true;
178 }
179 
180 }  // namespace sh
181