1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASI,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
16 #define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "source/opt/tree_iterator.h"
24 
25 namespace spvtools {
26 namespace opt {
27 
28 class Loop;
29 class ScalarEvolutionAnalysis;
30 class SEConstantNode;
31 class SERecurrentNode;
32 class SEAddNode;
33 class SEMultiplyNode;
34 class SENegative;
35 class SEValueUnknown;
36 class SECantCompute;
37 
38 // Abstract class representing a node in the scalar evolution DAG. Each node
39 // contains a vector of pointers to its children and each subclass of SENode
40 // implements GetType and an As method to allow casting. SENodes can be hashed
41 // using the SENodeHash functor. The vector of children is sorted when a node is
42 // added. This is important as it allows the hash of X+Y to be the same as Y+X.
43 class SENode {
44  public:
45   enum SENodeType {
46     Constant,
47     RecurrentAddExpr,
48     Add,
49     Multiply,
50     Negative,
51     ValueUnknown,
52     CanNotCompute
53   };
54 
55   using ChildContainerType = std::vector<SENode*>;
56 
SENode(ScalarEvolutionAnalysis * parent_analysis)57   explicit SENode(ScalarEvolutionAnalysis* parent_analysis)
58       : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {}
59 
60   virtual SENodeType GetType() const = 0;
61 
~SENode()62   virtual ~SENode() {}
63 
AddChild(SENode * child)64   virtual inline void AddChild(SENode* child) {
65     // If this is a constant node, assert.
66     if (AsSEConstantNode()) {
67       assert(false && "Trying to add a child node to a constant!");
68     }
69 
70     // Find the first point in the vector where |child| is greater than the node
71     // currently in the vector.
72     auto find_first_less_than = [child](const SENode* node) {
73       return child->unique_id_ <= node->unique_id_;
74     };
75 
76     auto position = std::find_if_not(children_.begin(), children_.end(),
77                                      find_first_less_than);
78     // Children are sorted so the hashing and equality operator will be the same
79     // for a node with the same children. X+Y should be the same as Y+X.
80     children_.insert(position, child);
81   }
82 
83   // Get the type as an std::string. This is used to represent the node in the
84   // dot output and is used to hash the type as well.
85   std::string AsString() const;
86 
87   // Dump the SENode and its immediate children, if |recurse| is true then it
88   // will recurse through all children to print the DAG starting from this node
89   // as a root.
90   void DumpDot(std::ostream& out, bool recurse = false) const;
91 
92   // Checks if two nodes are the same by hashing them.
93   bool operator==(const SENode& other) const;
94 
95   // Checks if two nodes are not the same by comparing the hashes.
96   bool operator!=(const SENode& other) const;
97 
98   // Return the child node at |index|.
GetChild(size_t index)99   inline SENode* GetChild(size_t index) { return children_[index]; }
GetChild(size_t index)100   inline const SENode* GetChild(size_t index) const { return children_[index]; }
101 
102   // Iterator to iterate over the child nodes.
103   using iterator = ChildContainerType::iterator;
104   using const_iterator = ChildContainerType::const_iterator;
105 
106   // Iterate over immediate child nodes.
begin()107   iterator begin() { return children_.begin(); }
end()108   iterator end() { return children_.end(); }
109 
110   // Constant overloads for iterating over immediate child nodes.
begin()111   const_iterator begin() const { return children_.cbegin(); }
end()112   const_iterator end() const { return children_.cend(); }
cbegin()113   const_iterator cbegin() { return children_.cbegin(); }
cend()114   const_iterator cend() { return children_.cend(); }
115 
116   // Collect all the recurrent nodes in this SENode
CollectRecurrentNodes()117   std::vector<SERecurrentNode*> CollectRecurrentNodes() {
118     std::vector<SERecurrentNode*> recurrent_nodes{};
119 
120     if (auto recurrent_node = AsSERecurrentNode()) {
121       recurrent_nodes.push_back(recurrent_node);
122     }
123 
124     for (auto child : GetChildren()) {
125       auto child_recurrent_nodes = child->CollectRecurrentNodes();
126       recurrent_nodes.insert(recurrent_nodes.end(),
127                              child_recurrent_nodes.begin(),
128                              child_recurrent_nodes.end());
129     }
130 
131     return recurrent_nodes;
132   }
133 
134   // Collect all the value unknown nodes in this SENode
CollectValueUnknownNodes()135   std::vector<SEValueUnknown*> CollectValueUnknownNodes() {
136     std::vector<SEValueUnknown*> value_unknown_nodes{};
137 
138     if (auto value_unknown_node = AsSEValueUnknown()) {
139       value_unknown_nodes.push_back(value_unknown_node);
140     }
141 
142     for (auto child : GetChildren()) {
143       auto child_value_unknown_nodes = child->CollectValueUnknownNodes();
144       value_unknown_nodes.insert(value_unknown_nodes.end(),
145                                  child_value_unknown_nodes.begin(),
146                                  child_value_unknown_nodes.end());
147     }
148 
149     return value_unknown_nodes;
150   }
151 
152   // Iterator to iterate over the entire DAG. Even though we are using the tree
153   // iterator it should still be safe to iterate over. However, nodes with
154   // multiple parents will be visited multiple times, unlike in a tree.
155   using dag_iterator = TreeDFIterator<SENode>;
156   using const_dag_iterator = TreeDFIterator<const SENode>;
157 
158   // Iterate over all child nodes in the graph.
graph_begin()159   dag_iterator graph_begin() { return dag_iterator(this); }
graph_end()160   dag_iterator graph_end() { return dag_iterator(); }
graph_begin()161   const_dag_iterator graph_begin() const { return graph_cbegin(); }
graph_end()162   const_dag_iterator graph_end() const { return graph_cend(); }
graph_cbegin()163   const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); }
graph_cend()164   const_dag_iterator graph_cend() const { return const_dag_iterator(); }
165 
166   // Return the vector of immediate children.
GetChildren()167   const ChildContainerType& GetChildren() const { return children_; }
GetChildren()168   ChildContainerType& GetChildren() { return children_; }
169 
170   // Return true if this node is a cant compute node.
IsCantCompute()171   bool IsCantCompute() const { return GetType() == CanNotCompute; }
172 
173 // Implements a casting method for each type.
174 // clang-format off
175 #define DeclareCastMethod(target)                  \
176   virtual target* As##target() { return nullptr; } \
177   virtual const target* As##target() const { return nullptr; }
178   DeclareCastMethod(SEConstantNode)
DeclareCastMethod(SERecurrentNode)179   DeclareCastMethod(SERecurrentNode)
180   DeclareCastMethod(SEAddNode)
181   DeclareCastMethod(SEMultiplyNode)
182   DeclareCastMethod(SENegative)
183   DeclareCastMethod(SEValueUnknown)
184   DeclareCastMethod(SECantCompute)
185 #undef DeclareCastMethod
186 
187   // Get the analysis which has this node in its cache.
188   inline ScalarEvolutionAnalysis* GetParentAnalysis() const {
189     return parent_analysis_;
190   }
191 
192  protected:
193   ChildContainerType children_;
194 
195   ScalarEvolutionAnalysis* parent_analysis_;
196 
197   // The unique id of this node, assigned on creation by incrementing the static
198   // node count.
199   uint32_t unique_id_;
200 
201   // The number of nodes created.
202   static uint32_t NumberOfNodes;
203 };
204 // clang-format on
205 
206 // Function object to handle the hashing of SENodes. Hashing algorithm hashes
207 // the type (as a string), the literal value of any constants, and the child
208 // pointers which are assumed to be unique.
209 struct SENodeHash {
210   size_t operator()(const std::unique_ptr<SENode>& node) const;
211   size_t operator()(const SENode* node) const;
212 };
213 
214 // A node representing a constant integer.
215 class SEConstantNode : public SENode {
216  public:
SEConstantNode(ScalarEvolutionAnalysis * parent_analysis,int64_t value)217   SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value)
218       : SENode(parent_analysis), literal_value_(value) {}
219 
GetType()220   SENodeType GetType() const final { return Constant; }
221 
FoldToSingleValue()222   int64_t FoldToSingleValue() const { return literal_value_; }
223 
AsSEConstantNode()224   SEConstantNode* AsSEConstantNode() override { return this; }
AsSEConstantNode()225   const SEConstantNode* AsSEConstantNode() const override { return this; }
226 
AddChild(SENode *)227   inline void AddChild(SENode*) final {
228     assert(false && "Attempting to add a child to a constant node!");
229   }
230 
231  protected:
232   int64_t literal_value_;
233 };
234 
235 // A node representing a recurrent expression in the code. A recurrent
236 // expression is an expression whose value can be expressed as a linear
237 // expression of the loop iterations. Such as an induction variable. The actual
238 // value of a recurrent expression is coefficent_ * iteration + offset_, hence
239 // an induction variable i=0, i++ becomes a recurrent expression with an offset
240 // of zero and a coefficient of one.
241 class SERecurrentNode : public SENode {
242  public:
SERecurrentNode(ScalarEvolutionAnalysis * parent_analysis,const Loop * loop)243   SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop)
244       : SENode(parent_analysis), loop_(loop) {}
245 
GetType()246   SENodeType GetType() const final { return RecurrentAddExpr; }
247 
AddCoefficient(SENode * child)248   inline void AddCoefficient(SENode* child) {
249     coefficient_ = child;
250     SENode::AddChild(child);
251   }
252 
AddOffset(SENode * child)253   inline void AddOffset(SENode* child) {
254     offset_ = child;
255     SENode::AddChild(child);
256   }
257 
GetCoefficient()258   inline const SENode* GetCoefficient() const { return coefficient_; }
GetCoefficient()259   inline SENode* GetCoefficient() { return coefficient_; }
260 
GetOffset()261   inline const SENode* GetOffset() const { return offset_; }
GetOffset()262   inline SENode* GetOffset() { return offset_; }
263 
264   // Return the loop which this recurrent expression is recurring within.
GetLoop()265   const Loop* GetLoop() const { return loop_; }
266 
AsSERecurrentNode()267   SERecurrentNode* AsSERecurrentNode() override { return this; }
AsSERecurrentNode()268   const SERecurrentNode* AsSERecurrentNode() const override { return this; }
269 
270  private:
271   SENode* coefficient_;
272   SENode* offset_;
273   const Loop* loop_;
274 };
275 
276 // A node representing an addition operation between child nodes.
277 class SEAddNode : public SENode {
278  public:
SEAddNode(ScalarEvolutionAnalysis * parent_analysis)279   explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis)
280       : SENode(parent_analysis) {}
281 
GetType()282   SENodeType GetType() const final { return Add; }
283 
AsSEAddNode()284   SEAddNode* AsSEAddNode() override { return this; }
AsSEAddNode()285   const SEAddNode* AsSEAddNode() const override { return this; }
286 };
287 
288 // A node representing a multiply operation between child nodes.
289 class SEMultiplyNode : public SENode {
290  public:
SEMultiplyNode(ScalarEvolutionAnalysis * parent_analysis)291   explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis)
292       : SENode(parent_analysis) {}
293 
GetType()294   SENodeType GetType() const final { return Multiply; }
295 
AsSEMultiplyNode()296   SEMultiplyNode* AsSEMultiplyNode() override { return this; }
AsSEMultiplyNode()297   const SEMultiplyNode* AsSEMultiplyNode() const override { return this; }
298 };
299 
300 // A node representing a unary negative operation.
301 class SENegative : public SENode {
302  public:
SENegative(ScalarEvolutionAnalysis * parent_analysis)303   explicit SENegative(ScalarEvolutionAnalysis* parent_analysis)
304       : SENode(parent_analysis) {}
305 
GetType()306   SENodeType GetType() const final { return Negative; }
307 
AsSENegative()308   SENegative* AsSENegative() override { return this; }
AsSENegative()309   const SENegative* AsSENegative() const override { return this; }
310 };
311 
312 // A node representing a value which we do not know the value of, such as a load
313 // instruction.
314 class SEValueUnknown : public SENode {
315  public:
316   // SEValueUnknowns must come from an instruction |unique_id| is the unique id
317   // of that instruction. This is so we cancompare value unknowns and have a
318   // unique value unknown for each instruction.
SEValueUnknown(ScalarEvolutionAnalysis * parent_analysis,uint32_t result_id)319   SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id)
320       : SENode(parent_analysis), result_id_(result_id) {}
321 
GetType()322   SENodeType GetType() const final { return ValueUnknown; }
323 
AsSEValueUnknown()324   SEValueUnknown* AsSEValueUnknown() override { return this; }
AsSEValueUnknown()325   const SEValueUnknown* AsSEValueUnknown() const override { return this; }
326 
ResultId()327   inline uint32_t ResultId() const { return result_id_; }
328 
329  private:
330   uint32_t result_id_;
331 };
332 
333 // A node which we cannot reason about at all.
334 class SECantCompute : public SENode {
335  public:
SECantCompute(ScalarEvolutionAnalysis * parent_analysis)336   explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis)
337       : SENode(parent_analysis) {}
338 
GetType()339   SENodeType GetType() const final { return CanNotCompute; }
340 
AsSECantCompute()341   SECantCompute* AsSECantCompute() override { return this; }
AsSECantCompute()342   const SECantCompute* AsSECantCompute() const override { return this; }
343 };
344 
345 }  // namespace opt
346 }  // namespace spvtools
347 #endif  // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
348