1 //
2 // Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10 
11 #include "common/debug.h"
12 #include "compiler/translator/ScalarizeVecAndMatConstructorArgs.h"
13 
14 #include <algorithm>
15 
16 #include "angle_gl.h"
17 #include "common/angleutils.h"
18 #include "compiler/translator/IntermNode_util.h"
19 #include "compiler/translator/IntermTraverse.h"
20 
21 namespace sh
22 {
23 
24 namespace
25 {
26 
ContainsMatrixNode(const TIntermSequence & sequence)27 bool ContainsMatrixNode(const TIntermSequence &sequence)
28 {
29     for (size_t ii = 0; ii < sequence.size(); ++ii)
30     {
31         TIntermTyped *node = sequence[ii]->getAsTyped();
32         if (node && node->isMatrix())
33             return true;
34     }
35     return false;
36 }
37 
ContainsVectorNode(const TIntermSequence & sequence)38 bool ContainsVectorNode(const TIntermSequence &sequence)
39 {
40     for (size_t ii = 0; ii < sequence.size(); ++ii)
41     {
42         TIntermTyped *node = sequence[ii]->getAsTyped();
43         if (node && node->isVector())
44             return true;
45     }
46     return false;
47 }
48 
ConstructVectorIndexBinaryNode(TIntermSymbol * symbolNode,int index)49 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
50 {
51     return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
52 }
53 
ConstructMatrixIndexBinaryNode(TIntermSymbol * symbolNode,int colIndex,int rowIndex)54 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermSymbol *symbolNode, int colIndex, int rowIndex)
55 {
56     TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
57 
58     return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
59 }
60 
61 class ScalarizeArgsTraverser : public TIntermTraverser
62 {
63   public:
ScalarizeArgsTraverser(sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)64     ScalarizeArgsTraverser(sh::GLenum shaderType,
65                            bool fragmentPrecisionHigh,
66                            TSymbolTable *symbolTable)
67         : TIntermTraverser(true, false, false, symbolTable),
68           mShaderType(shaderType),
69           mFragmentPrecisionHigh(fragmentPrecisionHigh)
70     {
71     }
72 
73   protected:
74     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
75     bool visitBlock(Visit visit, TIntermBlock *node) override;
76 
77   private:
78     void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
79 
80     // If we have the following code:
81     //   mat4 m(0);
82     //   vec4 v(1, m);
83     // We will rewrite to:
84     //   mat4 m(0);
85     //   mat4 s0 = m;
86     //   vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
87     // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
88     // way the possible side effects of the constructor argument will only be evaluated once.
89     void createTempVariable(TIntermTyped *original);
90 
91     std::vector<TIntermSequence> mBlockStack;
92 
93     sh::GLenum mShaderType;
94     bool mFragmentPrecisionHigh;
95 };
96 
visitAggregate(Visit visit,TIntermAggregate * node)97 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
98 {
99     if (visit == PreVisit && node->getOp() == EOpConstruct)
100     {
101         if (node->getType().isVector() && ContainsMatrixNode(*(node->getSequence())))
102             scalarizeArgs(node, false, true);
103         else if (node->getType().isMatrix() && ContainsVectorNode(*(node->getSequence())))
104             scalarizeArgs(node, true, false);
105     }
106     return true;
107 }
108 
visitBlock(Visit visit,TIntermBlock * node)109 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
110 {
111     mBlockStack.push_back(TIntermSequence());
112     {
113         for (TIntermNode *child : *node->getSequence())
114         {
115             ASSERT(child != nullptr);
116             child->traverse(this);
117             mBlockStack.back().push_back(child);
118         }
119     }
120     if (mBlockStack.back().size() > node->getSequence()->size())
121     {
122         node->getSequence()->clear();
123         *(node->getSequence()) = mBlockStack.back();
124     }
125     mBlockStack.pop_back();
126     return false;
127 }
128 
scalarizeArgs(TIntermAggregate * aggregate,bool scalarizeVector,bool scalarizeMatrix)129 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
130                                            bool scalarizeVector,
131                                            bool scalarizeMatrix)
132 {
133     ASSERT(aggregate);
134     ASSERT(!aggregate->isArray());
135     int size                  = static_cast<int>(aggregate->getType().getObjectSize());
136     TIntermSequence *sequence = aggregate->getSequence();
137     TIntermSequence original(*sequence);
138     sequence->clear();
139     for (size_t ii = 0; ii < original.size(); ++ii)
140     {
141         ASSERT(size > 0);
142         TIntermTyped *node = original[ii]->getAsTyped();
143         ASSERT(node);
144         createTempVariable(node);
145         if (node->isScalar())
146         {
147             sequence->push_back(createTempSymbol(node->getType()));
148             size--;
149         }
150         else if (node->isVector())
151         {
152             if (scalarizeVector)
153             {
154                 int repeat = std::min(size, node->getNominalSize());
155                 size -= repeat;
156                 for (int index = 0; index < repeat; ++index)
157                 {
158                     TIntermSymbol *symbolNode = createTempSymbol(node->getType());
159                     TIntermBinary *newNode    = ConstructVectorIndexBinaryNode(symbolNode, index);
160                     sequence->push_back(newNode);
161                 }
162             }
163             else
164             {
165                 TIntermSymbol *symbolNode = createTempSymbol(node->getType());
166                 sequence->push_back(symbolNode);
167                 size -= node->getNominalSize();
168             }
169         }
170         else
171         {
172             ASSERT(node->isMatrix());
173             if (scalarizeMatrix)
174             {
175                 int colIndex = 0, rowIndex = 0;
176                 int repeat = std::min(size, node->getCols() * node->getRows());
177                 size -= repeat;
178                 while (repeat > 0)
179                 {
180                     TIntermSymbol *symbolNode = createTempSymbol(node->getType());
181                     TIntermBinary *newNode =
182                         ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
183                     sequence->push_back(newNode);
184                     rowIndex++;
185                     if (rowIndex >= node->getRows())
186                     {
187                         rowIndex = 0;
188                         colIndex++;
189                     }
190                     repeat--;
191                 }
192             }
193             else
194             {
195                 TIntermSymbol *symbolNode = createTempSymbol(node->getType());
196                 sequence->push_back(symbolNode);
197                 size -= node->getCols() * node->getRows();
198             }
199         }
200     }
201 }
202 
createTempVariable(TIntermTyped * original)203 void ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
204 {
205     ASSERT(original);
206     nextTemporaryId();
207     TIntermDeclaration *decl = createTempInitDeclaration(original);
208 
209     TType type = original->getType();
210     if (mShaderType == GL_FRAGMENT_SHADER && type.getBasicType() == EbtFloat &&
211         type.getPrecision() == EbpUndefined)
212     {
213         // We use the highest available precision for the temporary variable
214         // to avoid computing the actual precision using the rules defined
215         // in GLSL ES 1.0 Section 4.5.2.
216         TIntermBinary *init = decl->getSequence()->at(0)->getAsBinaryNode();
217         init->getTypePointer()->setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
218         init->getLeft()->getTypePointer()->setPrecision(mFragmentPrecisionHigh ? EbpHigh
219                                                                                : EbpMedium);
220     }
221 
222     ASSERT(mBlockStack.size() > 0);
223     TIntermSequence &sequence = mBlockStack.back();
224     sequence.push_back(decl);
225 }
226 
227 }  // namespace anonymous
228 
ScalarizeVecAndMatConstructorArgs(TIntermBlock * root,sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)229 void ScalarizeVecAndMatConstructorArgs(TIntermBlock *root,
230                                        sh::GLenum shaderType,
231                                        bool fragmentPrecisionHigh,
232                                        TSymbolTable *symbolTable)
233 {
234     ScalarizeArgsTraverser scalarizer(shaderType, fragmentPrecisionHigh, symbolTable);
235     root->traverse(&scalarizer);
236 }
237 
238 }  // namespace sh
239