1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 //
15 // AllDifferent constraints
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/strings/str_format.h"
25 #include "ortools/base/integral_types.h"
26 #include "ortools/base/logging.h"
27 #include "ortools/constraint_solver/constraint_solver.h"
28 #include "ortools/constraint_solver/constraint_solveri.h"
29 #include "ortools/util/string_array.h"
30
31 namespace operations_research {
32 namespace {
33
34 class BaseAllDifferent : public Constraint {
35 public:
BaseAllDifferent(Solver * const s,const std::vector<IntVar * > & vars)36 BaseAllDifferent(Solver* const s, const std::vector<IntVar*>& vars)
37 : Constraint(s), vars_(vars) {}
~BaseAllDifferent()38 ~BaseAllDifferent() override {}
DebugStringInternal(const std::string & name) const39 std::string DebugStringInternal(const std::string& name) const {
40 return absl::StrFormat("%s(%s)", name, JoinDebugStringPtr(vars_, ", "));
41 }
42
43 protected:
44 const std::vector<IntVar*> vars_;
size() const45 int64_t size() const { return vars_.size(); }
46 };
47
48 //-----------------------------------------------------------------------------
49 // ValueAllDifferent
50
51 class ValueAllDifferent : public BaseAllDifferent {
52 public:
ValueAllDifferent(Solver * const s,const std::vector<IntVar * > & vars)53 ValueAllDifferent(Solver* const s, const std::vector<IntVar*>& vars)
54 : BaseAllDifferent(s, vars) {}
~ValueAllDifferent()55 ~ValueAllDifferent() override {}
56
57 void Post() override;
58 void InitialPropagate() override;
59 void OneMove(int index);
60 bool AllMoves();
61
DebugString() const62 std::string DebugString() const override {
63 return DebugStringInternal("ValueAllDifferent");
64 }
Accept(ModelVisitor * const visitor) const65 void Accept(ModelVisitor* const visitor) const override {
66 visitor->BeginVisitConstraint(ModelVisitor::kAllDifferent, this);
67 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
68 vars_);
69 visitor->VisitIntegerArgument(ModelVisitor::kRangeArgument, 0);
70 visitor->EndVisitConstraint(ModelVisitor::kAllDifferent, this);
71 }
72
73 private:
74 RevSwitch all_instantiated_;
75 };
76
Post()77 void ValueAllDifferent::Post() {
78 for (int i = 0; i < size(); ++i) {
79 IntVar* var = vars_[i];
80 Demon* d = MakeConstraintDemon1(solver(), this, &ValueAllDifferent::OneMove,
81 "OneMove", i);
82 var->WhenBound(d);
83 }
84 }
85
InitialPropagate()86 void ValueAllDifferent::InitialPropagate() {
87 for (int i = 0; i < size(); ++i) {
88 if (vars_[i]->Bound()) {
89 OneMove(i);
90 }
91 }
92 }
93
OneMove(int index)94 void ValueAllDifferent::OneMove(int index) {
95 if (!AllMoves()) {
96 const int64_t val = vars_[index]->Value();
97 for (int j = 0; j < size(); ++j) {
98 if (index != j) {
99 if (vars_[j]->Size() < 0xFFFFFF) {
100 vars_[j]->RemoveValue(val);
101 } else {
102 solver()->AddConstraint(solver()->MakeNonEquality(vars_[j], val));
103 }
104 }
105 }
106 }
107 }
108
AllMoves()109 bool ValueAllDifferent::AllMoves() {
110 if (all_instantiated_.Switched() || size() == 0) {
111 return true;
112 }
113 for (int i = 0; i < size(); ++i) {
114 if (!vars_[i]->Bound()) {
115 return false;
116 }
117 }
118 std::unique_ptr<int64_t[]> values(new int64_t[size()]);
119 for (int i = 0; i < size(); ++i) {
120 values[i] = vars_[i]->Value();
121 }
122 std::sort(values.get(), values.get() + size());
123 for (int i = 0; i < size() - 1; ++i) {
124 if (values[i] == values[i + 1]) {
125 values.reset(); // prevent leaks (solver()->Fail() won't return)
126 solver()->Fail();
127 }
128 }
129 all_instantiated_.Switch(solver());
130 return true;
131 }
132
133 // ---------- Bounds All Different ----------
134 // See http://www.cs.uwaterloo.ca/~cquimper/Papers/ijcai03_TR.pdf for details.
135
136 class RangeBipartiteMatching {
137 public:
138 struct Interval {
139 int64_t min;
140 int64_t max;
141 int min_rank;
142 int max_rank;
143 };
144
RangeBipartiteMatching(Solver * const solver,int size)145 RangeBipartiteMatching(Solver* const solver, int size)
146 : solver_(solver),
147 size_(size),
148 intervals_(new Interval[size + 1]),
149 min_sorted_(new Interval*[size]),
150 max_sorted_(new Interval*[size]),
151 bounds_(new int64_t[2 * size + 2]),
152 tree_(new int[2 * size + 2]),
153 diff_(new int64_t[2 * size + 2]),
154 hall_(new int[2 * size + 2]),
155 active_size_(0) {
156 for (int i = 0; i < size; ++i) {
157 max_sorted_[i] = &intervals_[i];
158 min_sorted_[i] = max_sorted_[i];
159 }
160 }
161
SetRange(int index,int64_t imin,int64_t imax)162 void SetRange(int index, int64_t imin, int64_t imax) {
163 intervals_[index].min = imin;
164 intervals_[index].max = imax;
165 }
166
Propagate()167 bool Propagate() {
168 SortArray();
169
170 const bool modified1 = PropagateMin();
171 const bool modified2 = PropagateMax();
172 return modified1 || modified2;
173 }
174
Min(int index) const175 int64_t Min(int index) const { return intervals_[index].min; }
176
Max(int index) const177 int64_t Max(int index) const { return intervals_[index].max; }
178
179 private:
180 // This method sorts the min_sorted_ and max_sorted_ arrays and fill
181 // the bounds_ array (and set the active_size_ counter).
SortArray()182 void SortArray() {
183 std::sort(min_sorted_.get(), min_sorted_.get() + size_,
184 CompareIntervalMin());
185 std::sort(max_sorted_.get(), max_sorted_.get() + size_,
186 CompareIntervalMax());
187
188 int64_t min = min_sorted_[0]->min;
189 int64_t max = max_sorted_[0]->max + 1;
190 int64_t last = min - 2;
191 bounds_[0] = last;
192
193 int i = 0;
194 int j = 0;
195 int nb = 0;
196 for (;;) { // merge min_sorted_[] and max_sorted_[] into bounds_[].
197 if (i < size_ && min <= max) { // make sure min_sorted_ exhausted first.
198 if (min != last) {
199 last = min;
200 bounds_[++nb] = last;
201 }
202 min_sorted_[i]->min_rank = nb;
203 if (++i < size_) {
204 min = min_sorted_[i]->min;
205 }
206 } else {
207 if (max != last) {
208 last = max;
209 bounds_[++nb] = last;
210 }
211 max_sorted_[j]->max_rank = nb;
212 if (++j == size_) {
213 break;
214 }
215 max = max_sorted_[j]->max + 1;
216 }
217 }
218 active_size_ = nb;
219 bounds_[nb + 1] = bounds_[nb] + 2;
220 }
221
222 // These two methods will actually do the new bounds computation.
PropagateMin()223 bool PropagateMin() {
224 bool modified = false;
225
226 for (int i = 1; i <= active_size_ + 1; ++i) {
227 hall_[i] = i - 1;
228 tree_[i] = i - 1;
229 diff_[i] = bounds_[i] - bounds_[i - 1];
230 }
231 // visit intervals in increasing max order
232 for (int i = 0; i < size_; ++i) {
233 const int x = max_sorted_[i]->min_rank;
234 const int y = max_sorted_[i]->max_rank;
235 int z = PathMax(tree_.get(), x + 1);
236 int j = tree_[z];
237 if (--diff_[z] == 0) {
238 tree_[z] = z + 1;
239 z = PathMax(tree_.get(), z + 1);
240 tree_[z] = j;
241 }
242 PathSet(x + 1, z, z, tree_.get()); // path compression
243 if (diff_[z] < bounds_[z] - bounds_[y]) {
244 solver_->Fail();
245 }
246 if (hall_[x] > x) {
247 int w = PathMax(hall_.get(), hall_[x]);
248 max_sorted_[i]->min = bounds_[w];
249 PathSet(x, w, w, hall_.get()); // path compression
250 modified = true;
251 }
252 if (diff_[z] == bounds_[z] - bounds_[y]) {
253 PathSet(hall_[y], j - 1, y, hall_.get()); // mark hall interval
254 hall_[y] = j - 1;
255 }
256 }
257 return modified;
258 }
259
PropagateMax()260 bool PropagateMax() {
261 bool modified = false;
262
263 for (int i = 0; i <= active_size_; i++) {
264 tree_[i] = i + 1;
265 hall_[i] = i + 1;
266 diff_[i] = bounds_[i + 1] - bounds_[i];
267 }
268 // visit intervals in decreasing min order
269 for (int i = size_ - 1; i >= 0; --i) {
270 const int x = min_sorted_[i]->max_rank;
271 const int y = min_sorted_[i]->min_rank;
272 int z = PathMin(tree_.get(), x - 1);
273 int j = tree_[z];
274 if (--diff_[z] == 0) {
275 tree_[z] = z - 1;
276 z = PathMin(tree_.get(), z - 1);
277 tree_[z] = j;
278 }
279 PathSet(x - 1, z, z, tree_.get());
280 if (diff_[z] < bounds_[y] - bounds_[z]) {
281 solver_->Fail();
282 // useless. Should have been caught by the PropagateMin() method.
283 }
284 if (hall_[x] < x) {
285 int w = PathMin(hall_.get(), hall_[x]);
286 min_sorted_[i]->max = bounds_[w] - 1;
287 PathSet(x, w, w, hall_.get());
288 modified = true;
289 }
290 if (diff_[z] == bounds_[y] - bounds_[z]) {
291 PathSet(hall_[y], j + 1, y, hall_.get());
292 hall_[y] = j + 1;
293 }
294 }
295 return modified;
296 }
297
298 // TODO(user) : use better sort, use bounding boxes of modifications to
299 // improve the sorting (only modified vars).
300
301 // This method is used by the STL sort.
302 struct CompareIntervalMin {
operator ()operations_research::__anon780a12590111::RangeBipartiteMatching::CompareIntervalMin303 bool operator()(const Interval* i1, const Interval* i2) const {
304 return (i1->min < i2->min);
305 }
306 };
307
308 // This method is used by the STL sort.
309 struct CompareIntervalMax {
operator ()operations_research::__anon780a12590111::RangeBipartiteMatching::CompareIntervalMax310 bool operator()(const Interval* i1, const Interval* i2) const {
311 return (i1->max < i2->max);
312 }
313 };
314
PathSet(int start,int end,int to,int * const tree)315 void PathSet(int start, int end, int to, int* const tree) {
316 int l = start;
317 while (l != end) {
318 int k = l;
319 l = tree[k];
320 tree[k] = to;
321 }
322 }
323
PathMin(const int * const tree,int index)324 int PathMin(const int* const tree, int index) {
325 int i = index;
326 while (tree[i] < i) {
327 i = tree[i];
328 }
329 return i;
330 }
331
PathMax(const int * const tree,int index)332 int PathMax(const int* const tree, int index) {
333 int i = index;
334 while (tree[i] > i) {
335 i = tree[i];
336 }
337 return i;
338 }
339
340 Solver* const solver_;
341 const int size_;
342 std::unique_ptr<Interval[]> intervals_;
343 std::unique_ptr<Interval*[]> min_sorted_;
344 std::unique_ptr<Interval*[]> max_sorted_;
345 // bounds_[1..active_size_] hold set of min & max in the n intervals_
346 // while bounds_[0] and bounds_[active_size_ + 1] allow sentinels.
347 std::unique_ptr<int64_t[]> bounds_;
348 std::unique_ptr<int[]> tree_; // tree links.
349 std::unique_ptr<int64_t[]> diff_; // diffs between critical capacities.
350 std::unique_ptr<int[]> hall_; // hall interval links.
351 int active_size_;
352 };
353
354 class BoundsAllDifferent : public BaseAllDifferent {
355 public:
BoundsAllDifferent(Solver * const s,const std::vector<IntVar * > & vars)356 BoundsAllDifferent(Solver* const s, const std::vector<IntVar*>& vars)
357 : BaseAllDifferent(s, vars), matching_(s, vars.size()) {}
358
~BoundsAllDifferent()359 ~BoundsAllDifferent() override {}
360
Post()361 void Post() override {
362 Demon* range = MakeDelayedConstraintDemon0(
363 solver(), this, &BoundsAllDifferent::IncrementalPropagate,
364 "IncrementalPropagate");
365
366 for (int i = 0; i < size(); ++i) {
367 vars_[i]->WhenRange(range);
368 Demon* bound = MakeConstraintDemon1(solver(), this,
369 &BoundsAllDifferent::PropagateValue,
370 "PropagateValue", i);
371 vars_[i]->WhenBound(bound);
372 }
373 }
374
InitialPropagate()375 void InitialPropagate() override {
376 IncrementalPropagate();
377 for (int i = 0; i < size(); ++i) {
378 if (vars_[i]->Bound()) {
379 PropagateValue(i);
380 }
381 }
382 }
383
IncrementalPropagate()384 virtual void IncrementalPropagate() {
385 for (int i = 0; i < size(); ++i) {
386 matching_.SetRange(i, vars_[i]->Min(), vars_[i]->Max());
387 }
388
389 if (matching_.Propagate()) {
390 for (int i = 0; i < size(); ++i) {
391 vars_[i]->SetRange(matching_.Min(i), matching_.Max(i));
392 }
393 }
394 }
395
PropagateValue(int index)396 void PropagateValue(int index) {
397 const int64_t to_remove = vars_[index]->Value();
398 for (int j = 0; j < index; j++) {
399 if (vars_[j]->Size() < 0xFFFFFF) {
400 vars_[j]->RemoveValue(to_remove);
401 } else {
402 solver()->AddConstraint(solver()->MakeNonEquality(vars_[j], to_remove));
403 }
404 }
405 for (int j = index + 1; j < size(); j++) {
406 if (vars_[j]->Size() < 0xFFFFFF) {
407 vars_[j]->RemoveValue(to_remove);
408 } else {
409 solver()->AddConstraint(solver()->MakeNonEquality(vars_[j], to_remove));
410 }
411 }
412 }
413
DebugString() const414 std::string DebugString() const override {
415 return DebugStringInternal("BoundsAllDifferent");
416 }
417
Accept(ModelVisitor * const visitor) const418 void Accept(ModelVisitor* const visitor) const override {
419 visitor->BeginVisitConstraint(ModelVisitor::kAllDifferent, this);
420 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
421 vars_);
422 visitor->VisitIntegerArgument(ModelVisitor::kRangeArgument, 1);
423 visitor->EndVisitConstraint(ModelVisitor::kAllDifferent, this);
424 }
425
426 private:
427 RangeBipartiteMatching matching_;
428 };
429
430 class SortConstraint : public Constraint {
431 public:
SortConstraint(Solver * const solver,const std::vector<IntVar * > & original_vars,const std::vector<IntVar * > & sorted_vars)432 SortConstraint(Solver* const solver,
433 const std::vector<IntVar*>& original_vars,
434 const std::vector<IntVar*>& sorted_vars)
435 : Constraint(solver),
436 ovars_(original_vars),
437 svars_(sorted_vars),
438 mins_(original_vars.size(), 0),
439 maxs_(original_vars.size(), 0),
440 matching_(solver, original_vars.size()) {}
441
~SortConstraint()442 ~SortConstraint() override {}
443
Post()444 void Post() override {
445 Demon* const demon =
446 solver()->MakeDelayedConstraintInitialPropagateCallback(this);
447 for (int i = 0; i < size(); ++i) {
448 ovars_[i]->WhenRange(demon);
449 svars_[i]->WhenRange(demon);
450 }
451 }
452
InitialPropagate()453 void InitialPropagate() override {
454 for (int i = 0; i < size(); ++i) {
455 int64_t vmin = 0;
456 int64_t vmax = 0;
457 ovars_[i]->Range(&vmin, &vmax);
458 mins_[i] = vmin;
459 maxs_[i] = vmax;
460 }
461 // Propagates from variables to sorted variables.
462 std::sort(mins_.begin(), mins_.end());
463 std::sort(maxs_.begin(), maxs_.end());
464 for (int i = 0; i < size(); ++i) {
465 svars_[i]->SetRange(mins_[i], maxs_[i]);
466 }
467 // Maintains sortedness.
468 for (int i = 0; i < size() - 1; ++i) {
469 svars_[i + 1]->SetMin(svars_[i]->Min());
470 }
471 for (int i = size() - 1; i > 0; --i) {
472 svars_[i - 1]->SetMax(svars_[i]->Max());
473 }
474 // Reverse propagation.
475 for (int i = 0; i < size(); ++i) {
476 int64_t imin = 0;
477 int64_t imax = 0;
478 FindIntersectionRange(i, &imin, &imax);
479 matching_.SetRange(i, imin, imax);
480 }
481 matching_.Propagate();
482 for (int i = 0; i < size(); ++i) {
483 const int64_t vmin = svars_[matching_.Min(i)]->Min();
484 const int64_t vmax = svars_[matching_.Max(i)]->Max();
485 ovars_[i]->SetRange(vmin, vmax);
486 }
487 }
488
Accept(ModelVisitor * const visitor) const489 void Accept(ModelVisitor* const visitor) const override {
490 visitor->BeginVisitConstraint(ModelVisitor::kSortingConstraint, this);
491 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
492 ovars_);
493 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kTargetArgument,
494 svars_);
495 visitor->EndVisitConstraint(ModelVisitor::kSortingConstraint, this);
496 }
497
DebugString() const498 std::string DebugString() const override {
499 return absl::StrFormat("Sort(%s, %s)", JoinDebugStringPtr(ovars_, ", "),
500 JoinDebugStringPtr(svars_, ", "));
501 }
502
503 private:
size() const504 int64_t size() const { return ovars_.size(); }
505
FindIntersectionRange(int index,int64_t * const range_min,int64_t * const range_max) const506 void FindIntersectionRange(int index, int64_t* const range_min,
507 int64_t* const range_max) const {
508 // Naive version.
509 // TODO(user): Implement log(n) version.
510 int64_t imin = 0;
511 while (imin < size() && NotIntersect(index, imin)) {
512 imin++;
513 }
514 if (imin == size()) {
515 solver()->Fail();
516 }
517 int64_t imax = size() - 1;
518 while (imax > imin && NotIntersect(index, imax)) {
519 imax--;
520 }
521 *range_min = imin;
522 *range_max = imax;
523 }
524
NotIntersect(int oindex,int sindex) const525 bool NotIntersect(int oindex, int sindex) const {
526 return ovars_[oindex]->Min() > svars_[sindex]->Max() ||
527 ovars_[oindex]->Max() < svars_[sindex]->Min();
528 }
529
530 const std::vector<IntVar*> ovars_;
531 const std::vector<IntVar*> svars_;
532 std::vector<int64_t> mins_;
533 std::vector<int64_t> maxs_;
534 RangeBipartiteMatching matching_;
535 };
536
537 // All variables are pairwise different, unless they are assigned to
538 // the escape value.
539 class AllDifferentExcept : public Constraint {
540 public:
AllDifferentExcept(Solver * const s,std::vector<IntVar * > vars,int64_t escape_value)541 AllDifferentExcept(Solver* const s, std::vector<IntVar*> vars,
542 int64_t escape_value)
543 : Constraint(s), vars_(std::move(vars)), escape_value_(escape_value) {}
544
~AllDifferentExcept()545 ~AllDifferentExcept() override {}
546
Post()547 void Post() override {
548 for (int i = 0; i < vars_.size(); ++i) {
549 IntVar* const var = vars_[i];
550 Demon* const d = MakeConstraintDemon1(
551 solver(), this, &AllDifferentExcept::Propagate, "Propagate", i);
552 var->WhenBound(d);
553 }
554 }
555
InitialPropagate()556 void InitialPropagate() override {
557 for (int i = 0; i < vars_.size(); ++i) {
558 if (vars_[i]->Bound()) {
559 Propagate(i);
560 }
561 }
562 }
563
Propagate(int index)564 void Propagate(int index) {
565 const int64_t val = vars_[index]->Value();
566 if (val != escape_value_) {
567 for (int j = 0; j < vars_.size(); ++j) {
568 if (index != j) {
569 vars_[j]->RemoveValue(val);
570 }
571 }
572 }
573 }
574
DebugString() const575 std::string DebugString() const override {
576 return absl::StrFormat("AllDifferentExcept([%s], %d",
577 JoinDebugStringPtr(vars_, ", "), escape_value_);
578 }
579
Accept(ModelVisitor * const visitor) const580 void Accept(ModelVisitor* const visitor) const override {
581 visitor->BeginVisitConstraint(ModelVisitor::kAllDifferent, this);
582 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
583 vars_);
584 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, escape_value_);
585 visitor->EndVisitConstraint(ModelVisitor::kAllDifferent, this);
586 }
587
588 private:
589 std::vector<IntVar*> vars_;
590 const int64_t escape_value_;
591 };
592
593 // Creates a constraint that states that all variables in the first
594 // vector are different from all variables from the second group,
595 // unless they are assigned to the escape value if it is defined. Thus
596 // the set of values in the first vector minus the escape value does
597 // not intersect the set of values in the second vector.
598 class NullIntersectArrayExcept : public Constraint {
599 public:
NullIntersectArrayExcept(Solver * const s,std::vector<IntVar * > first_vars,std::vector<IntVar * > second_vars,int64_t escape_value)600 NullIntersectArrayExcept(Solver* const s, std::vector<IntVar*> first_vars,
601 std::vector<IntVar*> second_vars,
602 int64_t escape_value)
603 : Constraint(s),
604 first_vars_(std::move(first_vars)),
605 second_vars_(std::move(second_vars)),
606 escape_value_(escape_value),
607 has_escape_value_(true) {}
608
NullIntersectArrayExcept(Solver * const s,std::vector<IntVar * > first_vars,std::vector<IntVar * > second_vars)609 NullIntersectArrayExcept(Solver* const s, std::vector<IntVar*> first_vars,
610 std::vector<IntVar*> second_vars)
611 : Constraint(s),
612 first_vars_(std::move(first_vars)),
613 second_vars_(std::move(second_vars)),
614 escape_value_(0),
615 has_escape_value_(false) {}
616
~NullIntersectArrayExcept()617 ~NullIntersectArrayExcept() override {}
618
Post()619 void Post() override {
620 for (int i = 0; i < first_vars_.size(); ++i) {
621 IntVar* const var = first_vars_[i];
622 Demon* const d = MakeConstraintDemon1(
623 solver(), this, &NullIntersectArrayExcept::PropagateFirst,
624 "PropagateFirst", i);
625 var->WhenBound(d);
626 }
627 for (int i = 0; i < second_vars_.size(); ++i) {
628 IntVar* const var = second_vars_[i];
629 Demon* const d = MakeConstraintDemon1(
630 solver(), this, &NullIntersectArrayExcept::PropagateSecond,
631 "PropagateSecond", i);
632 var->WhenBound(d);
633 }
634 }
635
InitialPropagate()636 void InitialPropagate() override {
637 for (int i = 0; i < first_vars_.size(); ++i) {
638 if (first_vars_[i]->Bound()) {
639 PropagateFirst(i);
640 }
641 }
642 for (int i = 0; i < second_vars_.size(); ++i) {
643 if (second_vars_[i]->Bound()) {
644 PropagateSecond(i);
645 }
646 }
647 }
648
PropagateFirst(int index)649 void PropagateFirst(int index) {
650 const int64_t val = first_vars_[index]->Value();
651 if (!has_escape_value_ || val != escape_value_) {
652 for (int j = 0; j < second_vars_.size(); ++j) {
653 second_vars_[j]->RemoveValue(val);
654 }
655 }
656 }
657
PropagateSecond(int index)658 void PropagateSecond(int index) {
659 const int64_t val = second_vars_[index]->Value();
660 if (!has_escape_value_ || val != escape_value_) {
661 for (int j = 0; j < first_vars_.size(); ++j) {
662 first_vars_[j]->RemoveValue(val);
663 }
664 }
665 }
666
DebugString() const667 std::string DebugString() const override {
668 return absl::StrFormat("NullIntersectArray([%s], [%s], escape = %d",
669 JoinDebugStringPtr(first_vars_, ", "),
670 JoinDebugStringPtr(second_vars_, ", "),
671 escape_value_);
672 }
673
Accept(ModelVisitor * const visitor) const674 void Accept(ModelVisitor* const visitor) const override {
675 visitor->BeginVisitConstraint(ModelVisitor::kNullIntersect, this);
676 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kLeftArgument,
677 first_vars_);
678 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kRightArgument,
679 second_vars_);
680 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, escape_value_);
681 visitor->EndVisitConstraint(ModelVisitor::kNullIntersect, this);
682 }
683
684 private:
685 std::vector<IntVar*> first_vars_;
686 std::vector<IntVar*> second_vars_;
687 const int64_t escape_value_;
688 const bool has_escape_value_;
689 };
690 } // namespace
691
MakeAllDifferent(const std::vector<IntVar * > & vars)692 Constraint* Solver::MakeAllDifferent(const std::vector<IntVar*>& vars) {
693 return MakeAllDifferent(vars, true);
694 }
695
MakeAllDifferent(const std::vector<IntVar * > & vars,bool stronger_propagation)696 Constraint* Solver::MakeAllDifferent(const std::vector<IntVar*>& vars,
697 bool stronger_propagation) {
698 const int size = vars.size();
699 for (int i = 0; i < size; ++i) {
700 CHECK_EQ(this, vars[i]->solver());
701 }
702 if (size < 2) {
703 return MakeTrueConstraint();
704 } else if (size == 2) {
705 return MakeNonEquality(const_cast<IntVar* const>(vars[0]),
706 const_cast<IntVar* const>(vars[1]));
707 } else {
708 if (stronger_propagation) {
709 return RevAlloc(new BoundsAllDifferent(this, vars));
710 } else {
711 return RevAlloc(new ValueAllDifferent(this, vars));
712 }
713 }
714 }
715
MakeSortingConstraint(const std::vector<IntVar * > & vars,const std::vector<IntVar * > & sorted)716 Constraint* Solver::MakeSortingConstraint(const std::vector<IntVar*>& vars,
717 const std::vector<IntVar*>& sorted) {
718 CHECK_EQ(vars.size(), sorted.size());
719 return RevAlloc(new SortConstraint(this, vars, sorted));
720 }
721
MakeAllDifferentExcept(const std::vector<IntVar * > & vars,int64_t escape_value)722 Constraint* Solver::MakeAllDifferentExcept(const std::vector<IntVar*>& vars,
723 int64_t escape_value) {
724 int escape_candidates = 0;
725 for (int i = 0; i < vars.size(); ++i) {
726 escape_candidates += (vars[i]->Contains(escape_value));
727 }
728 if (escape_candidates <= 1) {
729 return MakeAllDifferent(vars);
730 } else {
731 return RevAlloc(new AllDifferentExcept(this, vars, escape_value));
732 }
733 }
734
MakeNullIntersect(const std::vector<IntVar * > & first_vars,const std::vector<IntVar * > & second_vars)735 Constraint* Solver::MakeNullIntersect(const std::vector<IntVar*>& first_vars,
736 const std::vector<IntVar*>& second_vars) {
737 return RevAlloc(new NullIntersectArrayExcept(this, first_vars, second_vars));
738 }
739
MakeNullIntersectExcept(const std::vector<IntVar * > & first_vars,const std::vector<IntVar * > & second_vars,int64_t escape_value)740 Constraint* Solver::MakeNullIntersectExcept(
741 const std::vector<IntVar*>& first_vars,
742 const std::vector<IntVar*>& second_vars, int64_t escape_value) {
743 int first_escape_candidates = 0;
744 for (int i = 0; i < first_vars.size(); ++i) {
745 first_escape_candidates += (first_vars[i]->Contains(escape_value));
746 }
747 int second_escape_candidates = 0;
748 for (int i = 0; i < second_vars.size(); ++i) {
749 second_escape_candidates += (second_vars[i]->Contains(escape_value));
750 }
751 if (first_escape_candidates == 0 || second_escape_candidates == 0) {
752 return RevAlloc(
753 new NullIntersectArrayExcept(this, first_vars, second_vars));
754 } else {
755 return RevAlloc(new NullIntersectArrayExcept(this, first_vars, second_vars,
756 escape_value));
757 }
758 }
759 } // namespace operations_research
760