1 //===- ADT/SCCIterator.h - Strongly Connected Comp. Iter. -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// 10 /// This builds on the llvm/ADT/GraphTraits.h file to find the strongly 11 /// connected components (SCCs) of a graph in O(N+E) time using Tarjan's DFS 12 /// algorithm. 13 /// 14 /// The SCC iterator has the important property that if a node in SCC S1 has an 15 /// edge to a node in SCC S2, then it visits S1 *after* S2. 16 /// 17 /// To visit S1 *before* S2, use the scc_iterator on the Inverse graph. (NOTE: 18 /// This requires some simple wrappers and is not supported yet.) 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #ifndef LLVM_ADT_SCCITERATOR_H 23 #define LLVM_ADT_SCCITERATOR_H 24 25 #include "llvm/ADT/DenseMap.h" 26 #include "llvm/ADT/DenseSet.h" 27 #include "llvm/ADT/GraphTraits.h" 28 #include "llvm/ADT/iterator.h" 29 #include <cassert> 30 #include <cstddef> 31 #include <iterator> 32 #include <queue> 33 #include <set> 34 #include <unordered_map> 35 #include <unordered_set> 36 #include <vector> 37 38 namespace llvm { 39 40 /// Enumerate the SCCs of a directed graph in reverse topological order 41 /// of the SCC DAG. 42 /// 43 /// This is implemented using Tarjan's DFS algorithm using an internal stack to 44 /// build up a vector of nodes in a particular SCC. Note that it is a forward 45 /// iterator and thus you cannot backtrack or re-visit nodes. 46 template <class GraphT, class GT = GraphTraits<GraphT>> 47 class scc_iterator : public iterator_facade_base< 48 scc_iterator<GraphT, GT>, std::forward_iterator_tag, 49 const std::vector<typename GT::NodeRef>, ptrdiff_t> { 50 using NodeRef = typename GT::NodeRef; 51 using ChildItTy = typename GT::ChildIteratorType; 52 using SccTy = std::vector<NodeRef>; 53 using reference = typename scc_iterator::reference; 54 55 /// Element of VisitStack during DFS. 56 struct StackElement { 57 NodeRef Node; ///< The current node pointer. 58 ChildItTy NextChild; ///< The next child, modified inplace during DFS. 59 unsigned MinVisited; ///< Minimum uplink value of all children of Node. 60 61 StackElement(NodeRef Node, const ChildItTy &Child, unsigned Min) 62 : Node(Node), NextChild(Child), MinVisited(Min) {} 63 64 bool operator==(const StackElement &Other) const { 65 return Node == Other.Node && 66 NextChild == Other.NextChild && 67 MinVisited == Other.MinVisited; 68 } 69 }; 70 71 /// The visit counters used to detect when a complete SCC is on the stack. 72 /// visitNum is the global counter. 73 /// 74 /// nodeVisitNumbers are per-node visit numbers, also used as DFS flags. 75 unsigned visitNum; 76 DenseMap<NodeRef, unsigned> nodeVisitNumbers; 77 78 /// Stack holding nodes of the SCC. 79 std::vector<NodeRef> SCCNodeStack; 80 81 /// The current SCC, retrieved using operator*(). 82 SccTy CurrentSCC; 83 84 /// DFS stack, Used to maintain the ordering. The top contains the current 85 /// node, the next child to visit, and the minimum uplink value of all child 86 std::vector<StackElement> VisitStack; 87 88 /// A single "visit" within the non-recursive DFS traversal. 89 void DFSVisitOne(NodeRef N); 90 91 /// The stack-based DFS traversal; defined below. 92 void DFSVisitChildren(); 93 94 /// Compute the next SCC using the DFS traversal. 95 void GetNextSCC(); 96 97 scc_iterator(NodeRef entryN) : visitNum(0) { 98 DFSVisitOne(entryN); 99 GetNextSCC(); 100 } 101 102 /// End is when the DFS stack is empty. 103 scc_iterator() = default; 104 105 public: 106 static scc_iterator begin(const GraphT &G) { 107 return scc_iterator(GT::getEntryNode(G)); 108 } 109 static scc_iterator end(const GraphT &) { return scc_iterator(); } 110 111 /// Direct loop termination test which is more efficient than 112 /// comparison with \c end(). 113 bool isAtEnd() const { 114 assert(!CurrentSCC.empty() || VisitStack.empty()); 115 return CurrentSCC.empty(); 116 } 117 118 bool operator==(const scc_iterator &x) const { 119 return VisitStack == x.VisitStack && CurrentSCC == x.CurrentSCC; 120 } 121 122 scc_iterator &operator++() { 123 GetNextSCC(); 124 return *this; 125 } 126 127 reference operator*() const { 128 assert(!CurrentSCC.empty() && "Dereferencing END SCC iterator!"); 129 return CurrentSCC; 130 } 131 132 /// Test if the current SCC has a cycle. 133 /// 134 /// If the SCC has more than one node, this is trivially true. If not, it may 135 /// still contain a cycle if the node has an edge back to itself. 136 bool hasCycle() const; 137 138 /// This informs the \c scc_iterator that the specified \c Old node 139 /// has been deleted, and \c New is to be used in its place. 140 void ReplaceNode(NodeRef Old, NodeRef New) { 141 assert(nodeVisitNumbers.count(Old) && "Old not in scc_iterator?"); 142 // Do the assignment in two steps, in case 'New' is not yet in the map, and 143 // inserting it causes the map to grow. 144 auto tempVal = nodeVisitNumbers[Old]; 145 nodeVisitNumbers[New] = tempVal; 146 nodeVisitNumbers.erase(Old); 147 } 148 }; 149 150 template <class GraphT, class GT> 151 void scc_iterator<GraphT, GT>::DFSVisitOne(NodeRef N) { 152 ++visitNum; 153 nodeVisitNumbers[N] = visitNum; 154 SCCNodeStack.push_back(N); 155 VisitStack.push_back(StackElement(N, GT::child_begin(N), visitNum)); 156 #if 0 // Enable if needed when debugging. 157 dbgs() << "TarjanSCC: Node " << N << 158 " : visitNum = " << visitNum << "\n"; 159 #endif 160 } 161 162 template <class GraphT, class GT> 163 void scc_iterator<GraphT, GT>::DFSVisitChildren() { 164 assert(!VisitStack.empty()); 165 while (VisitStack.back().NextChild != GT::child_end(VisitStack.back().Node)) { 166 // TOS has at least one more child so continue DFS 167 NodeRef childN = *VisitStack.back().NextChild++; 168 typename DenseMap<NodeRef, unsigned>::iterator Visited = 169 nodeVisitNumbers.find(childN); 170 if (Visited == nodeVisitNumbers.end()) { 171 // this node has never been seen. 172 DFSVisitOne(childN); 173 continue; 174 } 175 176 unsigned childNum = Visited->second; 177 if (VisitStack.back().MinVisited > childNum) 178 VisitStack.back().MinVisited = childNum; 179 } 180 } 181 182 template <class GraphT, class GT> void scc_iterator<GraphT, GT>::GetNextSCC() { 183 CurrentSCC.clear(); // Prepare to compute the next SCC 184 while (!VisitStack.empty()) { 185 DFSVisitChildren(); 186 187 // Pop the leaf on top of the VisitStack. 188 NodeRef visitingN = VisitStack.back().Node; 189 unsigned minVisitNum = VisitStack.back().MinVisited; 190 assert(VisitStack.back().NextChild == GT::child_end(visitingN)); 191 VisitStack.pop_back(); 192 193 // Propagate MinVisitNum to parent so we can detect the SCC starting node. 194 if (!VisitStack.empty() && VisitStack.back().MinVisited > minVisitNum) 195 VisitStack.back().MinVisited = minVisitNum; 196 197 #if 0 // Enable if needed when debugging. 198 dbgs() << "TarjanSCC: Popped node " << visitingN << 199 " : minVisitNum = " << minVisitNum << "; Node visit num = " << 200 nodeVisitNumbers[visitingN] << "\n"; 201 #endif 202 203 if (minVisitNum != nodeVisitNumbers[visitingN]) 204 continue; 205 206 // A full SCC is on the SCCNodeStack! It includes all nodes below 207 // visitingN on the stack. Copy those nodes to CurrentSCC, 208 // reset their minVisit values, and return (this suspends 209 // the DFS traversal till the next ++). 210 do { 211 CurrentSCC.push_back(SCCNodeStack.back()); 212 SCCNodeStack.pop_back(); 213 nodeVisitNumbers[CurrentSCC.back()] = ~0U; 214 } while (CurrentSCC.back() != visitingN); 215 return; 216 } 217 } 218 219 template <class GraphT, class GT> 220 bool scc_iterator<GraphT, GT>::hasCycle() const { 221 assert(!CurrentSCC.empty() && "Dereferencing END SCC iterator!"); 222 if (CurrentSCC.size() > 1) 223 return true; 224 NodeRef N = CurrentSCC.front(); 225 for (ChildItTy CI = GT::child_begin(N), CE = GT::child_end(N); CI != CE; 226 ++CI) 227 if (*CI == N) 228 return true; 229 return false; 230 } 231 232 /// Construct the begin iterator for a deduced graph type T. 233 template <class T> scc_iterator<T> scc_begin(const T &G) { 234 return scc_iterator<T>::begin(G); 235 } 236 237 /// Construct the end iterator for a deduced graph type T. 238 template <class T> scc_iterator<T> scc_end(const T &G) { 239 return scc_iterator<T>::end(G); 240 } 241 242 /// Sort the nodes of a directed SCC in the decreasing order of the edge 243 /// weights. The instantiating GraphT type should have weighted edge type 244 /// declared in its graph traits in order to use this iterator. 245 /// 246 /// This is implemented using Kruskal's minimal spanning tree algorithm followed 247 /// by Kahn's algorithm to compute a topological order on the MST. First a 248 /// maximum spanning tree (forest) is built based on all edges within the SCC 249 /// collection. Then a topological walk is initiated on tree nodes that do not 250 /// have a predecessor and then applied to all nodes of the SCC. Such order 251 /// ensures that high-weighted edges are visited first during the traversal. 252 template <class GraphT, class GT = GraphTraits<GraphT>> 253 class scc_member_iterator { 254 using NodeType = typename GT::NodeType; 255 using EdgeType = typename GT::EdgeType; 256 using NodesType = std::vector<NodeType *>; 257 258 // Auxilary node information used during the MST calculation. 259 struct NodeInfo { 260 NodeInfo *Group = this; 261 uint32_t Rank = 0; 262 bool Visited = false; 263 DenseSet<const EdgeType *> IncomingMSTEdges; 264 }; 265 266 // Find the root group of the node and compress the path from node to the 267 // root. 268 NodeInfo *find(NodeInfo *Node) { 269 if (Node->Group != Node) 270 Node->Group = find(Node->Group); 271 return Node->Group; 272 } 273 274 // Union the source and target node into the same group and return true. 275 // Returns false if they are already in the same group. 276 bool unionGroups(const EdgeType *Edge) { 277 NodeInfo *G1 = find(&NodeInfoMap[Edge->Source]); 278 NodeInfo *G2 = find(&NodeInfoMap[Edge->Target]); 279 280 // If the edge forms a cycle, do not add it to MST 281 if (G1 == G2) 282 return false; 283 284 // Make the smaller rank tree a direct child or the root of high rank tree. 285 if (G1->Rank < G1->Rank) 286 G1->Group = G2; 287 else { 288 G2->Group = G1; 289 // If the ranks are the same, increment root of one tree by one. 290 if (G1->Rank == G2->Rank) 291 G2->Rank++; 292 } 293 return true; 294 } 295 296 std::unordered_map<NodeType *, NodeInfo> NodeInfoMap; 297 NodesType Nodes; 298 299 public: 300 scc_member_iterator(const NodesType &InputNodes); 301 302 NodesType &operator*() { return Nodes; } 303 }; 304 305 template <class GraphT, class GT> 306 scc_member_iterator<GraphT, GT>::scc_member_iterator( 307 const NodesType &InputNodes) { 308 if (InputNodes.size() <= 1) { 309 Nodes = InputNodes; 310 return; 311 } 312 313 // Initialize auxilary node information. 314 NodeInfoMap.clear(); 315 for (auto *Node : InputNodes) { 316 // This is specifically used to construct a `NodeInfo` object in place. An 317 // insert operation will involve a copy construction which invalidate the 318 // initial value of the `Group` field which should be `this`. 319 (void)NodeInfoMap[Node].Group; 320 } 321 322 // Sort edges by weights. 323 struct EdgeComparer { 324 bool operator()(const EdgeType *L, const EdgeType *R) const { 325 return L->Weight > R->Weight; 326 } 327 }; 328 329 std::multiset<const EdgeType *, EdgeComparer> SortedEdges; 330 for (auto *Node : InputNodes) { 331 for (auto &Edge : Node->Edges) { 332 if (NodeInfoMap.count(Edge.Target)) 333 SortedEdges.insert(&Edge); 334 } 335 } 336 337 // Traverse all the edges and compute the Maximum Weight Spanning Tree 338 // using Kruskal's algorithm. 339 std::unordered_set<const EdgeType *> MSTEdges; 340 for (auto *Edge : SortedEdges) { 341 if (unionGroups(Edge)) 342 MSTEdges.insert(Edge); 343 } 344 345 // Run Kahn's algorithm on MST to compute a topological traversal order. 346 // The algorithm starts from nodes that have no incoming edge. These nodes are 347 // "roots" of the MST forest. This ensures that nodes are visited before their 348 // descendants are, thus ensures hot edges are processed before cold edges, 349 // based on how MST is computed. 350 std::queue<NodeType *> Queue; 351 for (const auto *Edge : MSTEdges) 352 NodeInfoMap[Edge->Target].IncomingMSTEdges.insert(Edge); 353 354 // Walk through SortedEdges to initialize the queue, instead of using NodeInfoMap 355 // to ensure an ordered deterministic push. 356 for (auto *Edge : SortedEdges) { 357 if (!NodeInfoMap[Edge->Source].Visited && 358 NodeInfoMap[Edge->Source].IncomingMSTEdges.empty()) { 359 Queue.push(Edge->Source); 360 NodeInfoMap[Edge->Source].Visited = true; 361 } 362 } 363 364 while (!Queue.empty()) { 365 auto *Node = Queue.front(); 366 Queue.pop(); 367 Nodes.push_back(Node); 368 for (auto &Edge : Node->Edges) { 369 NodeInfoMap[Edge.Target].IncomingMSTEdges.erase(&Edge); 370 if (MSTEdges.count(&Edge) && 371 NodeInfoMap[Edge.Target].IncomingMSTEdges.empty()) { 372 Queue.push(Edge.Target); 373 } 374 } 375 } 376 377 assert(InputNodes.size() == Nodes.size() && "missing nodes in MST"); 378 std::reverse(Nodes.begin(), Nodes.end()); 379 } 380 } // end namespace llvm 381 382 #endif // LLVM_ADT_SCCITERATOR_H 383