1 //
2 // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/ValidateLimitations.h"
8
9 #include "angle_gl.h"
10 #include "compiler/translator/Diagnostics.h"
11 #include "compiler/translator/IntermTraverse.h"
12 #include "compiler/translator/ParseContext.h"
13
14 namespace sh
15 {
16
17 namespace
18 {
19
GetLoopSymbolId(TIntermLoop * loop)20 int GetLoopSymbolId(TIntermLoop *loop)
21 {
22 // Here we assume all the operations are valid, because the loop node is
23 // already validated before this call.
24 TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
25 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
26 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
27
28 return symbol->getId();
29 }
30
31 // Traverses a node to check if it represents a constant index expression.
32 // Definition:
33 // constant-index-expressions are a superset of constant-expressions.
34 // Constant-index-expressions can include loop indices as defined in
35 // GLSL ES 1.0 spec, Appendix A, section 4.
36 // The following are constant-index-expressions:
37 // - Constant expressions
38 // - Loop indices as defined in section 4
39 // - Expressions composed of both of the above
40 class ValidateConstIndexExpr : public TIntermTraverser
41 {
42 public:
ValidateConstIndexExpr(const std::vector<int> & loopSymbols)43 ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
44 : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
45 {
46 }
47
48 // Returns true if the parsed node represents a constant index expression.
isValid() const49 bool isValid() const { return mValid; }
50
visitSymbol(TIntermSymbol * symbol)51 void visitSymbol(TIntermSymbol *symbol) override
52 {
53 // Only constants and loop indices are allowed in a
54 // constant index expression.
55 if (mValid)
56 {
57 bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
58 symbol->getId()) != mLoopSymbolIds.end();
59 mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
60 }
61 }
62
63 private:
64 bool mValid;
65 const std::vector<int> mLoopSymbolIds;
66 };
67
68 // Traverses intermediate tree to ensure that the shader does not exceed the
69 // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
70 class ValidateLimitationsTraverser : public TLValueTrackingTraverser
71 {
72 public:
73 ValidateLimitationsTraverser(sh::GLenum shaderType,
74 TSymbolTable *symbolTable,
75 int shaderVersion,
76 TDiagnostics *diagnostics);
77
78 void visitSymbol(TIntermSymbol *node) override;
79 bool visitBinary(Visit, TIntermBinary *) override;
80 bool visitLoop(Visit, TIntermLoop *) override;
81
82 private:
83 void error(TSourceLoc loc, const char *reason, const char *token);
84
85 bool withinLoopBody() const;
86 bool isLoopIndex(TIntermSymbol *symbol);
87 bool validateLoopType(TIntermLoop *node);
88
89 bool validateForLoopHeader(TIntermLoop *node);
90 // If valid, return the index symbol id; Otherwise, return -1.
91 int validateForLoopInit(TIntermLoop *node);
92 bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
93 bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
94
95 // Returns true if indexing does not exceed the minimum functionality
96 // mandated in GLSL 1.0 spec, Appendix A, Section 5.
97 bool isConstExpr(TIntermNode *node);
98 bool isConstIndexExpr(TIntermNode *node);
99 bool validateIndexing(TIntermBinary *node);
100
101 sh::GLenum mShaderType;
102 TDiagnostics *mDiagnostics;
103 std::vector<int> mLoopSymbolIds;
104 };
105
ValidateLimitationsTraverser(sh::GLenum shaderType,TSymbolTable * symbolTable,int shaderVersion,TDiagnostics * diagnostics)106 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
107 TSymbolTable *symbolTable,
108 int shaderVersion,
109 TDiagnostics *diagnostics)
110 : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
111 mShaderType(shaderType),
112 mDiagnostics(diagnostics)
113 {
114 ASSERT(diagnostics);
115 }
116
visitSymbol(TIntermSymbol * node)117 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
118 {
119 if (isLoopIndex(node) && isLValueRequiredHere())
120 {
121 error(node->getLine(),
122 "Loop index cannot be statically assigned to within the body of the loop",
123 node->getSymbol().c_str());
124 }
125 }
126
visitBinary(Visit,TIntermBinary * node)127 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
128 {
129 // Check indexing.
130 switch (node->getOp())
131 {
132 case EOpIndexDirect:
133 case EOpIndexIndirect:
134 validateIndexing(node);
135 break;
136 default:
137 break;
138 }
139 return true;
140 }
141
visitLoop(Visit,TIntermLoop * node)142 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
143 {
144 if (!validateLoopType(node))
145 return false;
146
147 if (!validateForLoopHeader(node))
148 return false;
149
150 TIntermNode *body = node->getBody();
151 if (body != nullptr)
152 {
153 mLoopSymbolIds.push_back(GetLoopSymbolId(node));
154 body->traverse(this);
155 mLoopSymbolIds.pop_back();
156 }
157
158 // The loop is fully processed - no need to visit children.
159 return false;
160 }
161
error(TSourceLoc loc,const char * reason,const char * token)162 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
163 {
164 mDiagnostics->error(loc, reason, token);
165 }
166
withinLoopBody() const167 bool ValidateLimitationsTraverser::withinLoopBody() const
168 {
169 return !mLoopSymbolIds.empty();
170 }
171
isLoopIndex(TIntermSymbol * symbol)172 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
173 {
174 return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->getId()) !=
175 mLoopSymbolIds.end();
176 }
177
validateLoopType(TIntermLoop * node)178 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
179 {
180 TLoopType type = node->getType();
181 if (type == ELoopFor)
182 return true;
183
184 // Reject while and do-while loops.
185 error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
186 return false;
187 }
188
validateForLoopHeader(TIntermLoop * node)189 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
190 {
191 ASSERT(node->getType() == ELoopFor);
192
193 //
194 // The for statement has the form:
195 // for ( init-declaration ; condition ; expression ) statement
196 //
197 int indexSymbolId = validateForLoopInit(node);
198 if (indexSymbolId < 0)
199 return false;
200 if (!validateForLoopCond(node, indexSymbolId))
201 return false;
202 if (!validateForLoopExpr(node, indexSymbolId))
203 return false;
204
205 return true;
206 }
207
validateForLoopInit(TIntermLoop * node)208 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
209 {
210 TIntermNode *init = node->getInit();
211 if (init == nullptr)
212 {
213 error(node->getLine(), "Missing init declaration", "for");
214 return -1;
215 }
216
217 //
218 // init-declaration has the form:
219 // type-specifier identifier = constant-expression
220 //
221 TIntermDeclaration *decl = init->getAsDeclarationNode();
222 if (decl == nullptr)
223 {
224 error(init->getLine(), "Invalid init declaration", "for");
225 return -1;
226 }
227 // To keep things simple do not allow declaration list.
228 TIntermSequence *declSeq = decl->getSequence();
229 if (declSeq->size() != 1)
230 {
231 error(decl->getLine(), "Invalid init declaration", "for");
232 return -1;
233 }
234 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
235 if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
236 {
237 error(decl->getLine(), "Invalid init declaration", "for");
238 return -1;
239 }
240 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
241 if (symbol == nullptr)
242 {
243 error(declInit->getLine(), "Invalid init declaration", "for");
244 return -1;
245 }
246 // The loop index has type int or float.
247 TBasicType type = symbol->getBasicType();
248 if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
249 {
250 error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
251 return -1;
252 }
253 // The loop index is initialized with constant expression.
254 if (!isConstExpr(declInit->getRight()))
255 {
256 error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
257 symbol->getSymbol().c_str());
258 return -1;
259 }
260
261 return symbol->getId();
262 }
263
validateForLoopCond(TIntermLoop * node,int indexSymbolId)264 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
265 {
266 TIntermNode *cond = node->getCondition();
267 if (cond == nullptr)
268 {
269 error(node->getLine(), "Missing condition", "for");
270 return false;
271 }
272 //
273 // condition has the form:
274 // loop_index relational_operator constant_expression
275 //
276 TIntermBinary *binOp = cond->getAsBinaryNode();
277 if (binOp == nullptr)
278 {
279 error(node->getLine(), "Invalid condition", "for");
280 return false;
281 }
282 // Loop index should be to the left of relational operator.
283 TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
284 if (symbol == nullptr)
285 {
286 error(binOp->getLine(), "Invalid condition", "for");
287 return false;
288 }
289 if (symbol->getId() != indexSymbolId)
290 {
291 error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str());
292 return false;
293 }
294 // Relational operator is one of: > >= < <= == or !=.
295 switch (binOp->getOp())
296 {
297 case EOpEqual:
298 case EOpNotEqual:
299 case EOpLessThan:
300 case EOpGreaterThan:
301 case EOpLessThanEqual:
302 case EOpGreaterThanEqual:
303 break;
304 default:
305 error(binOp->getLine(), "Invalid relational operator",
306 GetOperatorString(binOp->getOp()));
307 break;
308 }
309 // Loop index must be compared with a constant.
310 if (!isConstExpr(binOp->getRight()))
311 {
312 error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
313 symbol->getSymbol().c_str());
314 return false;
315 }
316
317 return true;
318 }
319
validateForLoopExpr(TIntermLoop * node,int indexSymbolId)320 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
321 {
322 TIntermNode *expr = node->getExpression();
323 if (expr == nullptr)
324 {
325 error(node->getLine(), "Missing expression", "for");
326 return false;
327 }
328
329 // for expression has one of the following forms:
330 // loop_index++
331 // loop_index--
332 // loop_index += constant_expression
333 // loop_index -= constant_expression
334 // ++loop_index
335 // --loop_index
336 // The last two forms are not specified in the spec, but I am assuming
337 // its an oversight.
338 TIntermUnary *unOp = expr->getAsUnaryNode();
339 TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
340
341 TOperator op = EOpNull;
342 TIntermSymbol *symbol = nullptr;
343 if (unOp != nullptr)
344 {
345 op = unOp->getOp();
346 symbol = unOp->getOperand()->getAsSymbolNode();
347 }
348 else if (binOp != nullptr)
349 {
350 op = binOp->getOp();
351 symbol = binOp->getLeft()->getAsSymbolNode();
352 }
353
354 // The operand must be loop index.
355 if (symbol == nullptr)
356 {
357 error(expr->getLine(), "Invalid expression", "for");
358 return false;
359 }
360 if (symbol->getId() != indexSymbolId)
361 {
362 error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str());
363 return false;
364 }
365
366 // The operator is one of: ++ -- += -=.
367 switch (op)
368 {
369 case EOpPostIncrement:
370 case EOpPostDecrement:
371 case EOpPreIncrement:
372 case EOpPreDecrement:
373 ASSERT((unOp != nullptr) && (binOp == nullptr));
374 break;
375 case EOpAddAssign:
376 case EOpSubAssign:
377 ASSERT((unOp == nullptr) && (binOp != nullptr));
378 break;
379 default:
380 error(expr->getLine(), "Invalid operator", GetOperatorString(op));
381 return false;
382 }
383
384 // Loop index must be incremented/decremented with a constant.
385 if (binOp != nullptr)
386 {
387 if (!isConstExpr(binOp->getRight()))
388 {
389 error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
390 symbol->getSymbol().c_str());
391 return false;
392 }
393 }
394
395 return true;
396 }
397
isConstExpr(TIntermNode * node)398 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
399 {
400 ASSERT(node != nullptr);
401 return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
402 }
403
isConstIndexExpr(TIntermNode * node)404 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
405 {
406 ASSERT(node != nullptr);
407
408 ValidateConstIndexExpr validate(mLoopSymbolIds);
409 node->traverse(&validate);
410 return validate.isValid();
411 }
412
validateIndexing(TIntermBinary * node)413 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
414 {
415 ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
416
417 bool valid = true;
418 TIntermTyped *index = node->getRight();
419 // The index expession must be a constant-index-expression unless
420 // the operand is a uniform in a vertex shader.
421 TIntermTyped *operand = node->getLeft();
422 bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
423 if (!skip && !isConstIndexExpr(index))
424 {
425 error(index->getLine(), "Index expression must be constant", "[]");
426 valid = false;
427 }
428 return valid;
429 }
430
431 } // namespace anonymous
432
ValidateLimitations(TIntermNode * root,GLenum shaderType,TSymbolTable * symbolTable,int shaderVersion,TDiagnostics * diagnostics)433 bool ValidateLimitations(TIntermNode *root,
434 GLenum shaderType,
435 TSymbolTable *symbolTable,
436 int shaderVersion,
437 TDiagnostics *diagnostics)
438 {
439 ValidateLimitationsTraverser validate(shaderType, symbolTable, shaderVersion, diagnostics);
440 root->traverse(&validate);
441 return diagnostics->numErrors() == 0;
442 }
443
444 } // namespace sh
445