1 //
2 // Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10 
11 #include "common/debug.h"
12 #include "compiler/translator/ScalarizeVecAndMatConstructorArgs.h"
13 
14 #include <algorithm>
15 
16 #include "angle_gl.h"
17 #include "common/angleutils.h"
18 #include "compiler/translator/IntermNodePatternMatcher.h"
19 #include "compiler/translator/IntermNode_util.h"
20 #include "compiler/translator/IntermTraverse.h"
21 
22 namespace sh
23 {
24 
25 namespace
26 {
27 
ConstructVectorIndexBinaryNode(TIntermSymbol * symbolNode,int index)28 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
29 {
30     return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
31 }
32 
ConstructMatrixIndexBinaryNode(TIntermSymbol * symbolNode,int colIndex,int rowIndex)33 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermSymbol *symbolNode, int colIndex, int rowIndex)
34 {
35     TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
36 
37     return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
38 }
39 
40 class ScalarizeArgsTraverser : public TIntermTraverser
41 {
42   public:
ScalarizeArgsTraverser(sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)43     ScalarizeArgsTraverser(sh::GLenum shaderType,
44                            bool fragmentPrecisionHigh,
45                            TSymbolTable *symbolTable)
46         : TIntermTraverser(true, false, false, symbolTable),
47           mShaderType(shaderType),
48           mFragmentPrecisionHigh(fragmentPrecisionHigh),
49           mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
50     {
51     }
52 
53   protected:
54     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
55     bool visitBlock(Visit visit, TIntermBlock *node) override;
56 
57   private:
58     void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
59 
60     // If we have the following code:
61     //   mat4 m(0);
62     //   vec4 v(1, m);
63     // We will rewrite to:
64     //   mat4 m(0);
65     //   mat4 s0 = m;
66     //   vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
67     // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
68     // way the possible side effects of the constructor argument will only be evaluated once.
69     TVariable *createTempVariable(TIntermTyped *original);
70 
71     std::vector<TIntermSequence> mBlockStack;
72 
73     sh::GLenum mShaderType;
74     bool mFragmentPrecisionHigh;
75 
76     IntermNodePatternMatcher mNodesToScalarize;
77 };
78 
visitAggregate(Visit visit,TIntermAggregate * node)79 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
80 {
81     ASSERT(visit == PreVisit);
82     if (mNodesToScalarize.match(node, getParentNode()))
83     {
84         if (node->getType().isVector())
85         {
86             scalarizeArgs(node, false, true);
87         }
88         else
89         {
90             ASSERT(node->getType().isMatrix());
91             scalarizeArgs(node, true, false);
92         }
93     }
94     return true;
95 }
96 
visitBlock(Visit visit,TIntermBlock * node)97 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
98 {
99     mBlockStack.push_back(TIntermSequence());
100     {
101         for (TIntermNode *child : *node->getSequence())
102         {
103             ASSERT(child != nullptr);
104             child->traverse(this);
105             mBlockStack.back().push_back(child);
106         }
107     }
108     if (mBlockStack.back().size() > node->getSequence()->size())
109     {
110         node->getSequence()->clear();
111         *(node->getSequence()) = mBlockStack.back();
112     }
113     mBlockStack.pop_back();
114     return false;
115 }
116 
scalarizeArgs(TIntermAggregate * aggregate,bool scalarizeVector,bool scalarizeMatrix)117 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
118                                            bool scalarizeVector,
119                                            bool scalarizeMatrix)
120 {
121     ASSERT(aggregate);
122     ASSERT(!aggregate->isArray());
123     int size                  = static_cast<int>(aggregate->getType().getObjectSize());
124     TIntermSequence *sequence = aggregate->getSequence();
125     TIntermSequence originalArgs(*sequence);
126     sequence->clear();
127     for (TIntermNode *originalArgNode : originalArgs)
128     {
129         ASSERT(size > 0);
130         TIntermTyped *originalArg = originalArgNode->getAsTyped();
131         ASSERT(originalArg);
132         TVariable *argVariable = createTempVariable(originalArg);
133         if (originalArg->isScalar())
134         {
135             sequence->push_back(CreateTempSymbolNode(argVariable));
136             size--;
137         }
138         else if (originalArg->isVector())
139         {
140             if (scalarizeVector)
141             {
142                 int repeat = std::min(size, originalArg->getNominalSize());
143                 size -= repeat;
144                 for (int index = 0; index < repeat; ++index)
145                 {
146                     TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
147                     TIntermBinary *newNode    = ConstructVectorIndexBinaryNode(symbolNode, index);
148                     sequence->push_back(newNode);
149                 }
150             }
151             else
152             {
153                 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
154                 sequence->push_back(symbolNode);
155                 size -= originalArg->getNominalSize();
156             }
157         }
158         else
159         {
160             ASSERT(originalArg->isMatrix());
161             if (scalarizeMatrix)
162             {
163                 int colIndex = 0, rowIndex = 0;
164                 int repeat = std::min(size, originalArg->getCols() * originalArg->getRows());
165                 size -= repeat;
166                 while (repeat > 0)
167                 {
168                     TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
169                     TIntermBinary *newNode =
170                         ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
171                     sequence->push_back(newNode);
172                     rowIndex++;
173                     if (rowIndex >= originalArg->getRows())
174                     {
175                         rowIndex = 0;
176                         colIndex++;
177                     }
178                     repeat--;
179                 }
180             }
181             else
182             {
183                 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
184                 sequence->push_back(symbolNode);
185                 size -= originalArg->getCols() * originalArg->getRows();
186             }
187         }
188     }
189 }
190 
createTempVariable(TIntermTyped * original)191 TVariable *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
192 {
193     ASSERT(original);
194 
195     TType *type = new TType(original->getType());
196     type->setQualifier(EvqTemporary);
197     if (mShaderType == GL_FRAGMENT_SHADER && type->getBasicType() == EbtFloat &&
198         type->getPrecision() == EbpUndefined)
199     {
200         // We use the highest available precision for the temporary variable
201         // to avoid computing the actual precision using the rules defined
202         // in GLSL ES 1.0 Section 4.5.2.
203         type->setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
204     }
205 
206     TVariable *variable = CreateTempVariable(mSymbolTable, type);
207 
208     ASSERT(mBlockStack.size() > 0);
209     TIntermSequence &sequence = mBlockStack.back();
210     TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original);
211     sequence.push_back(declaration);
212 
213     return variable;
214 }
215 
216 }  // namespace anonymous
217 
ScalarizeVecAndMatConstructorArgs(TIntermBlock * root,sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)218 void ScalarizeVecAndMatConstructorArgs(TIntermBlock *root,
219                                        sh::GLenum shaderType,
220                                        bool fragmentPrecisionHigh,
221                                        TSymbolTable *symbolTable)
222 {
223     ScalarizeArgsTraverser scalarizer(shaderType, fragmentPrecisionHigh, symbolTable);
224     root->traverse(&scalarizer);
225 }
226 
227 }  // namespace sh
228