1 //
2 // Copyright (C) 2015-2016 Google, Inc.
3 //
4 // All rights reserved.
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
8 // are met:
9 //
10 //    Redistributions of source code must retain the above copyright
11 //    notice, this list of conditions and the following disclaimer.
12 //
13 //    Redistributions in binary form must reproduce the above
14 //    copyright notice, this list of conditions and the following
15 //    disclaimer in the documentation and/or other materials provided
16 //    with the distribution.
17 //
18 //    Neither the name of Google Inc. nor the names of its
19 //    contributors may be used to endorse or promote products derived
20 //    from this software without specific prior written permission.
21 //
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33 // POSSIBILITY OF SUCH DAMAGE.
34 
35 //
36 // Visit the nodes in the glslang intermediate tree representation to
37 // propagate the 'noContraction' qualifier.
38 //
39 
40 #ifndef GLSLANG_WEB
41 
42 #include "propagateNoContraction.h"
43 
44 #include <cstdlib>
45 #include <string>
46 #include <tuple>
47 #include <unordered_map>
48 #include <unordered_set>
49 
50 #include "localintermediate.h"
51 namespace {
52 
53 // Use a string to hold the access chain information, as in most cases the
54 // access chain is short and may contain only one element, which is the symbol
55 // ID.
56 // Example: struct {float a; float b;} s;
57 //  Object s.a will be represented with: <symbol ID of s>/0
58 //  Object s.b will be represented with: <symbol ID of s>/1
59 //  Object s will be represented with: <symbol ID of s>
60 // For members of vector, matrix and arrays, they will be represented with the
61 // same symbol ID of their container symbol objects. This is because their
62 // preciseness is always the same as their container symbol objects.
63 typedef std::string ObjectAccessChain;
64 
65 // The delimiter used in the ObjectAccessChain string to separate symbol ID and
66 // different level of struct indices.
67 const char ObjectAccesschainDelimiter = '/';
68 
69 // Mapping from Symbol IDs of symbol nodes, to their defining operation
70 // nodes.
71 typedef std::unordered_multimap<ObjectAccessChain, glslang::TIntermOperator*> NodeMapping;
72 // Mapping from object nodes to their access chain info string.
73 typedef std::unordered_map<glslang::TIntermTyped*, ObjectAccessChain> AccessChainMapping;
74 
75 // Set of object IDs.
76 typedef std::unordered_set<ObjectAccessChain> ObjectAccesschainSet;
77 // Set of return branch nodes.
78 typedef std::unordered_set<glslang::TIntermBranch*> ReturnBranchNodeSet;
79 
80 // A helper function to tell whether a node is 'noContraction'. Returns true if
81 // the node has 'noContraction' qualifier, otherwise false.
isPreciseObjectNode(glslang::TIntermTyped * node)82 bool isPreciseObjectNode(glslang::TIntermTyped* node)
83 {
84     return node->getType().getQualifier().isNoContraction();
85 }
86 
87 // Returns true if the opcode is a dereferencing one.
isDereferenceOperation(glslang::TOperator op)88 bool isDereferenceOperation(glslang::TOperator op)
89 {
90     switch (op) {
91     case glslang::EOpIndexDirect:
92     case glslang::EOpIndexDirectStruct:
93     case glslang::EOpIndexIndirect:
94     case glslang::EOpVectorSwizzle:
95     case glslang::EOpMatrixSwizzle:
96         return true;
97     default:
98         return false;
99     }
100 }
101 
102 // Returns true if the opcode leads to an assignment operation.
isAssignOperation(glslang::TOperator op)103 bool isAssignOperation(glslang::TOperator op)
104 {
105     switch (op) {
106     case glslang::EOpAssign:
107     case glslang::EOpAddAssign:
108     case glslang::EOpSubAssign:
109     case glslang::EOpMulAssign:
110     case glslang::EOpVectorTimesMatrixAssign:
111     case glslang::EOpVectorTimesScalarAssign:
112     case glslang::EOpMatrixTimesScalarAssign:
113     case glslang::EOpMatrixTimesMatrixAssign:
114     case glslang::EOpDivAssign:
115     case glslang::EOpModAssign:
116     case glslang::EOpAndAssign:
117     case glslang::EOpLeftShiftAssign:
118     case glslang::EOpRightShiftAssign:
119     case glslang::EOpInclusiveOrAssign:
120     case glslang::EOpExclusiveOrAssign:
121 
122     case glslang::EOpPostIncrement:
123     case glslang::EOpPostDecrement:
124     case glslang::EOpPreIncrement:
125     case glslang::EOpPreDecrement:
126         return true;
127     default:
128         return false;
129     }
130 }
131 
132 // A helper function to get the unsigned int from a given constant union node.
133 // Note the node should only hold a uint scalar.
getStructIndexFromConstantUnion(glslang::TIntermTyped * node)134 unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped* node)
135 {
136     assert(node->getAsConstantUnion() && node->getAsConstantUnion()->isScalar());
137     unsigned struct_dereference_index = node->getAsConstantUnion()->getConstArray()[0].getUConst();
138     return struct_dereference_index;
139 }
140 
141 // A helper function to generate symbol_label.
generateSymbolLabel(glslang::TIntermSymbol * node)142 ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol* node)
143 {
144     ObjectAccessChain symbol_id =
145         std::to_string(node->getId()) + "(" + node->getName().c_str() + ")";
146     return symbol_id;
147 }
148 
149 // Returns true if the operation is an arithmetic operation and valid for
150 // the 'NoContraction' decoration.
isArithmeticOperation(glslang::TOperator op)151 bool isArithmeticOperation(glslang::TOperator op)
152 {
153     switch (op) {
154     case glslang::EOpAddAssign:
155     case glslang::EOpSubAssign:
156     case glslang::EOpMulAssign:
157     case glslang::EOpVectorTimesMatrixAssign:
158     case glslang::EOpVectorTimesScalarAssign:
159     case glslang::EOpMatrixTimesScalarAssign:
160     case glslang::EOpMatrixTimesMatrixAssign:
161     case glslang::EOpDivAssign:
162     case glslang::EOpModAssign:
163 
164     case glslang::EOpNegative:
165 
166     case glslang::EOpAdd:
167     case glslang::EOpSub:
168     case glslang::EOpMul:
169     case glslang::EOpDiv:
170     case glslang::EOpMod:
171 
172     case glslang::EOpVectorTimesScalar:
173     case glslang::EOpVectorTimesMatrix:
174     case glslang::EOpMatrixTimesVector:
175     case glslang::EOpMatrixTimesScalar:
176     case glslang::EOpMatrixTimesMatrix:
177 
178     case glslang::EOpDot:
179 
180     case glslang::EOpPostIncrement:
181     case glslang::EOpPostDecrement:
182     case glslang::EOpPreIncrement:
183     case glslang::EOpPreDecrement:
184         return true;
185     default:
186         return false;
187     }
188 }
189 
190 // A helper class to help manage the populating_initial_no_contraction_ flag.
191 template <typename T> class StateSettingGuard {
192 public:
StateSettingGuard(T * state_ptr,T new_state_value)193     StateSettingGuard(T* state_ptr, T new_state_value)
194         : state_ptr_(state_ptr), previous_state_(*state_ptr)
195     {
196         *state_ptr = new_state_value;
197     }
StateSettingGuard(T * state_ptr)198     StateSettingGuard(T* state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {}
setState(T new_state_value)199     void setState(T new_state_value) { *state_ptr_ = new_state_value; }
~StateSettingGuard()200     ~StateSettingGuard() { *state_ptr_ = previous_state_; }
201 
202 private:
203     T* state_ptr_;
204     T previous_state_;
205 };
206 
207 // A helper function to get the front element from a given ObjectAccessChain
getFrontElement(const ObjectAccessChain & chain)208 ObjectAccessChain getFrontElement(const ObjectAccessChain& chain)
209 {
210     size_t pos_delimiter = chain.find(ObjectAccesschainDelimiter);
211     return pos_delimiter == std::string::npos ? chain : chain.substr(0, pos_delimiter);
212 }
213 
214 // A helper function to get the access chain starting from the second element.
subAccessChainFromSecondElement(const ObjectAccessChain & chain)215 ObjectAccessChain subAccessChainFromSecondElement(const ObjectAccessChain& chain)
216 {
217     size_t pos_delimiter = chain.find(ObjectAccesschainDelimiter);
218     return pos_delimiter == std::string::npos ? "" : chain.substr(pos_delimiter + 1);
219 }
220 
221 // A helper function to get the access chain after removing a given prefix.
getSubAccessChainAfterPrefix(const ObjectAccessChain & chain,const ObjectAccessChain & prefix)222 ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain& chain,
223                                                const ObjectAccessChain& prefix)
224 {
225     size_t pos = chain.find(prefix);
226     if (pos != 0)
227         return chain;
228     return chain.substr(prefix.length() + sizeof(ObjectAccesschainDelimiter));
229 }
230 
231 //
232 // A traverser which traverses the whole AST and populates:
233 //  1) A mapping from symbol nodes' IDs to their defining operation nodes.
234 //  2) A set of access chains of the initial precise object nodes.
235 //
236 class TSymbolDefinitionCollectingTraverser : public glslang::TIntermTraverser {
237 public:
238     TSymbolDefinitionCollectingTraverser(NodeMapping* symbol_definition_mapping,
239                                          AccessChainMapping* accesschain_mapping,
240                                          ObjectAccesschainSet* precise_objects,
241                                          ReturnBranchNodeSet* precise_return_nodes);
242 
243     bool visitUnary(glslang::TVisit, glslang::TIntermUnary*) override;
244     bool visitBinary(glslang::TVisit, glslang::TIntermBinary*) override;
245     void visitSymbol(glslang::TIntermSymbol*) override;
246     bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate*) override;
247     bool visitBranch(glslang::TVisit, glslang::TIntermBranch*) override;
248 
249 protected:
250     TSymbolDefinitionCollectingTraverser& operator=(const TSymbolDefinitionCollectingTraverser&);
251 
252     // The mapping from symbol node IDs to their defining nodes. This should be
253     // populated along traversing the AST.
254     NodeMapping& symbol_definition_mapping_;
255     // The set of symbol node IDs for precise symbol nodes, the ones marked as
256     // 'noContraction'.
257     ObjectAccesschainSet& precise_objects_;
258     // The set of precise return nodes.
259     ReturnBranchNodeSet& precise_return_nodes_;
260     // A temporary cache of the symbol node whose defining node is to be found
261     // currently along traversing the AST.
262     ObjectAccessChain current_object_;
263     // A map from object node to its access chain. This traverser stores
264     // the built access chains into this map for each object node it has
265     // visited.
266     AccessChainMapping& accesschain_mapping_;
267     // The pointer to the Function Definition node, so we can get the
268     // preciseness of the return expression from it when we traverse the
269     // return branch node.
270     glslang::TIntermAggregate* current_function_definition_node_;
271 };
272 
TSymbolDefinitionCollectingTraverser(NodeMapping * symbol_definition_mapping,AccessChainMapping * accesschain_mapping,ObjectAccesschainSet * precise_objects,std::unordered_set<glslang::TIntermBranch * > * precise_return_nodes)273 TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser(
274     NodeMapping* symbol_definition_mapping, AccessChainMapping* accesschain_mapping,
275     ObjectAccesschainSet* precise_objects,
276     std::unordered_set<glslang::TIntermBranch*>* precise_return_nodes)
277     : TIntermTraverser(true, false, false), symbol_definition_mapping_(*symbol_definition_mapping),
278       precise_objects_(*precise_objects), precise_return_nodes_(*precise_return_nodes),
279       current_object_(), accesschain_mapping_(*accesschain_mapping),
280       current_function_definition_node_(nullptr) {}
281 
282 // Visits a symbol node, set the current_object_ to the
283 // current node symbol ID, and record a mapping from this node to the current
284 // current_object_, which is the just obtained symbol
285 // ID.
visitSymbol(glslang::TIntermSymbol * node)286 void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol* node)
287 {
288     current_object_ = generateSymbolLabel(node);
289     accesschain_mapping_[node] = current_object_;
290 }
291 
292 // Visits an aggregate node, traverses all of its children.
visitAggregate(glslang::TVisit,glslang::TIntermAggregate * node)293 bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit,
294                                                           glslang::TIntermAggregate* node)
295 {
296     // This aggregate node might be a function definition node, in which case we need to
297     // cache this node, so we can get the preciseness information of the return value
298     // of this function later.
299     StateSettingGuard<glslang::TIntermAggregate*> current_function_definition_node_setting_guard(
300         &current_function_definition_node_);
301     if (node->getOp() == glslang::EOpFunction) {
302         // This is function definition node, we need to cache this node so that we can
303         // get the preciseness of the return value later.
304         current_function_definition_node_setting_guard.setState(node);
305     }
306     // Traverse the items in the sequence.
307     glslang::TIntermSequence& seq = node->getSequence();
308     for (int i = 0; i < (int)seq.size(); ++i) {
309         current_object_.clear();
310         seq[i]->traverse(this);
311     }
312     return false;
313 }
314 
visitBranch(glslang::TVisit,glslang::TIntermBranch * node)315 bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit,
316                                                        glslang::TIntermBranch* node)
317 {
318     if (node->getFlowOp() == glslang::EOpReturn && node->getExpression() &&
319         current_function_definition_node_ &&
320         current_function_definition_node_->getType().getQualifier().noContraction) {
321         // This node is a return node with an expression, and its function has a
322         // precise return value. We need to find the involved objects in its
323         // expression and add them to the set of initial precise objects.
324         precise_return_nodes_.insert(node);
325         node->getExpression()->traverse(this);
326     }
327     return false;
328 }
329 
330 // Visits a unary node. This might be an implicit assignment like i++, i--. etc.
visitUnary(glslang::TVisit,glslang::TIntermUnary * node)331 bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit */,
332                                                       glslang::TIntermUnary* node)
333 {
334     current_object_.clear();
335     node->getOperand()->traverse(this);
336     if (isAssignOperation(node->getOp())) {
337         // We should always be able to get an access chain of the operand node.
338         assert(!current_object_.empty());
339 
340         // If the operand node object is 'precise', we collect its access chain
341         // for the initial set of 'precise' objects.
342         if (isPreciseObjectNode(node->getOperand())) {
343             // The operand node is an 'precise' object node, add its
344             // access chain to the set of 'precise' objects. This is to collect
345             // the initial set of 'precise' objects.
346             precise_objects_.insert(current_object_);
347         }
348         // Gets the symbol ID from the object's access chain.
349         ObjectAccessChain id_symbol = getFrontElement(current_object_);
350         // Add a mapping from the symbol ID to this assignment operation node.
351         symbol_definition_mapping_.insert(std::make_pair(id_symbol, node));
352     }
353     // A unary node is not a dereference node, so we clear the access chain which
354     // is under construction.
355     current_object_.clear();
356     return false;
357 }
358 
359 // Visits a binary node and updates the mapping from symbol IDs to the definition
360 // nodes. Also collects the access chains for the initial precise objects.
visitBinary(glslang::TVisit,glslang::TIntermBinary * node)361 bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit */,
362                                                        glslang::TIntermBinary* node)
363 {
364     // Traverses the left node to build the access chain info for the object.
365     current_object_.clear();
366     node->getLeft()->traverse(this);
367 
368     if (isAssignOperation(node->getOp())) {
369         // We should always be able to get an access chain for the left node.
370         assert(!current_object_.empty());
371 
372         // If the left node object is 'precise', it is an initial precise object
373         // specified in the shader source. Adds it to the initial work list to
374         // process later.
375         if (isPreciseObjectNode(node->getLeft())) {
376             // The left node is an 'precise' object node, add its access chain to
377             // the set of 'precise' objects. This is to collect the initial set
378             // of 'precise' objects.
379             precise_objects_.insert(current_object_);
380         }
381         // Gets the symbol ID from the object access chain, which should be the
382         // first element recorded in the access chain.
383         ObjectAccessChain id_symbol = getFrontElement(current_object_);
384         // Adds a mapping from the symbol ID to this assignment operation node.
385         symbol_definition_mapping_.insert(std::make_pair(id_symbol, node));
386 
387         // Traverses the right node, there may be other 'assignment'
388         // operations in the right.
389         current_object_.clear();
390         node->getRight()->traverse(this);
391 
392     } else if (isDereferenceOperation(node->getOp())) {
393         // The left node (parent node) is a struct type object. We need to
394         // record the access chain information of the current node into its
395         // object id.
396         if (node->getOp() == glslang::EOpIndexDirectStruct) {
397             unsigned struct_dereference_index = getStructIndexFromConstantUnion(node->getRight());
398             current_object_.push_back(ObjectAccesschainDelimiter);
399             current_object_.append(std::to_string(struct_dereference_index));
400         }
401         accesschain_mapping_[node] = current_object_;
402 
403         // For a dereference node, there is no need to traverse the right child
404         // node as the right node should always be an integer type object.
405 
406     } else {
407         // For other binary nodes, still traverse the right node.
408         current_object_.clear();
409         node->getRight()->traverse(this);
410     }
411     return false;
412 }
413 
414 // Traverses the AST and returns a tuple of four members:
415 // 1) a mapping from symbol IDs to the definition nodes (aka. assignment nodes) of these symbols.
416 // 2) a mapping from object nodes in the AST to the access chains of these objects.
417 // 3) a set of access chains of precise objects.
418 // 4) a set of return nodes with precise expressions.
419 std::tuple<NodeMapping, AccessChainMapping, ObjectAccesschainSet, ReturnBranchNodeSet>
getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate & intermediate)420 getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate& intermediate)
421 {
422     auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(),
423                                         ReturnBranchNodeSet());
424 
425     TIntermNode* root = intermediate.getTreeRoot();
426     if (root == 0)
427         return result_tuple;
428 
429     NodeMapping& symbol_definition_mapping = std::get<0>(result_tuple);
430     AccessChainMapping& accesschain_mapping = std::get<1>(result_tuple);
431     ObjectAccesschainSet& precise_objects = std::get<2>(result_tuple);
432     ReturnBranchNodeSet& precise_return_nodes = std::get<3>(result_tuple);
433 
434     // Traverses the AST and populate the results.
435     TSymbolDefinitionCollectingTraverser collector(&symbol_definition_mapping, &accesschain_mapping,
436                                                    &precise_objects, &precise_return_nodes);
437     root->traverse(&collector);
438 
439     return result_tuple;
440 }
441 
442 //
443 // A traverser that determine whether the left node (or operand node for unary
444 // node) of an assignment node is 'precise', containing 'precise' or not,
445 // according to the access chain a given precise object which share the same
446 // symbol as the left node.
447 //
448 // Post-orderly traverses the left node subtree of an binary assignment node and:
449 //
450 //  1) Propagates the 'precise' from the left object nodes to this object node.
451 //
452 //  2) Builds object access chain along the traversal, and also compares with
453 //  the access chain of the given 'precise' object along with the traversal to
454 //  tell if the node to be defined is 'precise' or not.
455 //
456 class TNoContractionAssigneeCheckingTraverser : public glslang::TIntermTraverser {
457 
458     enum DecisionStatus {
459         // The object node to be assigned to may contain 'precise' objects and also not 'precise' objects.
460         Mixed = 0,
461         // The object node to be assigned to is either a 'precise' object or a struct objects whose members are all 'precise'.
462         Precise = 1,
463         // The object node to be assigned to is not a 'precise' object.
464         NotPreicse = 2,
465     };
466 
467 public:
TNoContractionAssigneeCheckingTraverser(const AccessChainMapping & accesschain_mapping)468     TNoContractionAssigneeCheckingTraverser(const AccessChainMapping& accesschain_mapping)
469         : TIntermTraverser(true, false, false), accesschain_mapping_(accesschain_mapping),
470           precise_object_(nullptr) {}
471 
472     // Checks the preciseness of a given assignment node with a precise object
473     // represented as access chain. The precise object shares the same symbol
474     // with the assignee of the given assignment node. Return a tuple of two:
475     //
476     //  1) The preciseness of the assignee node of this assignment node. True
477     //  if the assignee contains 'precise' objects or is 'precise', false if
478     //  the assignee is not 'precise' according to the access chain of the given
479     //  precise object.
480     //
481     //  2) The incremental access chain from the assignee node to its nested
482     //  'precise' object, according to the access chain of the given precise
483     //  object. This incremental access chain can be empty, which means the
484     //  assignee is 'precise'. Otherwise it shows the path to the nested
485     //  precise object.
486     std::tuple<bool, ObjectAccessChain>
getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator * node,const ObjectAccessChain & precise_object)487     getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator* node,
488                                          const ObjectAccessChain& precise_object)
489     {
490         assert(isAssignOperation(node->getOp()));
491         precise_object_ = &precise_object;
492         ObjectAccessChain assignee_object;
493         if (glslang::TIntermBinary* BN = node->getAsBinaryNode()) {
494             // This is a binary assignment node, we need to check the
495             // preciseness of the left node.
496             assert(accesschain_mapping_.count(BN->getLeft()));
497             // The left node (assignee node) is an object node, traverse the
498             // node to let the 'precise' of nesting objects being transfered to
499             // nested objects.
500             BN->getLeft()->traverse(this);
501             // After traversing the left node, if the left node is 'precise',
502             // we can conclude this assignment should propagate 'precise'.
503             if (isPreciseObjectNode(BN->getLeft())) {
504                 return make_tuple(true, ObjectAccessChain());
505             }
506             // If the preciseness of the left node (assignee node) can not
507             // be determined by now, we need to compare the access chain string
508             // of the assignee object with the given precise object.
509             assignee_object = accesschain_mapping_.at(BN->getLeft());
510 
511         } else if (glslang::TIntermUnary* UN = node->getAsUnaryNode()) {
512             // This is a unary assignment node, we need to check the
513             // preciseness of the operand node. For unary assignment node, the
514             // operand node should always be an object node.
515             assert(accesschain_mapping_.count(UN->getOperand()));
516             // Traverse the operand node to let the 'precise' being propagated
517             // from lower nodes to upper nodes.
518             UN->getOperand()->traverse(this);
519             // After traversing the operand node, if the operand node is
520             // 'precise', this assignment should propagate 'precise'.
521             if (isPreciseObjectNode(UN->getOperand())) {
522                 return make_tuple(true, ObjectAccessChain());
523             }
524             // If the preciseness of the operand node (assignee node) can not
525             // be determined by now, we need to compare the access chain string
526             // of the assignee object with the given precise object.
527             assignee_object = accesschain_mapping_.at(UN->getOperand());
528         } else {
529             // Not a binary or unary node, should not happen.
530             assert(false);
531         }
532 
533         // Compare the access chain string of the assignee node with the given
534         // precise object to determine if this assignment should propagate
535         // 'precise'.
536         if (assignee_object.find(precise_object) == 0) {
537             // The access chain string of the given precise object is a prefix
538             // of assignee's access chain string. The assignee should be
539             // 'precise'.
540             return make_tuple(true, ObjectAccessChain());
541         } else if (precise_object.find(assignee_object) == 0) {
542             // The assignee's access chain string is a prefix of the given
543             // precise object, the assignee object contains 'precise' object,
544             // and we need to pass the remained access chain to the object nodes
545             // in the right.
546             return make_tuple(true, getSubAccessChainAfterPrefix(precise_object, assignee_object));
547         } else {
548             // The access chain strings do not match, the assignee object can
549             // not be labeled as 'precise' according to the given precise
550             // object.
551             return make_tuple(false, ObjectAccessChain());
552         }
553     }
554 
555 protected:
556     TNoContractionAssigneeCheckingTraverser& operator=(const TNoContractionAssigneeCheckingTraverser&);
557 
558     bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override;
559     void visitSymbol(glslang::TIntermSymbol* node) override;
560 
561     // A map from object nodes to their access chain string (used as object ID).
562     const AccessChainMapping& accesschain_mapping_;
563     // A given precise object, represented in it access chain string. This
564     // precise object is used to be compared with the assignee node to tell if
565     // the assignee node is 'precise', contains 'precise' object or not
566     // 'precise'.
567     const ObjectAccessChain* precise_object_;
568 };
569 
570 // Visits a binary node. If the node is an object node, it must be a dereference
571 // node. In such cases, if the left node is 'precise', this node should also be
572 // 'precise'.
visitBinary(glslang::TVisit,glslang::TIntermBinary * node)573 bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit,
574                                                           glslang::TIntermBinary* node)
575 {
576     // Traverses the left so that we transfer the 'precise' from nesting object
577     // to its nested object.
578     node->getLeft()->traverse(this);
579     // If this binary node is an object node, we should have it in the
580     // accesschain_mapping_.
581     if (accesschain_mapping_.count(node)) {
582         // A binary object node must be a dereference node.
583         assert(isDereferenceOperation(node->getOp()));
584         // If the left node is 'precise', this node should also be precise,
585         // otherwise, compare with the given precise_object_. If the
586         // access chain of this node matches with the given precise_object_,
587         // this node should be marked as 'precise'.
588         if (isPreciseObjectNode(node->getLeft())) {
589             node->getWritableType().getQualifier().noContraction = true;
590         } else if (accesschain_mapping_.at(node) == *precise_object_) {
591             node->getWritableType().getQualifier().noContraction = true;
592         }
593     }
594     return false;
595 }
596 
597 // Visits a symbol node, if the symbol node ID (its access chain string) matches
598 // with the given precise object, this node should be 'precise'.
visitSymbol(glslang::TIntermSymbol * node)599 void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol* node)
600 {
601     // A symbol node should always be an object node, and should have been added
602     // to the map from object nodes to their access chain strings.
603     assert(accesschain_mapping_.count(node));
604     if (accesschain_mapping_.at(node) == *precise_object_) {
605         node->getWritableType().getQualifier().noContraction = true;
606     }
607 }
608 
609 //
610 // A traverser that only traverses the right side of binary assignment nodes
611 // and the operand node of unary assignment nodes.
612 //
613 // 1) Marks arithmetic operations as 'NoContraction'.
614 //
615 // 2) Find the object which should be marked as 'precise' in the right and
616 //    update the 'precise' object work list.
617 //
618 class TNoContractionPropagator : public glslang::TIntermTraverser {
619 public:
TNoContractionPropagator(ObjectAccesschainSet * precise_objects,const AccessChainMapping & accesschain_mapping)620     TNoContractionPropagator(ObjectAccesschainSet* precise_objects,
621                              const AccessChainMapping& accesschain_mapping)
622         : TIntermTraverser(true, false, false),
623           precise_objects_(*precise_objects), added_precise_object_ids_(),
624           remained_accesschain_(), accesschain_mapping_(accesschain_mapping) {}
625 
626     // Propagates 'precise' in the right nodes of a given assignment node with
627     // access chain record from the assignee node to a 'precise' object it
628     // contains.
629     void
propagateNoContractionInOneExpression(glslang::TIntermTyped * defining_node,const ObjectAccessChain & assignee_remained_accesschain)630     propagateNoContractionInOneExpression(glslang::TIntermTyped* defining_node,
631                                           const ObjectAccessChain& assignee_remained_accesschain)
632     {
633         remained_accesschain_ = assignee_remained_accesschain;
634         if (glslang::TIntermBinary* BN = defining_node->getAsBinaryNode()) {
635             assert(isAssignOperation(BN->getOp()));
636             BN->getRight()->traverse(this);
637             if (isArithmeticOperation(BN->getOp())) {
638                 BN->getWritableType().getQualifier().noContraction = true;
639             }
640         } else if (glslang::TIntermUnary* UN = defining_node->getAsUnaryNode()) {
641             assert(isAssignOperation(UN->getOp()));
642             UN->getOperand()->traverse(this);
643             if (isArithmeticOperation(UN->getOp())) {
644                 UN->getWritableType().getQualifier().noContraction = true;
645             }
646         }
647     }
648 
649     // Propagates 'precise' in a given precise return node.
propagateNoContractionInReturnNode(glslang::TIntermBranch * return_node)650     void propagateNoContractionInReturnNode(glslang::TIntermBranch* return_node)
651     {
652         remained_accesschain_ = "";
653         assert(return_node->getFlowOp() == glslang::EOpReturn && return_node->getExpression());
654         return_node->getExpression()->traverse(this);
655     }
656 
657 protected:
658     TNoContractionPropagator& operator=(const TNoContractionPropagator&);
659 
660     // Visits an aggregate node. The node can be a initializer list, in which
661     // case we need to find the 'precise' or 'precise' containing object node
662     // with the access chain record. In other cases, just need to traverse all
663     // the children nodes.
visitAggregate(glslang::TVisit,glslang::TIntermAggregate * node)664     bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) override
665     {
666         if (!remained_accesschain_.empty() && node->getOp() == glslang::EOpConstructStruct) {
667             // This is a struct initializer node, and the remained
668             // access chain is not empty, we need to refer to the
669             // assignee_remained_access_chain_ to find the nested
670             // 'precise' object. And we don't need to visit other nodes in this
671             // aggregate node.
672 
673             // Gets the struct dereference index that leads to 'precise' object.
674             ObjectAccessChain precise_accesschain_index_str =
675                 getFrontElement(remained_accesschain_);
676             unsigned precise_accesschain_index = (unsigned)strtoul(precise_accesschain_index_str.c_str(), nullptr, 10);
677             // Gets the node pointed by the access chain index extracted before.
678             glslang::TIntermTyped* potential_precise_node =
679                 node->getSequence()[precise_accesschain_index]->getAsTyped();
680             assert(potential_precise_node);
681             // Pop the front access chain index from the path, and visit the nested node.
682             {
683                 ObjectAccessChain next_level_accesschain =
684                     subAccessChainFromSecondElement(remained_accesschain_);
685                 StateSettingGuard<ObjectAccessChain> setup_remained_accesschain_for_next_level(
686                     &remained_accesschain_, next_level_accesschain);
687                 potential_precise_node->traverse(this);
688             }
689             return false;
690         }
691         return true;
692     }
693 
694     // Visits a binary node. A binary node can be an object node, e.g. a dereference node.
695     // As only the top object nodes in the right side of an assignment needs to be visited
696     // and added to 'precise' work list, this traverser won't visit the children nodes of
697     // an object node. If the binary node does not represent an object node, it should
698     // go on to traverse its children nodes and if it is an arithmetic operation node, this
699     // operation should be marked as 'noContraction'.
visitBinary(glslang::TVisit,glslang::TIntermBinary * node)700     bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override
701     {
702         if (isDereferenceOperation(node->getOp())) {
703             // This binary node is an object node. Need to update the precise
704             // object set with the access chain of this node + remained
705             // access chain .
706             ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node);
707             if (remained_accesschain_.empty()) {
708                 node->getWritableType().getQualifier().noContraction = true;
709             } else {
710                 new_precise_accesschain += ObjectAccesschainDelimiter + remained_accesschain_;
711             }
712             // Cache the access chain as added precise object, so we won't add the
713             // same object to the work list again.
714             if (!added_precise_object_ids_.count(new_precise_accesschain)) {
715                 precise_objects_.insert(new_precise_accesschain);
716                 added_precise_object_ids_.insert(new_precise_accesschain);
717             }
718             // Only the upper-most object nodes should be visited, so do not
719             // visit children of this object node.
720             return false;
721         }
722         // If this is an arithmetic operation, marks this node as 'noContraction'.
723         if (isArithmeticOperation(node->getOp()) && node->getBasicType() != glslang::EbtInt) {
724             node->getWritableType().getQualifier().noContraction = true;
725         }
726         // As this node is not an object node, need to traverse the children nodes.
727         return true;
728     }
729 
730     // Visits a unary node. A unary node can not be an object node. If the operation
731     // is an arithmetic operation, need to mark this node as 'noContraction'.
visitUnary(glslang::TVisit,glslang::TIntermUnary * node)732     bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary* node) override
733     {
734         // If this is an arithmetic operation, marks this with 'noContraction'
735         if (isArithmeticOperation(node->getOp())) {
736             node->getWritableType().getQualifier().noContraction = true;
737         }
738         return true;
739     }
740 
741     // Visits a symbol node. A symbol node is always an object node. So we
742     // should always be able to find its in our collected mapping from object
743     // nodes to access chains.  As an object node, a symbol node can be either
744     // 'precise' or containing 'precise' objects according to unused
745     // access chain information we have when we visit this node.
visitSymbol(glslang::TIntermSymbol * node)746     void visitSymbol(glslang::TIntermSymbol* node) override
747     {
748         // Symbol nodes are object nodes and should always have an
749         // access chain collected before matches with it.
750         assert(accesschain_mapping_.count(node));
751         ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node);
752         // If the unused access chain is empty, this symbol node should be
753         // marked as 'precise'.  Otherwise, the unused access chain should be
754         // appended to the symbol ID to build a new access chain which points to
755         // the nested 'precise' object in this symbol object.
756         if (remained_accesschain_.empty()) {
757             node->getWritableType().getQualifier().noContraction = true;
758         } else {
759             new_precise_accesschain += ObjectAccesschainDelimiter + remained_accesschain_;
760         }
761         // Add the new 'precise' access chain to the work list and make sure we
762         // don't visit it again.
763         if (!added_precise_object_ids_.count(new_precise_accesschain)) {
764             precise_objects_.insert(new_precise_accesschain);
765             added_precise_object_ids_.insert(new_precise_accesschain);
766         }
767     }
768 
769     // A set of precise objects, represented as access chains.
770     ObjectAccesschainSet& precise_objects_;
771     // Visited symbol nodes, should not revisit these nodes.
772     ObjectAccesschainSet added_precise_object_ids_;
773     // The left node of an assignment operation might be an parent of 'precise' objects.
774     // This means the left node might not be an 'precise' object node, but it may contains
775     // 'precise' qualifier which should be propagated to the corresponding child node in
776     // the right. So we need the path from the left node to its nested 'precise' node to
777     // tell us how to find the corresponding 'precise' node in the right.
778     ObjectAccessChain remained_accesschain_;
779     // A map from node pointers to their access chains.
780     const AccessChainMapping& accesschain_mapping_;
781 };
782 }
783 
784 namespace glslang {
785 
PropagateNoContraction(const glslang::TIntermediate & intermediate)786 void PropagateNoContraction(const glslang::TIntermediate& intermediate)
787 {
788     // First, traverses the AST, records symbols with their defining operations
789     // and collects the initial set of precise symbols (symbol nodes that marked
790     // as 'noContraction') and precise return nodes.
791     auto mappings_and_precise_objects =
792         getSymbolToDefinitionMappingAndPreciseSymbolIDs(intermediate);
793 
794     // The mapping of symbol node IDs to their defining nodes. This enables us
795     // to get the defining node directly from a given symbol ID without
796     // traversing the tree again.
797     NodeMapping& symbol_definition_mapping = std::get<0>(mappings_and_precise_objects);
798 
799     // The mapping of object nodes to their access chains recorded.
800     AccessChainMapping& accesschain_mapping = std::get<1>(mappings_and_precise_objects);
801 
802     // The initial set of 'precise' objects which are represented as the
803     // access chain toward them.
804     ObjectAccesschainSet& precise_object_accesschains = std::get<2>(mappings_and_precise_objects);
805 
806     // The set of 'precise' return nodes.
807     ReturnBranchNodeSet& precise_return_nodes = std::get<3>(mappings_and_precise_objects);
808 
809     // Second, uses the initial set of precise objects as a work list, pops an
810     // access chain, extract the symbol ID from it. Then:
811     //  1) Check the assignee object, see if it is 'precise' object node or
812     //  contains 'precise' object. Obtain the incremental access chain from the
813     //  assignee node to its nested 'precise' node (if any).
814     //  2) If the assignee object node is 'precise' or it contains 'precise'
815     //  objects, traverses the right side of the assignment operation
816     //  expression to mark arithmetic operations as 'noContration' and update
817     //  'precise' access chain work list with new found object nodes.
818     // Repeat above steps until the work list is empty.
819     TNoContractionAssigneeCheckingTraverser checker(accesschain_mapping);
820     TNoContractionPropagator propagator(&precise_object_accesschains, accesschain_mapping);
821 
822     // We have two initial precise work lists to handle:
823     //  1) precise return nodes
824     //  2) precise object access chains
825     // We should process the precise return nodes first and the involved
826     // objects in the return expression should be added to the precise object
827     // access chain set.
828     while (!precise_return_nodes.empty()) {
829         glslang::TIntermBranch* precise_return_node = *precise_return_nodes.begin();
830         propagator.propagateNoContractionInReturnNode(precise_return_node);
831         precise_return_nodes.erase(precise_return_node);
832     }
833 
834     while (!precise_object_accesschains.empty()) {
835         // Get the access chain of a precise object from the work list.
836         ObjectAccessChain precise_object_accesschain = *precise_object_accesschains.begin();
837         // Get the symbol id from the access chain.
838         ObjectAccessChain symbol_id = getFrontElement(precise_object_accesschain);
839         // Get all the defining nodes of that symbol ID.
840         std::pair<NodeMapping::iterator, NodeMapping::iterator> range =
841             symbol_definition_mapping.equal_range(symbol_id);
842         // Visits all the assignment nodes of that symbol ID and
843         //  1) Check if the assignee node is 'precise' or contains 'precise'
844         //  objects.
845         //  2) Propagate the 'precise' to the top layer object nodes
846         //  in the right side of the assignment operation, update the 'precise'
847         //  work list with new access chains representing the new 'precise'
848         //  objects, and mark arithmetic operations as 'noContraction'.
849         for (NodeMapping::iterator defining_node_iter = range.first;
850              defining_node_iter != range.second; defining_node_iter++) {
851             TIntermOperator* defining_node = defining_node_iter->second;
852             // Check the assignee node.
853             auto checker_result = checker.getPrecisenessAndRemainedAccessChain(
854                 defining_node, precise_object_accesschain);
855             bool& contain_precise = std::get<0>(checker_result);
856             ObjectAccessChain& remained_accesschain = std::get<1>(checker_result);
857             // If the assignee node is 'precise' or contains 'precise', propagate the
858             // 'precise' to the right. Otherwise just skip this assignment node.
859             if (contain_precise) {
860                 propagator.propagateNoContractionInOneExpression(defining_node,
861                                                                  remained_accesschain);
862             }
863         }
864         // Remove the last processed 'precise' object from the work list.
865         precise_object_accesschains.erase(precise_object_accesschain);
866     }
867 }
868 };
869 
870 #endif // GLSLANG_WEB
871