1 //===- ADT/SCCIterator.h - Strongly Connected Comp. Iter. -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// 10 /// This builds on the llvm/ADT/GraphTraits.h file to find the strongly 11 /// connected components (SCCs) of a graph in O(N+E) time using Tarjan's DFS 12 /// algorithm. 13 /// 14 /// The SCC iterator has the important property that if a node in SCC S1 has an 15 /// edge to a node in SCC S2, then it visits S1 *after* S2. 16 /// 17 /// To visit S1 *before* S2, use the scc_iterator on the Inverse graph. (NOTE: 18 /// This requires some simple wrappers and is not supported yet.) 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #ifndef LLVM_ADT_SCCITERATOR_H 23 #define LLVM_ADT_SCCITERATOR_H 24 25 #include "llvm/ADT/DenseMap.h" 26 #include "llvm/ADT/GraphTraits.h" 27 #include "llvm/ADT/iterator.h" 28 #include <cassert> 29 #include <cstddef> 30 #include <iterator> 31 #include <queue> 32 #include <set> 33 #include <unordered_map> 34 #include <unordered_set> 35 #include <vector> 36 37 namespace llvm { 38 39 /// Enumerate the SCCs of a directed graph in reverse topological order 40 /// of the SCC DAG. 41 /// 42 /// This is implemented using Tarjan's DFS algorithm using an internal stack to 43 /// build up a vector of nodes in a particular SCC. Note that it is a forward 44 /// iterator and thus you cannot backtrack or re-visit nodes. 45 template <class GraphT, class GT = GraphTraits<GraphT>> 46 class scc_iterator : public iterator_facade_base< 47 scc_iterator<GraphT, GT>, std::forward_iterator_tag, 48 const std::vector<typename GT::NodeRef>, ptrdiff_t> { 49 using NodeRef = typename GT::NodeRef; 50 using ChildItTy = typename GT::ChildIteratorType; 51 using SccTy = std::vector<NodeRef>; 52 using reference = typename scc_iterator::reference; 53 54 /// Element of VisitStack during DFS. 55 struct StackElement { 56 NodeRef Node; ///< The current node pointer. 57 ChildItTy NextChild; ///< The next child, modified inplace during DFS. 58 unsigned MinVisited; ///< Minimum uplink value of all children of Node. 59 60 StackElement(NodeRef Node, const ChildItTy &Child, unsigned Min) 61 : Node(Node), NextChild(Child), MinVisited(Min) {} 62 63 bool operator==(const StackElement &Other) const { 64 return Node == Other.Node && 65 NextChild == Other.NextChild && 66 MinVisited == Other.MinVisited; 67 } 68 }; 69 70 /// The visit counters used to detect when a complete SCC is on the stack. 71 /// visitNum is the global counter. 72 /// 73 /// nodeVisitNumbers are per-node visit numbers, also used as DFS flags. 74 unsigned visitNum; 75 DenseMap<NodeRef, unsigned> nodeVisitNumbers; 76 77 /// Stack holding nodes of the SCC. 78 std::vector<NodeRef> SCCNodeStack; 79 80 /// The current SCC, retrieved using operator*(). 81 SccTy CurrentSCC; 82 83 /// DFS stack, Used to maintain the ordering. The top contains the current 84 /// node, the next child to visit, and the minimum uplink value of all child 85 std::vector<StackElement> VisitStack; 86 87 /// A single "visit" within the non-recursive DFS traversal. 88 void DFSVisitOne(NodeRef N); 89 90 /// The stack-based DFS traversal; defined below. 91 void DFSVisitChildren(); 92 93 /// Compute the next SCC using the DFS traversal. 94 void GetNextSCC(); 95 96 scc_iterator(NodeRef entryN) : visitNum(0) { 97 DFSVisitOne(entryN); 98 GetNextSCC(); 99 } 100 101 /// End is when the DFS stack is empty. 102 scc_iterator() = default; 103 104 public: 105 static scc_iterator begin(const GraphT &G) { 106 return scc_iterator(GT::getEntryNode(G)); 107 } 108 static scc_iterator end(const GraphT &) { return scc_iterator(); } 109 110 /// Direct loop termination test which is more efficient than 111 /// comparison with \c end(). 112 bool isAtEnd() const { 113 assert(!CurrentSCC.empty() || VisitStack.empty()); 114 return CurrentSCC.empty(); 115 } 116 117 bool operator==(const scc_iterator &x) const { 118 return VisitStack == x.VisitStack && CurrentSCC == x.CurrentSCC; 119 } 120 121 scc_iterator &operator++() { 122 GetNextSCC(); 123 return *this; 124 } 125 126 reference operator*() const { 127 assert(!CurrentSCC.empty() && "Dereferencing END SCC iterator!"); 128 return CurrentSCC; 129 } 130 131 /// Test if the current SCC has a cycle. 132 /// 133 /// If the SCC has more than one node, this is trivially true. If not, it may 134 /// still contain a cycle if the node has an edge back to itself. 135 bool hasCycle() const; 136 137 /// This informs the \c scc_iterator that the specified \c Old node 138 /// has been deleted, and \c New is to be used in its place. 139 void ReplaceNode(NodeRef Old, NodeRef New) { 140 assert(nodeVisitNumbers.count(Old) && "Old not in scc_iterator?"); 141 // Do the assignment in two steps, in case 'New' is not yet in the map, and 142 // inserting it causes the map to grow. 143 auto tempVal = nodeVisitNumbers[Old]; 144 nodeVisitNumbers[New] = tempVal; 145 nodeVisitNumbers.erase(Old); 146 } 147 }; 148 149 template <class GraphT, class GT> 150 void scc_iterator<GraphT, GT>::DFSVisitOne(NodeRef N) { 151 ++visitNum; 152 nodeVisitNumbers[N] = visitNum; 153 SCCNodeStack.push_back(N); 154 VisitStack.push_back(StackElement(N, GT::child_begin(N), visitNum)); 155 #if 0 // Enable if needed when debugging. 156 dbgs() << "TarjanSCC: Node " << N << 157 " : visitNum = " << visitNum << "\n"; 158 #endif 159 } 160 161 template <class GraphT, class GT> 162 void scc_iterator<GraphT, GT>::DFSVisitChildren() { 163 assert(!VisitStack.empty()); 164 while (VisitStack.back().NextChild != GT::child_end(VisitStack.back().Node)) { 165 // TOS has at least one more child so continue DFS 166 NodeRef childN = *VisitStack.back().NextChild++; 167 typename DenseMap<NodeRef, unsigned>::iterator Visited = 168 nodeVisitNumbers.find(childN); 169 if (Visited == nodeVisitNumbers.end()) { 170 // this node has never been seen. 171 DFSVisitOne(childN); 172 continue; 173 } 174 175 unsigned childNum = Visited->second; 176 if (VisitStack.back().MinVisited > childNum) 177 VisitStack.back().MinVisited = childNum; 178 } 179 } 180 181 template <class GraphT, class GT> void scc_iterator<GraphT, GT>::GetNextSCC() { 182 CurrentSCC.clear(); // Prepare to compute the next SCC 183 while (!VisitStack.empty()) { 184 DFSVisitChildren(); 185 186 // Pop the leaf on top of the VisitStack. 187 NodeRef visitingN = VisitStack.back().Node; 188 unsigned minVisitNum = VisitStack.back().MinVisited; 189 assert(VisitStack.back().NextChild == GT::child_end(visitingN)); 190 VisitStack.pop_back(); 191 192 // Propagate MinVisitNum to parent so we can detect the SCC starting node. 193 if (!VisitStack.empty() && VisitStack.back().MinVisited > minVisitNum) 194 VisitStack.back().MinVisited = minVisitNum; 195 196 #if 0 // Enable if needed when debugging. 197 dbgs() << "TarjanSCC: Popped node " << visitingN << 198 " : minVisitNum = " << minVisitNum << "; Node visit num = " << 199 nodeVisitNumbers[visitingN] << "\n"; 200 #endif 201 202 if (minVisitNum != nodeVisitNumbers[visitingN]) 203 continue; 204 205 // A full SCC is on the SCCNodeStack! It includes all nodes below 206 // visitingN on the stack. Copy those nodes to CurrentSCC, 207 // reset their minVisit values, and return (this suspends 208 // the DFS traversal till the next ++). 209 do { 210 CurrentSCC.push_back(SCCNodeStack.back()); 211 SCCNodeStack.pop_back(); 212 nodeVisitNumbers[CurrentSCC.back()] = ~0U; 213 } while (CurrentSCC.back() != visitingN); 214 return; 215 } 216 } 217 218 template <class GraphT, class GT> 219 bool scc_iterator<GraphT, GT>::hasCycle() const { 220 assert(!CurrentSCC.empty() && "Dereferencing END SCC iterator!"); 221 if (CurrentSCC.size() > 1) 222 return true; 223 NodeRef N = CurrentSCC.front(); 224 for (ChildItTy CI = GT::child_begin(N), CE = GT::child_end(N); CI != CE; 225 ++CI) 226 if (*CI == N) 227 return true; 228 return false; 229 } 230 231 /// Construct the begin iterator for a deduced graph type T. 232 template <class T> scc_iterator<T> scc_begin(const T &G) { 233 return scc_iterator<T>::begin(G); 234 } 235 236 /// Construct the end iterator for a deduced graph type T. 237 template <class T> scc_iterator<T> scc_end(const T &G) { 238 return scc_iterator<T>::end(G); 239 } 240 241 /// Sort the nodes of a directed SCC in the decreasing order of the edge 242 /// weights. The instantiating GraphT type should have weighted edge type 243 /// declared in its graph traits in order to use this iterator. 244 /// 245 /// This is implemented using Kruskal's minimal spanning tree algorithm followed 246 /// by a BFS walk. First a maximum spanning tree (forest) is built based on all 247 /// edges within the SCC collection. Then a BFS walk is initiated on tree nodes 248 /// that do not have a predecessor. Finally, the BFS order computed is the 249 /// traversal order of the nodes of the SCC. Such order ensures that 250 /// high-weighted edges are visited first during the tranversal. 251 template <class GraphT, class GT = GraphTraits<GraphT>> 252 class scc_member_iterator { 253 using NodeType = typename GT::NodeType; 254 using EdgeType = typename GT::EdgeType; 255 using NodesType = std::vector<NodeType *>; 256 257 // Auxilary node information used during the MST calculation. 258 struct NodeInfo { 259 NodeInfo *Group = this; 260 uint32_t Rank = 0; 261 bool Visited = true; 262 }; 263 264 // Find the root group of the node and compress the path from node to the 265 // root. 266 NodeInfo *find(NodeInfo *Node) { 267 if (Node->Group != Node) 268 Node->Group = find(Node->Group); 269 return Node->Group; 270 } 271 272 // Union the source and target node into the same group and return true. 273 // Returns false if they are already in the same group. 274 bool unionGroups(const EdgeType *Edge) { 275 NodeInfo *G1 = find(&NodeInfoMap[Edge->Source]); 276 NodeInfo *G2 = find(&NodeInfoMap[Edge->Target]); 277 278 // If the edge forms a cycle, do not add it to MST 279 if (G1 == G2) 280 return false; 281 282 // Make the smaller rank tree a direct child or the root of high rank tree. 283 if (G1->Rank < G1->Rank) 284 G1->Group = G2; 285 else { 286 G2->Group = G1; 287 // If the ranks are the same, increment root of one tree by one. 288 if (G1->Rank == G2->Rank) 289 G2->Rank++; 290 } 291 return true; 292 } 293 294 std::unordered_map<NodeType *, NodeInfo> NodeInfoMap; 295 NodesType Nodes; 296 297 public: 298 scc_member_iterator(const NodesType &InputNodes); 299 300 NodesType &operator*() { return Nodes; } 301 }; 302 303 template <class GraphT, class GT> 304 scc_member_iterator<GraphT, GT>::scc_member_iterator( 305 const NodesType &InputNodes) { 306 if (InputNodes.size() <= 1) { 307 Nodes = InputNodes; 308 return; 309 } 310 311 // Initialize auxilary node information. 312 NodeInfoMap.clear(); 313 for (auto *Node : InputNodes) { 314 // This is specifically used to construct a `NodeInfo` object in place. An 315 // insert operation will involve a copy construction which invalidate the 316 // initial value of the `Group` field which should be `this`. 317 (void)NodeInfoMap[Node].Group; 318 } 319 320 // Sort edges by weights. 321 struct EdgeComparer { 322 bool operator()(const EdgeType *L, const EdgeType *R) const { 323 return L->Weight > R->Weight; 324 } 325 }; 326 327 std::multiset<const EdgeType *, EdgeComparer> SortedEdges; 328 for (auto *Node : InputNodes) { 329 for (auto &Edge : Node->Edges) { 330 if (NodeInfoMap.count(Edge.Target)) 331 SortedEdges.insert(&Edge); 332 } 333 } 334 335 // Traverse all the edges and compute the Maximum Weight Spanning Tree 336 // using Kruskal's algorithm. 337 std::unordered_set<const EdgeType *> MSTEdges; 338 for (auto *Edge : SortedEdges) { 339 if (unionGroups(Edge)) 340 MSTEdges.insert(Edge); 341 } 342 343 // Do BFS on MST, starting from nodes that have no incoming edge. These nodes 344 // are "roots" of the MST forest. This ensures that nodes are visited before 345 // their decsendents are, thus ensures hot edges are processed before cold 346 // edges, based on how MST is computed. 347 for (const auto *Edge : MSTEdges) 348 NodeInfoMap[Edge->Target].Visited = false; 349 350 std::queue<NodeType *> Queue; 351 // Initialze the queue with MST roots. Note that walking through SortedEdges 352 // instead of NodeInfoMap ensures an ordered deterministic push. 353 for (auto *Edge : SortedEdges) { 354 if (NodeInfoMap[Edge->Source].Visited) { 355 Queue.push(Edge->Source); 356 NodeInfoMap[Edge->Source].Visited = false; 357 } 358 } 359 360 while (!Queue.empty()) { 361 auto *Node = Queue.front(); 362 Queue.pop(); 363 Nodes.push_back(Node); 364 for (auto &Edge : Node->Edges) { 365 if (MSTEdges.count(&Edge) && !NodeInfoMap[Edge.Target].Visited) { 366 NodeInfoMap[Edge.Target].Visited = true; 367 Queue.push(Edge.Target); 368 } 369 } 370 } 371 372 assert(InputNodes.size() == Nodes.size() && "missing nodes in MST"); 373 std::reverse(Nodes.begin(), Nodes.end()); 374 } 375 } // end namespace llvm 376 377 #endif // LLVM_ADT_SCCITERATOR_H 378