1 //
2 // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/ValidateLimitations.h"
8 
9 #include "angle_gl.h"
10 #include "compiler/translator/Diagnostics.h"
11 #include "compiler/translator/IntermTraverse.h"
12 #include "compiler/translator/ParseContext.h"
13 
14 namespace sh
15 {
16 
17 namespace
18 {
19 
GetLoopSymbolId(TIntermLoop * loop)20 int GetLoopSymbolId(TIntermLoop *loop)
21 {
22     // Here we assume all the operations are valid, because the loop node is
23     // already validated before this call.
24     TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
25     TIntermBinary *declInit  = (*declSeq)[0]->getAsBinaryNode();
26     TIntermSymbol *symbol    = declInit->getLeft()->getAsSymbolNode();
27 
28     return symbol->getId();
29 }
30 
31 // Traverses a node to check if it represents a constant index expression.
32 // Definition:
33 // constant-index-expressions are a superset of constant-expressions.
34 // Constant-index-expressions can include loop indices as defined in
35 // GLSL ES 1.0 spec, Appendix A, section 4.
36 // The following are constant-index-expressions:
37 // - Constant expressions
38 // - Loop indices as defined in section 4
39 // - Expressions composed of both of the above
40 class ValidateConstIndexExpr : public TIntermTraverser
41 {
42   public:
ValidateConstIndexExpr(const std::vector<int> & loopSymbols)43     ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
44         : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
45     {
46     }
47 
48     // Returns true if the parsed node represents a constant index expression.
isValid() const49     bool isValid() const { return mValid; }
50 
visitSymbol(TIntermSymbol * symbol)51     void visitSymbol(TIntermSymbol *symbol) override
52     {
53         // Only constants and loop indices are allowed in a
54         // constant index expression.
55         if (mValid)
56         {
57             bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
58                                           symbol->getId()) != mLoopSymbolIds.end();
59             mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
60         }
61     }
62 
63   private:
64     bool mValid;
65     const std::vector<int> mLoopSymbolIds;
66 };
67 
68 // Traverses intermediate tree to ensure that the shader does not exceed the
69 // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
70 class ValidateLimitationsTraverser : public TLValueTrackingTraverser
71 {
72   public:
73     ValidateLimitationsTraverser(sh::GLenum shaderType,
74                                  TSymbolTable *symbolTable,
75                                  int shaderVersion,
76                                  TDiagnostics *diagnostics);
77 
78     void visitSymbol(TIntermSymbol *node) override;
79     bool visitBinary(Visit, TIntermBinary *) override;
80     bool visitLoop(Visit, TIntermLoop *) override;
81 
82   private:
83     void error(TSourceLoc loc, const char *reason, const char *token);
84 
85     bool withinLoopBody() const;
86     bool isLoopIndex(TIntermSymbol *symbol);
87     bool validateLoopType(TIntermLoop *node);
88 
89     bool validateForLoopHeader(TIntermLoop *node);
90     // If valid, return the index symbol id; Otherwise, return -1.
91     int validateForLoopInit(TIntermLoop *node);
92     bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
93     bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
94 
95     // Returns true if indexing does not exceed the minimum functionality
96     // mandated in GLSL 1.0 spec, Appendix A, Section 5.
97     bool isConstExpr(TIntermNode *node);
98     bool isConstIndexExpr(TIntermNode *node);
99     bool validateIndexing(TIntermBinary *node);
100 
101     sh::GLenum mShaderType;
102     TDiagnostics *mDiagnostics;
103     std::vector<int> mLoopSymbolIds;
104 };
105 
ValidateLimitationsTraverser(sh::GLenum shaderType,TSymbolTable * symbolTable,int shaderVersion,TDiagnostics * diagnostics)106 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
107                                                            TSymbolTable *symbolTable,
108                                                            int shaderVersion,
109                                                            TDiagnostics *diagnostics)
110     : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
111       mShaderType(shaderType),
112       mDiagnostics(diagnostics)
113 {
114     ASSERT(diagnostics);
115 }
116 
visitSymbol(TIntermSymbol * node)117 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
118 {
119     if (isLoopIndex(node) && isLValueRequiredHere())
120     {
121         error(node->getLine(),
122               "Loop index cannot be statically assigned to within the body of the loop",
123               node->getSymbol().c_str());
124     }
125 }
126 
visitBinary(Visit,TIntermBinary * node)127 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
128 {
129     // Check indexing.
130     switch (node->getOp())
131     {
132         case EOpIndexDirect:
133         case EOpIndexIndirect:
134             validateIndexing(node);
135             break;
136         default:
137             break;
138     }
139     return true;
140 }
141 
visitLoop(Visit,TIntermLoop * node)142 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
143 {
144     if (!validateLoopType(node))
145         return false;
146 
147     if (!validateForLoopHeader(node))
148         return false;
149 
150     TIntermNode *body = node->getBody();
151     if (body != nullptr)
152     {
153         mLoopSymbolIds.push_back(GetLoopSymbolId(node));
154         body->traverse(this);
155         mLoopSymbolIds.pop_back();
156     }
157 
158     // The loop is fully processed - no need to visit children.
159     return false;
160 }
161 
error(TSourceLoc loc,const char * reason,const char * token)162 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
163 {
164     mDiagnostics->error(loc, reason, token);
165 }
166 
withinLoopBody() const167 bool ValidateLimitationsTraverser::withinLoopBody() const
168 {
169     return !mLoopSymbolIds.empty();
170 }
171 
isLoopIndex(TIntermSymbol * symbol)172 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
173 {
174     return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->getId()) !=
175            mLoopSymbolIds.end();
176 }
177 
validateLoopType(TIntermLoop * node)178 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
179 {
180     TLoopType type = node->getType();
181     if (type == ELoopFor)
182         return true;
183 
184     // Reject while and do-while loops.
185     error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
186     return false;
187 }
188 
validateForLoopHeader(TIntermLoop * node)189 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
190 {
191     ASSERT(node->getType() == ELoopFor);
192 
193     //
194     // The for statement has the form:
195     //    for ( init-declaration ; condition ; expression ) statement
196     //
197     int indexSymbolId = validateForLoopInit(node);
198     if (indexSymbolId < 0)
199         return false;
200     if (!validateForLoopCond(node, indexSymbolId))
201         return false;
202     if (!validateForLoopExpr(node, indexSymbolId))
203         return false;
204 
205     return true;
206 }
207 
validateForLoopInit(TIntermLoop * node)208 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
209 {
210     TIntermNode *init = node->getInit();
211     if (init == nullptr)
212     {
213         error(node->getLine(), "Missing init declaration", "for");
214         return -1;
215     }
216 
217     //
218     // init-declaration has the form:
219     //     type-specifier identifier = constant-expression
220     //
221     TIntermDeclaration *decl = init->getAsDeclarationNode();
222     if (decl == nullptr)
223     {
224         error(init->getLine(), "Invalid init declaration", "for");
225         return -1;
226     }
227     // To keep things simple do not allow declaration list.
228     TIntermSequence *declSeq = decl->getSequence();
229     if (declSeq->size() != 1)
230     {
231         error(decl->getLine(), "Invalid init declaration", "for");
232         return -1;
233     }
234     TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
235     if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
236     {
237         error(decl->getLine(), "Invalid init declaration", "for");
238         return -1;
239     }
240     TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
241     if (symbol == nullptr)
242     {
243         error(declInit->getLine(), "Invalid init declaration", "for");
244         return -1;
245     }
246     // The loop index has type int or float.
247     TBasicType type = symbol->getBasicType();
248     if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
249     {
250         error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
251         return -1;
252     }
253     // The loop index is initialized with constant expression.
254     if (!isConstExpr(declInit->getRight()))
255     {
256         error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
257               symbol->getSymbol().c_str());
258         return -1;
259     }
260 
261     return symbol->getId();
262 }
263 
validateForLoopCond(TIntermLoop * node,int indexSymbolId)264 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
265 {
266     TIntermNode *cond = node->getCondition();
267     if (cond == nullptr)
268     {
269         error(node->getLine(), "Missing condition", "for");
270         return false;
271     }
272     //
273     // condition has the form:
274     //     loop_index relational_operator constant_expression
275     //
276     TIntermBinary *binOp = cond->getAsBinaryNode();
277     if (binOp == nullptr)
278     {
279         error(node->getLine(), "Invalid condition", "for");
280         return false;
281     }
282     // Loop index should be to the left of relational operator.
283     TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
284     if (symbol == nullptr)
285     {
286         error(binOp->getLine(), "Invalid condition", "for");
287         return false;
288     }
289     if (symbol->getId() != indexSymbolId)
290     {
291         error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str());
292         return false;
293     }
294     // Relational operator is one of: > >= < <= == or !=.
295     switch (binOp->getOp())
296     {
297         case EOpEqual:
298         case EOpNotEqual:
299         case EOpLessThan:
300         case EOpGreaterThan:
301         case EOpLessThanEqual:
302         case EOpGreaterThanEqual:
303             break;
304         default:
305             error(binOp->getLine(), "Invalid relational operator",
306                   GetOperatorString(binOp->getOp()));
307             break;
308     }
309     // Loop index must be compared with a constant.
310     if (!isConstExpr(binOp->getRight()))
311     {
312         error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
313               symbol->getSymbol().c_str());
314         return false;
315     }
316 
317     return true;
318 }
319 
validateForLoopExpr(TIntermLoop * node,int indexSymbolId)320 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
321 {
322     TIntermNode *expr = node->getExpression();
323     if (expr == nullptr)
324     {
325         error(node->getLine(), "Missing expression", "for");
326         return false;
327     }
328 
329     // for expression has one of the following forms:
330     //     loop_index++
331     //     loop_index--
332     //     loop_index += constant_expression
333     //     loop_index -= constant_expression
334     //     ++loop_index
335     //     --loop_index
336     // The last two forms are not specified in the spec, but I am assuming
337     // its an oversight.
338     TIntermUnary *unOp   = expr->getAsUnaryNode();
339     TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
340 
341     TOperator op          = EOpNull;
342     TIntermSymbol *symbol = nullptr;
343     if (unOp != nullptr)
344     {
345         op     = unOp->getOp();
346         symbol = unOp->getOperand()->getAsSymbolNode();
347     }
348     else if (binOp != nullptr)
349     {
350         op     = binOp->getOp();
351         symbol = binOp->getLeft()->getAsSymbolNode();
352     }
353 
354     // The operand must be loop index.
355     if (symbol == nullptr)
356     {
357         error(expr->getLine(), "Invalid expression", "for");
358         return false;
359     }
360     if (symbol->getId() != indexSymbolId)
361     {
362         error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str());
363         return false;
364     }
365 
366     // The operator is one of: ++ -- += -=.
367     switch (op)
368     {
369         case EOpPostIncrement:
370         case EOpPostDecrement:
371         case EOpPreIncrement:
372         case EOpPreDecrement:
373             ASSERT((unOp != nullptr) && (binOp == nullptr));
374             break;
375         case EOpAddAssign:
376         case EOpSubAssign:
377             ASSERT((unOp == nullptr) && (binOp != nullptr));
378             break;
379         default:
380             error(expr->getLine(), "Invalid operator", GetOperatorString(op));
381             return false;
382     }
383 
384     // Loop index must be incremented/decremented with a constant.
385     if (binOp != nullptr)
386     {
387         if (!isConstExpr(binOp->getRight()))
388         {
389             error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
390                   symbol->getSymbol().c_str());
391             return false;
392         }
393     }
394 
395     return true;
396 }
397 
isConstExpr(TIntermNode * node)398 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
399 {
400     ASSERT(node != nullptr);
401     return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
402 }
403 
isConstIndexExpr(TIntermNode * node)404 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
405 {
406     ASSERT(node != nullptr);
407 
408     ValidateConstIndexExpr validate(mLoopSymbolIds);
409     node->traverse(&validate);
410     return validate.isValid();
411 }
412 
validateIndexing(TIntermBinary * node)413 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
414 {
415     ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
416 
417     bool valid          = true;
418     TIntermTyped *index = node->getRight();
419     // The index expession must be a constant-index-expression unless
420     // the operand is a uniform in a vertex shader.
421     TIntermTyped *operand = node->getLeft();
422     bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
423     if (!skip && !isConstIndexExpr(index))
424     {
425         error(index->getLine(), "Index expression must be constant", "[]");
426         valid = false;
427     }
428     return valid;
429 }
430 
431 }  // namespace anonymous
432 
ValidateLimitations(TIntermNode * root,GLenum shaderType,TSymbolTable * symbolTable,int shaderVersion,TDiagnostics * diagnostics)433 bool ValidateLimitations(TIntermNode *root,
434                          GLenum shaderType,
435                          TSymbolTable *symbolTable,
436                          int shaderVersion,
437                          TDiagnostics *diagnostics)
438 {
439     ValidateLimitationsTraverser validate(shaderType, symbolTable, shaderVersion, diagnostics);
440     root->traverse(&validate);
441     return diagnostics->numErrors() == 0;
442 }
443 
444 }  // namespace sh
445