1 //===- llvm/ADT/DirectedGraph.h - Directed Graph ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the interface and a base class implementation for a
11 /// directed graph.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_ADT_DIRECTEDGRAPH_H
16 #define LLVM_ADT_DIRECTEDGRAPH_H
17 
18 #include "llvm/ADT/GraphTraits.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 namespace llvm {
25 
26 /// Represent an edge in the directed graph.
27 /// The edge contains the target node it connects to.
28 template <class NodeType, class EdgeType> class DGEdge {
29 public:
30   DGEdge() = delete;
31   /// Create an edge pointing to the given node \p N.
32   explicit DGEdge(NodeType &N) : TargetNode(N) {}
33   explicit DGEdge(const DGEdge<NodeType, EdgeType> &E)
34       : TargetNode(E.TargetNode) {}
35   DGEdge<NodeType, EdgeType> &operator=(const DGEdge<NodeType, EdgeType> &E) {
36     TargetNode = E.TargetNode;
37     return *this;
38   }
39 
40   /// Static polymorphism: delegate implementation (via isEqualTo) to the
41   /// derived class.
42   bool operator==(const DGEdge &E) const {
43     return getDerived().isEqualTo(E.getDerived());
44   }
45   bool operator!=(const DGEdge &E) const { return !operator==(E); }
46 
47   /// Retrieve the target node this edge connects to.
48   const NodeType &getTargetNode() const { return TargetNode; }
49   NodeType &getTargetNode() {
50     return const_cast<NodeType &>(
51         static_cast<const DGEdge<NodeType, EdgeType> &>(*this).getTargetNode());
52   }
53 
54   /// Set the target node this edge connects to.
55   void setTargetNode(const NodeType &N) { TargetNode = N; }
56 
57 protected:
58   // As the default implementation use address comparison for equality.
59   bool isEqualTo(const EdgeType &E) const { return this == &E; }
60 
61   // Cast the 'this' pointer to the derived type and return a reference.
62   EdgeType &getDerived() { return *static_cast<EdgeType *>(this); }
63   const EdgeType &getDerived() const {
64     return *static_cast<const EdgeType *>(this);
65   }
66 
67   // The target node this edge connects to.
68   NodeType &TargetNode;
69 };
70 
71 /// Represent a node in the directed graph.
72 /// The node has a (possibly empty) list of outgoing edges.
73 template <class NodeType, class EdgeType> class DGNode {
74 public:
75   using EdgeListTy = SetVector<EdgeType *>;
76   using iterator = typename EdgeListTy::iterator;
77   using const_iterator = typename EdgeListTy::const_iterator;
78 
79   /// Create a node with a single outgoing edge \p E.
80   explicit DGNode(EdgeType &E) : Edges() { Edges.insert(&E); }
81   DGNode() = default;
82 
83   explicit DGNode(const DGNode<NodeType, EdgeType> &N) : Edges(N.Edges) {}
84   DGNode(DGNode<NodeType, EdgeType> &&N) : Edges(std::move(N.Edges)) {}
85 
86   DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &N) {
87     Edges = N.Edges;
88     return *this;
89   }
90   DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &&N) {
91     Edges = std::move(N.Edges);
92     return *this;
93   }
94 
95   /// Static polymorphism: delegate implementation (via isEqualTo) to the
96   /// derived class.
97   friend bool operator==(const NodeType &M, const NodeType &N) {
98     return M.isEqualTo(N);
99   }
100   friend bool operator!=(const NodeType &M, const NodeType &N) {
101     return !(M == N);
102   }
103 
104   const_iterator begin() const { return Edges.begin(); }
105   const_iterator end() const { return Edges.end(); }
106   iterator begin() { return Edges.begin(); }
107   iterator end() { return Edges.end(); }
108   const EdgeType &front() const { return *Edges.front(); }
109   EdgeType &front() { return *Edges.front(); }
110   const EdgeType &back() const { return *Edges.back(); }
111   EdgeType &back() { return *Edges.back(); }
112 
113   /// Collect in \p EL, all the edges from this node to \p N.
114   /// Return true if at least one edge was found, and false otherwise.
115   /// Note that this implementation allows more than one edge to connect
116   /// a given pair of nodes.
117   bool findEdgesTo(const NodeType &N, SmallVectorImpl<EdgeType *> &EL) const {
118     assert(EL.empty() && "Expected the list of edges to be empty.");
119     for (auto *E : Edges)
120       if (E->getTargetNode() == N)
121         EL.push_back(E);
122     return !EL.empty();
123   }
124 
125   /// Add the given edge \p E to this node, if it doesn't exist already. Returns
126   /// true if the edge is added and false otherwise.
127   bool addEdge(EdgeType &E) { return Edges.insert(&E); }
128 
129   /// Remove the given edge \p E from this node, if it exists.
130   void removeEdge(EdgeType &E) { Edges.remove(&E); }
131 
132   /// Test whether there is an edge that goes from this node to \p N.
133   bool hasEdgeTo(const NodeType &N) const {
134     return (findEdgeTo(N) != Edges.end());
135   }
136 
137   /// Retrieve the outgoing edges for the node.
138   const EdgeListTy &getEdges() const { return Edges; }
139   EdgeListTy &getEdges() {
140     return const_cast<EdgeListTy &>(
141         static_cast<const DGNode<NodeType, EdgeType> &>(*this).Edges);
142   }
143 
144   /// Clear the outgoing edges.
145   void clear() { Edges.clear(); }
146 
147 protected:
148   // As the default implementation use address comparison for equality.
149   bool isEqualTo(const NodeType &N) const { return this == &N; }
150 
151   // Cast the 'this' pointer to the derived type and return a reference.
152   NodeType &getDerived() { return *static_cast<NodeType *>(this); }
153   const NodeType &getDerived() const {
154     return *static_cast<const NodeType *>(this);
155   }
156 
157   /// Find an edge to \p N. If more than one edge exists, this will return
158   /// the first one in the list of edges.
159   const_iterator findEdgeTo(const NodeType &N) const {
160     return llvm::find_if(
161         Edges, [&N](const EdgeType *E) { return E->getTargetNode() == N; });
162   }
163 
164   // The list of outgoing edges.
165   EdgeListTy Edges;
166 };
167 
168 /// Directed graph
169 ///
170 /// The graph is represented by a table of nodes.
171 /// Each node contains a (possibly empty) list of outgoing edges.
172 /// Each edge contains the target node it connects to.
173 template <class NodeType, class EdgeType> class DirectedGraph {
174 protected:
175   using NodeListTy = SmallVector<NodeType *, 10>;
176   using EdgeListTy = SmallVector<EdgeType *, 10>;
177 public:
178   using iterator = typename NodeListTy::iterator;
179   using const_iterator = typename NodeListTy::const_iterator;
180   using DGraphType = DirectedGraph<NodeType, EdgeType>;
181 
182   DirectedGraph() = default;
183   explicit DirectedGraph(NodeType &N) : Nodes() { addNode(N); }
184   DirectedGraph(const DGraphType &G) : Nodes(G.Nodes) {}
185   DirectedGraph(DGraphType &&RHS) : Nodes(std::move(RHS.Nodes)) {}
186   DGraphType &operator=(const DGraphType &G) {
187     Nodes = G.Nodes;
188     return *this;
189   }
190   DGraphType &operator=(const DGraphType &&G) {
191     Nodes = std::move(G.Nodes);
192     return *this;
193   }
194 
195   const_iterator begin() const { return Nodes.begin(); }
196   const_iterator end() const { return Nodes.end(); }
197   iterator begin() { return Nodes.begin(); }
198   iterator end() { return Nodes.end(); }
199   const NodeType &front() const { return *Nodes.front(); }
200   NodeType &front() { return *Nodes.front(); }
201   const NodeType &back() const { return *Nodes.back(); }
202   NodeType &back() { return *Nodes.back(); }
203 
204   size_t size() const { return Nodes.size(); }
205 
206   /// Find the given node \p N in the table.
207   const_iterator findNode(const NodeType &N) const {
208     return llvm::find_if(Nodes,
209                          [&N](const NodeType *Node) { return *Node == N; });
210   }
211   iterator findNode(const NodeType &N) {
212     return const_cast<iterator>(
213         static_cast<const DGraphType &>(*this).findNode(N));
214   }
215 
216   /// Add the given node \p N to the graph if it is not already present.
217   bool addNode(NodeType &N) {
218     if (findNode(N) != Nodes.end())
219       return false;
220     Nodes.push_back(&N);
221     return true;
222   }
223 
224   /// Collect in \p EL all edges that are coming into node \p N. Return true
225   /// if at least one edge was found, and false otherwise.
226   bool findIncomingEdgesToNode(const NodeType &N, SmallVectorImpl<EdgeType*> &EL) const {
227     assert(EL.empty() && "Expected the list of edges to be empty.");
228     EdgeListTy TempList;
229     for (auto *Node : Nodes) {
230       if (*Node == N)
231         continue;
232       Node->findEdgesTo(N, TempList);
233       llvm::append_range(EL, TempList);
234       TempList.clear();
235     }
236     return !EL.empty();
237   }
238 
239   /// Remove the given node \p N from the graph. If the node has incoming or
240   /// outgoing edges, they are also removed. Return true if the node was found
241   /// and then removed, and false if the node was not found in the graph to
242   /// begin with.
243   bool removeNode(NodeType &N) {
244     iterator IT = findNode(N);
245     if (IT == Nodes.end())
246       return false;
247     // Remove incoming edges.
248     EdgeListTy EL;
249     for (auto *Node : Nodes) {
250       if (*Node == N)
251         continue;
252       Node->findEdgesTo(N, EL);
253       for (auto *E : EL)
254         Node->removeEdge(*E);
255       EL.clear();
256     }
257     N.clear();
258     Nodes.erase(IT);
259     return true;
260   }
261 
262   /// Assuming nodes \p Src and \p Dst are already in the graph, connect node \p
263   /// Src to node \p Dst using the provided edge \p E. Return true if \p Src is
264   /// not already connected to \p Dst via \p E, and false otherwise.
265   bool connect(NodeType &Src, NodeType &Dst, EdgeType &E) {
266     assert(findNode(Src) != Nodes.end() && "Src node should be present.");
267     assert(findNode(Dst) != Nodes.end() && "Dst node should be present.");
268     assert((E.getTargetNode() == Dst) &&
269            "Target of the given edge does not match Dst.");
270     return Src.addEdge(E);
271   }
272 
273 protected:
274   // The list of nodes in the graph.
275   NodeListTy Nodes;
276 };
277 
278 } // namespace llvm
279 
280 #endif // LLVM_ADT_DIRECTEDGRAPH_H
281