1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #ifndef OR_TOOLS_GRAPH_MINIMUM_SPANNING_TREE_H_
15 #define OR_TOOLS_GRAPH_MINIMUM_SPANNING_TREE_H_
16 
17 #include <queue>
18 #include <vector>
19 
20 #include "ortools/base/adjustable_priority_queue-inl.h"
21 #include "ortools/base/adjustable_priority_queue.h"
22 #include "ortools/base/integral_types.h"
23 #include "ortools/graph/connected_components.h"
24 #include "ortools/util/vector_or_function.h"
25 
26 namespace operations_research {
27 
28 // Implementation of Kruskal's mininumum spanning tree algorithm (c.f.
29 // https://en.wikipedia.org/wiki/Kruskal%27s_algorithm).
30 // Returns the index of the arcs appearing in the tree; will return a forest if
31 // the graph is disconnected. Nodes without any arcs will be ignored.
32 // Each arc of the graph is interpreted as an undirected arc.
33 // Complexity of the algorithm is O(E * log(E)) where E is the number of arcs
34 // in the graph. Memory usage is O(E * log(E)).
35 
36 // TODO(user): Add a global Minimum Spanning Tree API automatically switching
37 // between Prim and Kruskal depending on problem size.
38 
39 // Version taking sorted graph arcs. Allows somewhat incremental recomputation
40 // of minimum spanning trees as most of the processing time is spent sorting
41 // arcs.
42 // Usage:
43 //  ListGraph<int, int> graph(...);
44 //  std::vector<int> sorted_arcs = ...;
45 //  std::vector<int> mst = BuildKruskalMinimumSpanningTreeFromSortedArcs(
46 //      graph, sorted_arcs);
47 //
48 template <typename Graph>
49 std::vector<typename Graph::ArcIndex>
BuildKruskalMinimumSpanningTreeFromSortedArcs(const Graph & graph,const std::vector<typename Graph::ArcIndex> & sorted_arcs)50 BuildKruskalMinimumSpanningTreeFromSortedArcs(
51     const Graph& graph,
52     const std::vector<typename Graph::ArcIndex>& sorted_arcs) {
53   using ArcIndex = typename Graph::ArcIndex;
54   using NodeIndex = typename Graph::NodeIndex;
55   const int num_arcs = graph.num_arcs();
56   int arc_index = 0;
57   std::vector<ArcIndex> tree_arcs;
58   if (graph.num_nodes() == 0) {
59     return tree_arcs;
60   }
61   const int expected_tree_size = graph.num_nodes() - 1;
62   tree_arcs.reserve(expected_tree_size);
63   DenseConnectedComponentsFinder components;
64   components.SetNumberOfNodes(graph.num_nodes());
65   while (tree_arcs.size() != expected_tree_size && arc_index < num_arcs) {
66     const ArcIndex arc = sorted_arcs[arc_index];
67     const auto tail = graph.Tail(arc);
68     const auto head = graph.Head(arc);
69     if (!components.Connected(tail, head)) {
70       components.AddEdge(tail, head);
71       tree_arcs.push_back(arc);
72     }
73     ++arc_index;
74   }
75   return tree_arcs;
76 }
77 
78 // Version taking an arc comparator to sort graph arcs.
79 // Usage:
80 //  ListGraph<int, int> graph(...);
81 //  const auto arc_cost = [&graph](int arc) {
82 //                           return f(graph.Tail(arc), graph.Head(arc));
83 //                        };
84 //  std::vector<int> mst = BuildKruskalMinimumSpanningTree(
85 //      graph,
86 //      [&arc_cost](int a, int b) { return arc_cost(a) < arc_cost(b); });
87 //
88 template <typename Graph, typename ArcComparator>
BuildKruskalMinimumSpanningTree(const Graph & graph,const ArcComparator & arc_comparator)89 std::vector<typename Graph::ArcIndex> BuildKruskalMinimumSpanningTree(
90     const Graph& graph, const ArcComparator& arc_comparator) {
91   using ArcIndex = typename Graph::ArcIndex;
92   std::vector<ArcIndex> sorted_arcs(graph.num_arcs());
93   for (const ArcIndex arc : graph.AllForwardArcs()) {
94     sorted_arcs[arc] = arc;
95   }
96   std::sort(sorted_arcs.begin(), sorted_arcs.end(), arc_comparator);
97   return BuildKruskalMinimumSpanningTreeFromSortedArcs(graph, sorted_arcs);
98 }
99 
100 // Implementation of Prim's mininumum spanning tree algorithm (c.f.
101 // https://en.wikipedia.org/wiki/Prim's_algorithm) on undirected connected
102 // graphs.
103 // Returns the index of the arcs appearing in the tree.
104 // Complexity of the algorithm is O(E * log(V)) where E is the number of arcs
105 // in the graph, V is the number of vertices. Memory usage is O(V) + memory
106 // taken by the graph.
107 // Usage:
108 //  ListGraph<int, int> graph(...);
109 //  const auto arc_cost = [&graph](int arc) -> int64_t {
110 //                           return f(graph.Tail(arc), graph.Head(arc));
111 //                        };
112 //  std::vector<int> mst = BuildPrimMinimumSpanningTree(graph, arc_cost);
113 //
114 template <typename Graph, typename ArcValue>
BuildPrimMinimumSpanningTree(const Graph & graph,const ArcValue & arc_value)115 std::vector<typename Graph::ArcIndex> BuildPrimMinimumSpanningTree(
116     const Graph& graph, const ArcValue& arc_value) {
117   using ArcIndex = typename Graph::ArcIndex;
118   using NodeIndex = typename Graph::NodeIndex;
119   using ArcValueType = decltype(arc_value(0));
120   std::vector<ArcIndex> tree_arcs;
121   if (graph.num_nodes() == 0) {
122     return tree_arcs;
123   }
124   const int expected_tree_size = graph.num_nodes() - 1;
125   tree_arcs.reserve(expected_tree_size);
126   std::vector<ArcIndex> node_neighbor(graph.num_nodes(), Graph::kNilArc);
127   std::vector<bool> node_active(graph.num_nodes(), true);
128 
129   // This struct represents entries in the adjustable priority queue which
130   // maintains active nodes (not added to the tree yet) in decreasing insertion
131   // cost order. AdjustablePriorityQueue requires the existence of the
132   // SetHeapIndex and GetHeapIndex methods.
133   struct Entry {
134     void SetHeapIndex(int index) { heap_index = index; }
135     int GetHeapIndex() const { return heap_index; }
136     bool operator<(const Entry& other) const { return value > other.value; }
137 
138     NodeIndex node;
139     ArcValueType value;
140     int heap_index;
141   };
142 
143   AdjustablePriorityQueue<Entry> pq;
144   std::vector<Entry> entries;
145   std::vector<bool> touched_entry(graph.num_nodes(), false);
146   for (NodeIndex node : graph.AllNodes()) {
147     entries.push_back({node, std::numeric_limits<ArcValueType>::max(), -1});
148   }
149   entries[0].value = 0;
150   pq.Add(&entries[0]);
151   while (!pq.IsEmpty() && tree_arcs.size() != expected_tree_size) {
152     const Entry* best = pq.Top();
153     const NodeIndex node = best->node;
154     pq.Pop();
155     node_active[node] = false;
156     if (node_neighbor[node] != Graph::kNilArc) {
157       tree_arcs.push_back(node_neighbor[node]);
158     }
159     for (const ArcIndex arc : graph.OutgoingArcs(node)) {
160       const NodeIndex neighbor = graph.Head(arc);
161       if (node_active[neighbor]) {
162         const ArcValueType value = arc_value(arc);
163         Entry& entry = entries[neighbor];
164         if (value < entry.value || !touched_entry[neighbor]) {
165           node_neighbor[neighbor] = arc;
166           entry.value = value;
167           touched_entry[neighbor] = true;
168           if (pq.Contains(&entry)) {
169             pq.NoteChangedPriority(&entry);
170           } else {
171             pq.Add(&entry);
172           }
173         }
174       }
175     }
176   }
177   return tree_arcs;
178 }
179 
180 }  // namespace operations_research
181 #endif  // OR_TOOLS_GRAPH_MINIMUM_SPANNING_TREE_H_
182