1 // Copyright (c) 2018 Google LLC. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASI, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ 16 #define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ 17 18 #include <algorithm> 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "source/opt/tree_iterator.h" 24 25 namespace spvtools { 26 namespace opt { 27 28 class Loop; 29 class ScalarEvolutionAnalysis; 30 class SEConstantNode; 31 class SERecurrentNode; 32 class SEAddNode; 33 class SEMultiplyNode; 34 class SENegative; 35 class SEValueUnknown; 36 class SECantCompute; 37 38 // Abstract class representing a node in the scalar evolution DAG. Each node 39 // contains a vector of pointers to its children and each subclass of SENode 40 // implements GetType and an As method to allow casting. SENodes can be hashed 41 // using the SENodeHash functor. The vector of children is sorted when a node is 42 // added. This is important as it allows the hash of X+Y to be the same as Y+X. 43 class SENode { 44 public: 45 enum SENodeType { 46 Constant, 47 RecurrentAddExpr, 48 Add, 49 Multiply, 50 Negative, 51 ValueUnknown, 52 CanNotCompute 53 }; 54 55 using ChildContainerType = std::vector<SENode*>; 56 SENode(ScalarEvolutionAnalysis * parent_analysis)57 explicit SENode(ScalarEvolutionAnalysis* parent_analysis) 58 : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {} 59 60 virtual SENodeType GetType() const = 0; 61 ~SENode()62 virtual ~SENode() {} 63 AddChild(SENode * child)64 virtual inline void AddChild(SENode* child) { 65 // If this is a constant node, assert. 66 if (AsSEConstantNode()) { 67 assert(false && "Trying to add a child node to a constant!"); 68 } 69 70 // Find the first point in the vector where |child| is greater than the node 71 // currently in the vector. 72 auto find_first_less_than = [child](const SENode* node) { 73 return child->unique_id_ <= node->unique_id_; 74 }; 75 76 auto position = std::find_if_not(children_.begin(), children_.end(), 77 find_first_less_than); 78 // Children are sorted so the hashing and equality operator will be the same 79 // for a node with the same children. X+Y should be the same as Y+X. 80 children_.insert(position, child); 81 } 82 83 // Get the type as an std::string. This is used to represent the node in the 84 // dot output and is used to hash the type as well. 85 std::string AsString() const; 86 87 // Dump the SENode and its immediate children, if |recurse| is true then it 88 // will recurse through all children to print the DAG starting from this node 89 // as a root. 90 void DumpDot(std::ostream& out, bool recurse = false) const; 91 92 // Checks if two nodes are the same by hashing them. 93 bool operator==(const SENode& other) const; 94 95 // Checks if two nodes are not the same by comparing the hashes. 96 bool operator!=(const SENode& other) const; 97 98 // Return the child node at |index|. GetChild(size_t index)99 inline SENode* GetChild(size_t index) { return children_[index]; } GetChild(size_t index)100 inline const SENode* GetChild(size_t index) const { return children_[index]; } 101 102 // Iterator to iterate over the child nodes. 103 using iterator = ChildContainerType::iterator; 104 using const_iterator = ChildContainerType::const_iterator; 105 106 // Iterate over immediate child nodes. begin()107 iterator begin() { return children_.begin(); } end()108 iterator end() { return children_.end(); } 109 110 // Constant overloads for iterating over immediate child nodes. begin()111 const_iterator begin() const { return children_.cbegin(); } end()112 const_iterator end() const { return children_.cend(); } cbegin()113 const_iterator cbegin() { return children_.cbegin(); } cend()114 const_iterator cend() { return children_.cend(); } 115 116 // Collect all the recurrent nodes in this SENode CollectRecurrentNodes()117 std::vector<SERecurrentNode*> CollectRecurrentNodes() { 118 std::vector<SERecurrentNode*> recurrent_nodes{}; 119 120 if (auto recurrent_node = AsSERecurrentNode()) { 121 recurrent_nodes.push_back(recurrent_node); 122 } 123 124 for (auto child : GetChildren()) { 125 auto child_recurrent_nodes = child->CollectRecurrentNodes(); 126 recurrent_nodes.insert(recurrent_nodes.end(), 127 child_recurrent_nodes.begin(), 128 child_recurrent_nodes.end()); 129 } 130 131 return recurrent_nodes; 132 } 133 134 // Collect all the value unknown nodes in this SENode CollectValueUnknownNodes()135 std::vector<SEValueUnknown*> CollectValueUnknownNodes() { 136 std::vector<SEValueUnknown*> value_unknown_nodes{}; 137 138 if (auto value_unknown_node = AsSEValueUnknown()) { 139 value_unknown_nodes.push_back(value_unknown_node); 140 } 141 142 for (auto child : GetChildren()) { 143 auto child_value_unknown_nodes = child->CollectValueUnknownNodes(); 144 value_unknown_nodes.insert(value_unknown_nodes.end(), 145 child_value_unknown_nodes.begin(), 146 child_value_unknown_nodes.end()); 147 } 148 149 return value_unknown_nodes; 150 } 151 152 // Iterator to iterate over the entire DAG. Even though we are using the tree 153 // iterator it should still be safe to iterate over. However, nodes with 154 // multiple parents will be visited multiple times, unlike in a tree. 155 using dag_iterator = TreeDFIterator<SENode>; 156 using const_dag_iterator = TreeDFIterator<const SENode>; 157 158 // Iterate over all child nodes in the graph. graph_begin()159 dag_iterator graph_begin() { return dag_iterator(this); } graph_end()160 dag_iterator graph_end() { return dag_iterator(); } graph_begin()161 const_dag_iterator graph_begin() const { return graph_cbegin(); } graph_end()162 const_dag_iterator graph_end() const { return graph_cend(); } graph_cbegin()163 const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); } graph_cend()164 const_dag_iterator graph_cend() const { return const_dag_iterator(); } 165 166 // Return the vector of immediate children. GetChildren()167 const ChildContainerType& GetChildren() const { return children_; } GetChildren()168 ChildContainerType& GetChildren() { return children_; } 169 170 // Return true if this node is a cant compute node. IsCantCompute()171 bool IsCantCompute() const { return GetType() == CanNotCompute; } 172 173 // Implements a casting method for each type. 174 // clang-format off 175 #define DeclareCastMethod(target) \ 176 virtual target* As##target() { return nullptr; } \ 177 virtual const target* As##target() const { return nullptr; } 178 DeclareCastMethod(SEConstantNode) DeclareCastMethod(SERecurrentNode)179 DeclareCastMethod(SERecurrentNode) 180 DeclareCastMethod(SEAddNode) 181 DeclareCastMethod(SEMultiplyNode) 182 DeclareCastMethod(SENegative) 183 DeclareCastMethod(SEValueUnknown) 184 DeclareCastMethod(SECantCompute) 185 #undef DeclareCastMethod 186 187 // Get the analysis which has this node in its cache. 188 inline ScalarEvolutionAnalysis* GetParentAnalysis() const { 189 return parent_analysis_; 190 } 191 192 protected: 193 ChildContainerType children_; 194 195 ScalarEvolutionAnalysis* parent_analysis_; 196 197 // The unique id of this node, assigned on creation by incrementing the static 198 // node count. 199 uint32_t unique_id_; 200 201 // The number of nodes created. 202 static uint32_t NumberOfNodes; 203 }; 204 // clang-format on 205 206 // Function object to handle the hashing of SENodes. Hashing algorithm hashes 207 // the type (as a string), the literal value of any constants, and the child 208 // pointers which are assumed to be unique. 209 struct SENodeHash { 210 size_t operator()(const std::unique_ptr<SENode>& node) const; 211 size_t operator()(const SENode* node) const; 212 }; 213 214 // A node representing a constant integer. 215 class SEConstantNode : public SENode { 216 public: SEConstantNode(ScalarEvolutionAnalysis * parent_analysis,int64_t value)217 SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value) 218 : SENode(parent_analysis), literal_value_(value) {} 219 GetType()220 SENodeType GetType() const final { return Constant; } 221 FoldToSingleValue()222 int64_t FoldToSingleValue() const { return literal_value_; } 223 AsSEConstantNode()224 SEConstantNode* AsSEConstantNode() override { return this; } AsSEConstantNode()225 const SEConstantNode* AsSEConstantNode() const override { return this; } 226 AddChild(SENode *)227 inline void AddChild(SENode*) final { 228 assert(false && "Attempting to add a child to a constant node!"); 229 } 230 231 protected: 232 int64_t literal_value_; 233 }; 234 235 // A node representing a recurrent expression in the code. A recurrent 236 // expression is an expression whose value can be expressed as a linear 237 // expression of the loop iterations. Such as an induction variable. The actual 238 // value of a recurrent expression is coefficent_ * iteration + offset_, hence 239 // an induction variable i=0, i++ becomes a recurrent expression with an offset 240 // of zero and a coefficient of one. 241 class SERecurrentNode : public SENode { 242 public: SERecurrentNode(ScalarEvolutionAnalysis * parent_analysis,const Loop * loop)243 SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop) 244 : SENode(parent_analysis), loop_(loop) {} 245 GetType()246 SENodeType GetType() const final { return RecurrentAddExpr; } 247 AddCoefficient(SENode * child)248 inline void AddCoefficient(SENode* child) { 249 coefficient_ = child; 250 SENode::AddChild(child); 251 } 252 AddOffset(SENode * child)253 inline void AddOffset(SENode* child) { 254 offset_ = child; 255 SENode::AddChild(child); 256 } 257 GetCoefficient()258 inline const SENode* GetCoefficient() const { return coefficient_; } GetCoefficient()259 inline SENode* GetCoefficient() { return coefficient_; } 260 GetOffset()261 inline const SENode* GetOffset() const { return offset_; } GetOffset()262 inline SENode* GetOffset() { return offset_; } 263 264 // Return the loop which this recurrent expression is recurring within. GetLoop()265 const Loop* GetLoop() const { return loop_; } 266 AsSERecurrentNode()267 SERecurrentNode* AsSERecurrentNode() override { return this; } AsSERecurrentNode()268 const SERecurrentNode* AsSERecurrentNode() const override { return this; } 269 270 private: 271 SENode* coefficient_; 272 SENode* offset_; 273 const Loop* loop_; 274 }; 275 276 // A node representing an addition operation between child nodes. 277 class SEAddNode : public SENode { 278 public: SEAddNode(ScalarEvolutionAnalysis * parent_analysis)279 explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis) 280 : SENode(parent_analysis) {} 281 GetType()282 SENodeType GetType() const final { return Add; } 283 AsSEAddNode()284 SEAddNode* AsSEAddNode() override { return this; } AsSEAddNode()285 const SEAddNode* AsSEAddNode() const override { return this; } 286 }; 287 288 // A node representing a multiply operation between child nodes. 289 class SEMultiplyNode : public SENode { 290 public: SEMultiplyNode(ScalarEvolutionAnalysis * parent_analysis)291 explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis) 292 : SENode(parent_analysis) {} 293 GetType()294 SENodeType GetType() const final { return Multiply; } 295 AsSEMultiplyNode()296 SEMultiplyNode* AsSEMultiplyNode() override { return this; } AsSEMultiplyNode()297 const SEMultiplyNode* AsSEMultiplyNode() const override { return this; } 298 }; 299 300 // A node representing a unary negative operation. 301 class SENegative : public SENode { 302 public: SENegative(ScalarEvolutionAnalysis * parent_analysis)303 explicit SENegative(ScalarEvolutionAnalysis* parent_analysis) 304 : SENode(parent_analysis) {} 305 GetType()306 SENodeType GetType() const final { return Negative; } 307 AsSENegative()308 SENegative* AsSENegative() override { return this; } AsSENegative()309 const SENegative* AsSENegative() const override { return this; } 310 }; 311 312 // A node representing a value which we do not know the value of, such as a load 313 // instruction. 314 class SEValueUnknown : public SENode { 315 public: 316 // SEValueUnknowns must come from an instruction |unique_id| is the unique id 317 // of that instruction. This is so we cancompare value unknowns and have a 318 // unique value unknown for each instruction. SEValueUnknown(ScalarEvolutionAnalysis * parent_analysis,uint32_t result_id)319 SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id) 320 : SENode(parent_analysis), result_id_(result_id) {} 321 GetType()322 SENodeType GetType() const final { return ValueUnknown; } 323 AsSEValueUnknown()324 SEValueUnknown* AsSEValueUnknown() override { return this; } AsSEValueUnknown()325 const SEValueUnknown* AsSEValueUnknown() const override { return this; } 326 ResultId()327 inline uint32_t ResultId() const { return result_id_; } 328 329 private: 330 uint32_t result_id_; 331 }; 332 333 // A node which we cannot reason about at all. 334 class SECantCompute : public SENode { 335 public: SECantCompute(ScalarEvolutionAnalysis * parent_analysis)336 explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis) 337 : SENode(parent_analysis) {} 338 GetType()339 SENodeType GetType() const final { return CanNotCompute; } 340 AsSECantCompute()341 SECantCompute* AsSECantCompute() override { return this; } AsSECantCompute()342 const SECantCompute* AsSECantCompute() const override { return this; } 343 }; 344 345 } // namespace opt 346 } // namespace spvtools 347 #endif // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ 348