1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #include "ortools/algorithms/dynamic_partition.h"
15 
16 #include <algorithm>
17 #include <cstdint>
18 
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/str_join.h"
21 #include "ortools/base/murmur.h"
22 
23 namespace operations_research {
24 
25 namespace {
FprintOfInt32(int i)26 uint64_t FprintOfInt32(int i) {
27   return util_hash::MurmurHash64(reinterpret_cast<const char*>(&i),
28                                  sizeof(int));
29 }
30 }  // namespace
31 
DynamicPartition(int num_elements)32 DynamicPartition::DynamicPartition(int num_elements) {
33   DCHECK_GE(num_elements, 0);
34   element_.assign(num_elements, -1);
35   index_of_.assign(num_elements, -1);
36   for (int i = 0; i < num_elements; ++i) {
37     element_[i] = i;
38     index_of_[i] = i;
39   }
40   part_of_.assign(num_elements, 0);
41   uint64_t fprint = 0;
42   for (int i = 0; i < num_elements; ++i) fprint ^= FprintOfInt32(i);
43   part_.push_back(Part(/*start_index=*/0, /*end_index=*/num_elements,
44                        /*parent_part=*/0,
45                        /*fprint=*/fprint));
46 }
47 
DynamicPartition(const std::vector<int> & initial_part_of_element)48 DynamicPartition::DynamicPartition(
49     const std::vector<int>& initial_part_of_element) {
50   if (initial_part_of_element.empty()) return;
51   part_of_ = initial_part_of_element;
52   const int n = part_of_.size();
53   const int num_parts = 1 + *std::max_element(part_of_.begin(), part_of_.end());
54   DCHECK_EQ(0, *std::min_element(part_of_.begin(), part_of_.end()));
55   part_.resize(num_parts);
56 
57   // Compute the part fingerprints.
58   for (int i = 0; i < n; ++i) part_[part_of_[i]].fprint ^= FprintOfInt32(i);
59 
60   // Compute the actual start indices of each part, knowing that we'll sort
61   // them as they were given implicitly in "initial_part_of_element".
62   // The code looks a bit weird to do it in-place, with no additional memory.
63   for (int p = 0; p < num_parts; ++p) {
64     part_[p].end_index = 0;  // Temporarily utilized as size_of_part.
65     part_[p].parent_part = p;
66   }
67   for (const int p : part_of_) ++part_[p].end_index;  // size_of_part
68   int sum_part_sizes = 0;
69   for (int p = 0; p < num_parts; ++p) {
70     part_[p].start_index = sum_part_sizes;
71     sum_part_sizes += part_[p].end_index;  // size of part.
72   }
73 
74   // Now that we have the correct start indices, we set the end indices to the
75   // start indices, and incrementally add all elements to their part, adjusting
76   // the end indices as we go.
77   for (Part& part : part_) part.end_index = part.start_index;
78   element_.assign(n, -1);
79   index_of_.assign(n, -1);
80   for (int element = 0; element < n; ++element) {
81     Part* const part = &part_[part_of_[element]];
82     element_[part->end_index] = element;
83     index_of_[element] = part->end_index;
84     ++part->end_index;
85   }
86 
87   // Verify that we did it right.
88   // TODO(user): either remove this or factor it out if it can be used
89   // elsewhere.
90   DCHECK_EQ(0, part_[0].start_index);
91   DCHECK_EQ(NumElements(), part_[NumParts() - 1].end_index);
92   for (int p = 1; p < NumParts(); ++p) {
93     DCHECK_EQ(part_[p - 1].end_index, part_[p].start_index);
94   }
95 }
96 
Refine(const std::vector<int> & distinguished_subset)97 void DynamicPartition::Refine(const std::vector<int>& distinguished_subset) {
98   // tmp_counter_of_part_[i] will contain the number of
99   // elements in distinguished_subset that were part of part #i.
100   tmp_counter_of_part_.resize(NumParts(), 0);
101   // We remember the Parts that were actually affected.
102   tmp_affected_parts_.clear();
103   for (const int element : distinguished_subset) {
104     DCHECK_GE(element, 0);
105     DCHECK_LT(element, NumElements());
106     const int part = part_of_[element];
107     const int num_distinguished_elements_in_part = ++tmp_counter_of_part_[part];
108     // Is this the first time that we touch this element's part?
109     if (num_distinguished_elements_in_part == 1) {
110       // TODO(user): optimize the common singleton case.
111       tmp_affected_parts_.push_back(part);
112     }
113     // Move the element to the end of its current Part.
114     const int old_index = index_of_[element];
115     const int new_index =
116         part_[part].end_index - num_distinguished_elements_in_part;
117     DCHECK_GE(new_index, old_index)
118         << "Duplicate element given to Refine(): " << element;
119     // Perform the swap, keeping index_of_ up to date.
120     index_of_[element] = new_index;
121     index_of_[element_[new_index]] = old_index;
122     std::swap(element_[old_index], element_[new_index]);
123   }
124 
125   // Sort affected parts. This is important to behave as advertised in the .h.
126   // TODO(user): automatically switch to an O(N) sort when it's faster
127   // than this one, which is O(K log K) with K = tmp_affected_parts_.size().
128   std::sort(tmp_affected_parts_.begin(), tmp_affected_parts_.end());
129 
130   // Iterate on each affected part and split it, or keep it intact if all
131   // of its elements were distinguished.
132   for (const int part : tmp_affected_parts_) {
133     const int start_index = part_[part].start_index;
134     const int end_index = part_[part].end_index;
135     const int split_index = end_index - tmp_counter_of_part_[part];
136     tmp_counter_of_part_[part] = 0;  // Clean up after us.
137     DCHECK_GE(split_index, start_index);
138     DCHECK_LT(split_index, end_index);
139 
140     // Do nothing if all elements were distinguished.
141     if (split_index == start_index) continue;
142 
143     // Compute the fingerprint of the new part.
144     uint64_t new_fprint = 0;
145     for (int i = split_index; i < end_index; ++i) {
146       new_fprint ^= FprintOfInt32(element_[i]);
147     }
148 
149     const int new_part = NumParts();
150 
151     // Perform the split.
152     part_[part].end_index = split_index;
153     part_[part].fprint ^= new_fprint;
154     part_.push_back(Part(/*start_index*/ split_index, /*end_index*/ end_index,
155                          /*parent_part*/ part, new_fprint));
156     for (const int element : ElementsInPart(new_part)) {
157       part_of_[element] = new_part;
158     }
159   }
160 }
161 
UndoRefineUntilNumPartsEqual(int original_num_parts)162 void DynamicPartition::UndoRefineUntilNumPartsEqual(int original_num_parts) {
163   DCHECK_GE(NumParts(), original_num_parts);
164   DCHECK_GE(original_num_parts, 1);
165   while (NumParts() > original_num_parts) {
166     const int part_index = NumParts() - 1;
167     const Part& part = part_[part_index];
168     const int parent_part_index = part.parent_part;
169     DCHECK_LT(parent_part_index, part_index) << "UndoRefineUntilNumPartsEqual()"
170                                                 " called with "
171                                                 "'original_num_parts' too low";
172 
173     // Update the part contents: actually merge "part" onto its parent.
174     for (const int element : ElementsInPart(part_index)) {
175       part_of_[element] = parent_part_index;
176     }
177     Part* const parent_part = &part_[parent_part_index];
178     DCHECK_EQ(part.start_index, parent_part->end_index);
179     parent_part->end_index = part.end_index;
180     parent_part->fprint ^= part.fprint;
181     part_.pop_back();
182   }
183 }
184 
DebugString(DebugStringSorting sorting) const185 std::string DynamicPartition::DebugString(DebugStringSorting sorting) const {
186   if (sorting != SORT_LEXICOGRAPHICALLY && sorting != SORT_BY_PART) {
187     return absl::StrFormat("Unsupported sorting: %d", sorting);
188   }
189   std::vector<std::vector<int>> parts;
190   for (int i = 0; i < NumParts(); ++i) {
191     IterablePart iterable_part = ElementsInPart(i);
192     parts.emplace_back(iterable_part.begin(), iterable_part.end());
193     std::sort(parts.back().begin(), parts.back().end());
194   }
195   if (sorting == SORT_LEXICOGRAPHICALLY) {
196     std::sort(parts.begin(), parts.end());
197   }
198   std::string out;
199   for (const std::vector<int>& part : parts) {
200     if (!out.empty()) out += " | ";
201     out += absl::StrJoin(part, " ");
202   }
203   return out;
204 }
205 
Reset(int num_nodes)206 void MergingPartition::Reset(int num_nodes) {
207   DCHECK_GE(num_nodes, 0);
208   part_size_.assign(num_nodes, 1);
209   parent_.assign(num_nodes, -1);
210   for (int i = 0; i < num_nodes; ++i) parent_[i] = i;
211   tmp_part_bit_.assign(num_nodes, false);
212 }
213 
MergePartsOf(int node1,int node2)214 int MergingPartition::MergePartsOf(int node1, int node2) {
215   DCHECK_GE(node1, 0);
216   DCHECK_GE(node2, 0);
217   DCHECK_LT(node1, NumNodes());
218   DCHECK_LT(node2, NumNodes());
219   int root1 = GetRoot(node1);
220   int root2 = GetRoot(node2);
221   if (root1 == root2) return -1;
222   int s1 = part_size_[root1];
223   int s2 = part_size_[root2];
224   // Attach the smaller part to the larger one. Break ties by root index.
225   if (s1 < s2 || (s1 == s2 && root1 > root2)) {
226     std::swap(root1, root2);
227     std::swap(s1, s2);
228   }
229 
230   // Update the part size. Don't change part_size_[root2]: it won't be used
231   // again by further merges.
232   part_size_[root1] += part_size_[root2];
233   SetParentAlongPathToRoot(node1, root1);
234   SetParentAlongPathToRoot(node2, root1);
235   return root2;
236 }
237 
GetRootAndCompressPath(int node)238 int MergingPartition::GetRootAndCompressPath(int node) {
239   DCHECK_GE(node, 0);
240   DCHECK_LT(node, NumNodes());
241   const int root = GetRoot(node);
242   SetParentAlongPathToRoot(node, root);
243   return root;
244 }
245 
KeepOnlyOneNodePerPart(std::vector<int> * nodes)246 void MergingPartition::KeepOnlyOneNodePerPart(std::vector<int>* nodes) {
247   int num_nodes_kept = 0;
248   for (const int node : *nodes) {
249     const int representative = GetRootAndCompressPath(node);
250     if (!tmp_part_bit_[representative]) {
251       tmp_part_bit_[representative] = true;
252       (*nodes)[num_nodes_kept++] = node;
253     }
254   }
255   nodes->resize(num_nodes_kept);
256 
257   // Clean up the tmp_part_bit_ vector. Since we've already compressed the
258   // paths (if backtracking was enabled), no need to do it again.
259   for (const int node : *nodes) tmp_part_bit_[GetRoot(node)] = false;
260 }
261 
FillEquivalenceClasses(std::vector<int> * node_equivalence_classes)262 int MergingPartition::FillEquivalenceClasses(
263     std::vector<int>* node_equivalence_classes) {
264   node_equivalence_classes->assign(NumNodes(), -1);
265   int num_roots = 0;
266   for (int node = 0; node < NumNodes(); ++node) {
267     const int root = GetRootAndCompressPath(node);
268     if ((*node_equivalence_classes)[root] < 0) {
269       (*node_equivalence_classes)[root] = num_roots;
270       ++num_roots;
271     }
272     (*node_equivalence_classes)[node] = (*node_equivalence_classes)[root];
273   }
274   return num_roots;
275 }
276 
DebugString()277 std::string MergingPartition::DebugString() {
278   std::vector<std::vector<int>> sorted_parts(NumNodes());
279   for (int i = 0; i < NumNodes(); ++i) {
280     sorted_parts[GetRootAndCompressPath(i)].push_back(i);
281   }
282   for (std::vector<int>& part : sorted_parts)
283     std::sort(part.begin(), part.end());
284   std::sort(sorted_parts.begin(), sorted_parts.end());
285   // Note: typically, a lot of elements of "sorted_parts" will be empty,
286   // but these won't be visible in the string that we construct below.
287   std::string out;
288   for (const std::vector<int>& part : sorted_parts) {
289     if (!out.empty()) out += " | ";
290     out += absl::StrJoin(part, " ");
291   }
292   return out;
293 }
294 
295 }  // namespace operations_research
296