1 //
2 // Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10
11 #include "common/debug.h"
12 #include "compiler/translator/ScalarizeVecAndMatConstructorArgs.h"
13
14 #include <algorithm>
15
16 #include "angle_gl.h"
17 #include "common/angleutils.h"
18 #include "compiler/translator/IntermNode_util.h"
19 #include "compiler/translator/IntermTraverse.h"
20
21 namespace sh
22 {
23
24 namespace
25 {
26
ContainsMatrixNode(const TIntermSequence & sequence)27 bool ContainsMatrixNode(const TIntermSequence &sequence)
28 {
29 for (size_t ii = 0; ii < sequence.size(); ++ii)
30 {
31 TIntermTyped *node = sequence[ii]->getAsTyped();
32 if (node && node->isMatrix())
33 return true;
34 }
35 return false;
36 }
37
ContainsVectorNode(const TIntermSequence & sequence)38 bool ContainsVectorNode(const TIntermSequence &sequence)
39 {
40 for (size_t ii = 0; ii < sequence.size(); ++ii)
41 {
42 TIntermTyped *node = sequence[ii]->getAsTyped();
43 if (node && node->isVector())
44 return true;
45 }
46 return false;
47 }
48
ConstructVectorIndexBinaryNode(TIntermSymbol * symbolNode,int index)49 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
50 {
51 return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
52 }
53
ConstructMatrixIndexBinaryNode(TIntermSymbol * symbolNode,int colIndex,int rowIndex)54 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermSymbol *symbolNode, int colIndex, int rowIndex)
55 {
56 TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
57
58 return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
59 }
60
61 class ScalarizeArgsTraverser : public TIntermTraverser
62 {
63 public:
ScalarizeArgsTraverser(sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)64 ScalarizeArgsTraverser(sh::GLenum shaderType,
65 bool fragmentPrecisionHigh,
66 TSymbolTable *symbolTable)
67 : TIntermTraverser(true, false, false, symbolTable),
68 mShaderType(shaderType),
69 mFragmentPrecisionHigh(fragmentPrecisionHigh)
70 {
71 }
72
73 protected:
74 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
75 bool visitBlock(Visit visit, TIntermBlock *node) override;
76
77 private:
78 void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
79
80 // If we have the following code:
81 // mat4 m(0);
82 // vec4 v(1, m);
83 // We will rewrite to:
84 // mat4 m(0);
85 // mat4 s0 = m;
86 // vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
87 // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
88 // way the possible side effects of the constructor argument will only be evaluated once.
89 void createTempVariable(TIntermTyped *original);
90
91 std::vector<TIntermSequence> mBlockStack;
92
93 sh::GLenum mShaderType;
94 bool mFragmentPrecisionHigh;
95 };
96
visitAggregate(Visit visit,TIntermAggregate * node)97 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
98 {
99 if (visit == PreVisit && node->getOp() == EOpConstruct)
100 {
101 if (node->getType().isVector() && ContainsMatrixNode(*(node->getSequence())))
102 scalarizeArgs(node, false, true);
103 else if (node->getType().isMatrix() && ContainsVectorNode(*(node->getSequence())))
104 scalarizeArgs(node, true, false);
105 }
106 return true;
107 }
108
visitBlock(Visit visit,TIntermBlock * node)109 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
110 {
111 mBlockStack.push_back(TIntermSequence());
112 {
113 for (TIntermNode *child : *node->getSequence())
114 {
115 ASSERT(child != nullptr);
116 child->traverse(this);
117 mBlockStack.back().push_back(child);
118 }
119 }
120 if (mBlockStack.back().size() > node->getSequence()->size())
121 {
122 node->getSequence()->clear();
123 *(node->getSequence()) = mBlockStack.back();
124 }
125 mBlockStack.pop_back();
126 return false;
127 }
128
scalarizeArgs(TIntermAggregate * aggregate,bool scalarizeVector,bool scalarizeMatrix)129 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
130 bool scalarizeVector,
131 bool scalarizeMatrix)
132 {
133 ASSERT(aggregate);
134 ASSERT(!aggregate->isArray());
135 int size = static_cast<int>(aggregate->getType().getObjectSize());
136 TIntermSequence *sequence = aggregate->getSequence();
137 TIntermSequence original(*sequence);
138 sequence->clear();
139 for (size_t ii = 0; ii < original.size(); ++ii)
140 {
141 ASSERT(size > 0);
142 TIntermTyped *node = original[ii]->getAsTyped();
143 ASSERT(node);
144 createTempVariable(node);
145 if (node->isScalar())
146 {
147 sequence->push_back(createTempSymbol(node->getType()));
148 size--;
149 }
150 else if (node->isVector())
151 {
152 if (scalarizeVector)
153 {
154 int repeat = std::min(size, node->getNominalSize());
155 size -= repeat;
156 for (int index = 0; index < repeat; ++index)
157 {
158 TIntermSymbol *symbolNode = createTempSymbol(node->getType());
159 TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index);
160 sequence->push_back(newNode);
161 }
162 }
163 else
164 {
165 TIntermSymbol *symbolNode = createTempSymbol(node->getType());
166 sequence->push_back(symbolNode);
167 size -= node->getNominalSize();
168 }
169 }
170 else
171 {
172 ASSERT(node->isMatrix());
173 if (scalarizeMatrix)
174 {
175 int colIndex = 0, rowIndex = 0;
176 int repeat = std::min(size, node->getCols() * node->getRows());
177 size -= repeat;
178 while (repeat > 0)
179 {
180 TIntermSymbol *symbolNode = createTempSymbol(node->getType());
181 TIntermBinary *newNode =
182 ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
183 sequence->push_back(newNode);
184 rowIndex++;
185 if (rowIndex >= node->getRows())
186 {
187 rowIndex = 0;
188 colIndex++;
189 }
190 repeat--;
191 }
192 }
193 else
194 {
195 TIntermSymbol *symbolNode = createTempSymbol(node->getType());
196 sequence->push_back(symbolNode);
197 size -= node->getCols() * node->getRows();
198 }
199 }
200 }
201 }
202
createTempVariable(TIntermTyped * original)203 void ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
204 {
205 ASSERT(original);
206 nextTemporaryId();
207 TIntermDeclaration *decl = createTempInitDeclaration(original);
208
209 TType type = original->getType();
210 if (mShaderType == GL_FRAGMENT_SHADER && type.getBasicType() == EbtFloat &&
211 type.getPrecision() == EbpUndefined)
212 {
213 // We use the highest available precision for the temporary variable
214 // to avoid computing the actual precision using the rules defined
215 // in GLSL ES 1.0 Section 4.5.2.
216 TIntermBinary *init = decl->getSequence()->at(0)->getAsBinaryNode();
217 init->getTypePointer()->setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
218 init->getLeft()->getTypePointer()->setPrecision(mFragmentPrecisionHigh ? EbpHigh
219 : EbpMedium);
220 }
221
222 ASSERT(mBlockStack.size() > 0);
223 TIntermSequence &sequence = mBlockStack.back();
224 sequence.push_back(decl);
225 }
226
227 } // namespace anonymous
228
ScalarizeVecAndMatConstructorArgs(TIntermBlock * root,sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)229 void ScalarizeVecAndMatConstructorArgs(TIntermBlock *root,
230 sh::GLenum shaderType,
231 bool fragmentPrecisionHigh,
232 TSymbolTable *symbolTable)
233 {
234 ScalarizeArgsTraverser scalarizer(shaderType, fragmentPrecisionHigh, symbolTable);
235 root->traverse(&scalarizer);
236 }
237
238 } // namespace sh
239