1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //
15 //  AllDifferent constraints
16 
17 #include <algorithm>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/str_format.h"
25 #include "ortools/base/integral_types.h"
26 #include "ortools/base/logging.h"
27 #include "ortools/constraint_solver/constraint_solver.h"
28 #include "ortools/constraint_solver/constraint_solveri.h"
29 #include "ortools/util/string_array.h"
30 
31 namespace operations_research {
32 namespace {
33 
34 class BaseAllDifferent : public Constraint {
35  public:
BaseAllDifferent(Solver * const s,const std::vector<IntVar * > & vars)36   BaseAllDifferent(Solver* const s, const std::vector<IntVar*>& vars)
37       : Constraint(s), vars_(vars) {}
~BaseAllDifferent()38   ~BaseAllDifferent() override {}
DebugStringInternal(const std::string & name) const39   std::string DebugStringInternal(const std::string& name) const {
40     return absl::StrFormat("%s(%s)", name, JoinDebugStringPtr(vars_, ", "));
41   }
42 
43  protected:
44   const std::vector<IntVar*> vars_;
size() const45   int64_t size() const { return vars_.size(); }
46 };
47 
48 //-----------------------------------------------------------------------------
49 // ValueAllDifferent
50 
51 class ValueAllDifferent : public BaseAllDifferent {
52  public:
ValueAllDifferent(Solver * const s,const std::vector<IntVar * > & vars)53   ValueAllDifferent(Solver* const s, const std::vector<IntVar*>& vars)
54       : BaseAllDifferent(s, vars) {}
~ValueAllDifferent()55   ~ValueAllDifferent() override {}
56 
57   void Post() override;
58   void InitialPropagate() override;
59   void OneMove(int index);
60   bool AllMoves();
61 
DebugString() const62   std::string DebugString() const override {
63     return DebugStringInternal("ValueAllDifferent");
64   }
Accept(ModelVisitor * const visitor) const65   void Accept(ModelVisitor* const visitor) const override {
66     visitor->BeginVisitConstraint(ModelVisitor::kAllDifferent, this);
67     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
68                                                vars_);
69     visitor->VisitIntegerArgument(ModelVisitor::kRangeArgument, 0);
70     visitor->EndVisitConstraint(ModelVisitor::kAllDifferent, this);
71   }
72 
73  private:
74   RevSwitch all_instantiated_;
75 };
76 
Post()77 void ValueAllDifferent::Post() {
78   for (int i = 0; i < size(); ++i) {
79     IntVar* var = vars_[i];
80     Demon* d = MakeConstraintDemon1(solver(), this, &ValueAllDifferent::OneMove,
81                                     "OneMove", i);
82     var->WhenBound(d);
83   }
84 }
85 
InitialPropagate()86 void ValueAllDifferent::InitialPropagate() {
87   for (int i = 0; i < size(); ++i) {
88     if (vars_[i]->Bound()) {
89       OneMove(i);
90     }
91   }
92 }
93 
OneMove(int index)94 void ValueAllDifferent::OneMove(int index) {
95   if (!AllMoves()) {
96     const int64_t val = vars_[index]->Value();
97     for (int j = 0; j < size(); ++j) {
98       if (index != j) {
99         if (vars_[j]->Size() < 0xFFFFFF) {
100           vars_[j]->RemoveValue(val);
101         } else {
102           solver()->AddConstraint(solver()->MakeNonEquality(vars_[j], val));
103         }
104       }
105     }
106   }
107 }
108 
AllMoves()109 bool ValueAllDifferent::AllMoves() {
110   if (all_instantiated_.Switched() || size() == 0) {
111     return true;
112   }
113   for (int i = 0; i < size(); ++i) {
114     if (!vars_[i]->Bound()) {
115       return false;
116     }
117   }
118   std::unique_ptr<int64_t[]> values(new int64_t[size()]);
119   for (int i = 0; i < size(); ++i) {
120     values[i] = vars_[i]->Value();
121   }
122   std::sort(values.get(), values.get() + size());
123   for (int i = 0; i < size() - 1; ++i) {
124     if (values[i] == values[i + 1]) {
125       values.reset();  // prevent leaks (solver()->Fail() won't return)
126       solver()->Fail();
127     }
128   }
129   all_instantiated_.Switch(solver());
130   return true;
131 }
132 
133 // ---------- Bounds All Different ----------
134 // See http://www.cs.uwaterloo.ca/~cquimper/Papers/ijcai03_TR.pdf for details.
135 
136 class RangeBipartiteMatching {
137  public:
138   struct Interval {
139     int64_t min;
140     int64_t max;
141     int min_rank;
142     int max_rank;
143   };
144 
RangeBipartiteMatching(Solver * const solver,int size)145   RangeBipartiteMatching(Solver* const solver, int size)
146       : solver_(solver),
147         size_(size),
148         intervals_(new Interval[size + 1]),
149         min_sorted_(new Interval*[size]),
150         max_sorted_(new Interval*[size]),
151         bounds_(new int64_t[2 * size + 2]),
152         tree_(new int[2 * size + 2]),
153         diff_(new int64_t[2 * size + 2]),
154         hall_(new int[2 * size + 2]),
155         active_size_(0) {
156     for (int i = 0; i < size; ++i) {
157       max_sorted_[i] = &intervals_[i];
158       min_sorted_[i] = max_sorted_[i];
159     }
160   }
161 
SetRange(int index,int64_t imin,int64_t imax)162   void SetRange(int index, int64_t imin, int64_t imax) {
163     intervals_[index].min = imin;
164     intervals_[index].max = imax;
165   }
166 
Propagate()167   bool Propagate() {
168     SortArray();
169 
170     const bool modified1 = PropagateMin();
171     const bool modified2 = PropagateMax();
172     return modified1 || modified2;
173   }
174 
Min(int index) const175   int64_t Min(int index) const { return intervals_[index].min; }
176 
Max(int index) const177   int64_t Max(int index) const { return intervals_[index].max; }
178 
179  private:
180   // This method sorts the min_sorted_ and max_sorted_ arrays and fill
181   // the bounds_ array (and set the active_size_ counter).
SortArray()182   void SortArray() {
183     std::sort(min_sorted_.get(), min_sorted_.get() + size_,
184               CompareIntervalMin());
185     std::sort(max_sorted_.get(), max_sorted_.get() + size_,
186               CompareIntervalMax());
187 
188     int64_t min = min_sorted_[0]->min;
189     int64_t max = max_sorted_[0]->max + 1;
190     int64_t last = min - 2;
191     bounds_[0] = last;
192 
193     int i = 0;
194     int j = 0;
195     int nb = 0;
196     for (;;) {  // merge min_sorted_[] and max_sorted_[] into bounds_[].
197       if (i < size_ && min <= max) {  // make sure min_sorted_ exhausted first.
198         if (min != last) {
199           last = min;
200           bounds_[++nb] = last;
201         }
202         min_sorted_[i]->min_rank = nb;
203         if (++i < size_) {
204           min = min_sorted_[i]->min;
205         }
206       } else {
207         if (max != last) {
208           last = max;
209           bounds_[++nb] = last;
210         }
211         max_sorted_[j]->max_rank = nb;
212         if (++j == size_) {
213           break;
214         }
215         max = max_sorted_[j]->max + 1;
216       }
217     }
218     active_size_ = nb;
219     bounds_[nb + 1] = bounds_[nb] + 2;
220   }
221 
222   // These two methods will actually do the new bounds computation.
PropagateMin()223   bool PropagateMin() {
224     bool modified = false;
225 
226     for (int i = 1; i <= active_size_ + 1; ++i) {
227       hall_[i] = i - 1;
228       tree_[i] = i - 1;
229       diff_[i] = bounds_[i] - bounds_[i - 1];
230     }
231     // visit intervals in increasing max order
232     for (int i = 0; i < size_; ++i) {
233       const int x = max_sorted_[i]->min_rank;
234       const int y = max_sorted_[i]->max_rank;
235       int z = PathMax(tree_.get(), x + 1);
236       int j = tree_[z];
237       if (--diff_[z] == 0) {
238         tree_[z] = z + 1;
239         z = PathMax(tree_.get(), z + 1);
240         tree_[z] = j;
241       }
242       PathSet(x + 1, z, z, tree_.get());  // path compression
243       if (diff_[z] < bounds_[z] - bounds_[y]) {
244         solver_->Fail();
245       }
246       if (hall_[x] > x) {
247         int w = PathMax(hall_.get(), hall_[x]);
248         max_sorted_[i]->min = bounds_[w];
249         PathSet(x, w, w, hall_.get());  // path compression
250         modified = true;
251       }
252       if (diff_[z] == bounds_[z] - bounds_[y]) {
253         PathSet(hall_[y], j - 1, y, hall_.get());  // mark hall interval
254         hall_[y] = j - 1;
255       }
256     }
257     return modified;
258   }
259 
PropagateMax()260   bool PropagateMax() {
261     bool modified = false;
262 
263     for (int i = 0; i <= active_size_; i++) {
264       tree_[i] = i + 1;
265       hall_[i] = i + 1;
266       diff_[i] = bounds_[i + 1] - bounds_[i];
267     }
268     // visit intervals in decreasing min order
269     for (int i = size_ - 1; i >= 0; --i) {
270       const int x = min_sorted_[i]->max_rank;
271       const int y = min_sorted_[i]->min_rank;
272       int z = PathMin(tree_.get(), x - 1);
273       int j = tree_[z];
274       if (--diff_[z] == 0) {
275         tree_[z] = z - 1;
276         z = PathMin(tree_.get(), z - 1);
277         tree_[z] = j;
278       }
279       PathSet(x - 1, z, z, tree_.get());
280       if (diff_[z] < bounds_[y] - bounds_[z]) {
281         solver_->Fail();
282         // useless. Should have been caught by the PropagateMin() method.
283       }
284       if (hall_[x] < x) {
285         int w = PathMin(hall_.get(), hall_[x]);
286         min_sorted_[i]->max = bounds_[w] - 1;
287         PathSet(x, w, w, hall_.get());
288         modified = true;
289       }
290       if (diff_[z] == bounds_[y] - bounds_[z]) {
291         PathSet(hall_[y], j + 1, y, hall_.get());
292         hall_[y] = j + 1;
293       }
294     }
295     return modified;
296   }
297 
298   // TODO(user) : use better sort, use bounding boxes of modifications to
299   //                 improve the sorting (only modified vars).
300 
301   // This method is used by the STL sort.
302   struct CompareIntervalMin {
operator ()operations_research::__anon780a12590111::RangeBipartiteMatching::CompareIntervalMin303     bool operator()(const Interval* i1, const Interval* i2) const {
304       return (i1->min < i2->min);
305     }
306   };
307 
308   // This method is used by the STL sort.
309   struct CompareIntervalMax {
operator ()operations_research::__anon780a12590111::RangeBipartiteMatching::CompareIntervalMax310     bool operator()(const Interval* i1, const Interval* i2) const {
311       return (i1->max < i2->max);
312     }
313   };
314 
PathSet(int start,int end,int to,int * const tree)315   void PathSet(int start, int end, int to, int* const tree) {
316     int l = start;
317     while (l != end) {
318       int k = l;
319       l = tree[k];
320       tree[k] = to;
321     }
322   }
323 
PathMin(const int * const tree,int index)324   int PathMin(const int* const tree, int index) {
325     int i = index;
326     while (tree[i] < i) {
327       i = tree[i];
328     }
329     return i;
330   }
331 
PathMax(const int * const tree,int index)332   int PathMax(const int* const tree, int index) {
333     int i = index;
334     while (tree[i] > i) {
335       i = tree[i];
336     }
337     return i;
338   }
339 
340   Solver* const solver_;
341   const int size_;
342   std::unique_ptr<Interval[]> intervals_;
343   std::unique_ptr<Interval*[]> min_sorted_;
344   std::unique_ptr<Interval*[]> max_sorted_;
345   // bounds_[1..active_size_] hold set of min & max in the n intervals_
346   // while bounds_[0] and bounds_[active_size_ + 1] allow sentinels.
347   std::unique_ptr<int64_t[]> bounds_;
348   std::unique_ptr<int[]> tree_;      // tree links.
349   std::unique_ptr<int64_t[]> diff_;  // diffs between critical capacities.
350   std::unique_ptr<int[]> hall_;      // hall interval links.
351   int active_size_;
352 };
353 
354 class BoundsAllDifferent : public BaseAllDifferent {
355  public:
BoundsAllDifferent(Solver * const s,const std::vector<IntVar * > & vars)356   BoundsAllDifferent(Solver* const s, const std::vector<IntVar*>& vars)
357       : BaseAllDifferent(s, vars), matching_(s, vars.size()) {}
358 
~BoundsAllDifferent()359   ~BoundsAllDifferent() override {}
360 
Post()361   void Post() override {
362     Demon* range = MakeDelayedConstraintDemon0(
363         solver(), this, &BoundsAllDifferent::IncrementalPropagate,
364         "IncrementalPropagate");
365 
366     for (int i = 0; i < size(); ++i) {
367       vars_[i]->WhenRange(range);
368       Demon* bound = MakeConstraintDemon1(solver(), this,
369                                           &BoundsAllDifferent::PropagateValue,
370                                           "PropagateValue", i);
371       vars_[i]->WhenBound(bound);
372     }
373   }
374 
InitialPropagate()375   void InitialPropagate() override {
376     IncrementalPropagate();
377     for (int i = 0; i < size(); ++i) {
378       if (vars_[i]->Bound()) {
379         PropagateValue(i);
380       }
381     }
382   }
383 
IncrementalPropagate()384   virtual void IncrementalPropagate() {
385     for (int i = 0; i < size(); ++i) {
386       matching_.SetRange(i, vars_[i]->Min(), vars_[i]->Max());
387     }
388 
389     if (matching_.Propagate()) {
390       for (int i = 0; i < size(); ++i) {
391         vars_[i]->SetRange(matching_.Min(i), matching_.Max(i));
392       }
393     }
394   }
395 
PropagateValue(int index)396   void PropagateValue(int index) {
397     const int64_t to_remove = vars_[index]->Value();
398     for (int j = 0; j < index; j++) {
399       if (vars_[j]->Size() < 0xFFFFFF) {
400         vars_[j]->RemoveValue(to_remove);
401       } else {
402         solver()->AddConstraint(solver()->MakeNonEquality(vars_[j], to_remove));
403       }
404     }
405     for (int j = index + 1; j < size(); j++) {
406       if (vars_[j]->Size() < 0xFFFFFF) {
407         vars_[j]->RemoveValue(to_remove);
408       } else {
409         solver()->AddConstraint(solver()->MakeNonEquality(vars_[j], to_remove));
410       }
411     }
412   }
413 
DebugString() const414   std::string DebugString() const override {
415     return DebugStringInternal("BoundsAllDifferent");
416   }
417 
Accept(ModelVisitor * const visitor) const418   void Accept(ModelVisitor* const visitor) const override {
419     visitor->BeginVisitConstraint(ModelVisitor::kAllDifferent, this);
420     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
421                                                vars_);
422     visitor->VisitIntegerArgument(ModelVisitor::kRangeArgument, 1);
423     visitor->EndVisitConstraint(ModelVisitor::kAllDifferent, this);
424   }
425 
426  private:
427   RangeBipartiteMatching matching_;
428 };
429 
430 class SortConstraint : public Constraint {
431  public:
SortConstraint(Solver * const solver,const std::vector<IntVar * > & original_vars,const std::vector<IntVar * > & sorted_vars)432   SortConstraint(Solver* const solver,
433                  const std::vector<IntVar*>& original_vars,
434                  const std::vector<IntVar*>& sorted_vars)
435       : Constraint(solver),
436         ovars_(original_vars),
437         svars_(sorted_vars),
438         mins_(original_vars.size(), 0),
439         maxs_(original_vars.size(), 0),
440         matching_(solver, original_vars.size()) {}
441 
~SortConstraint()442   ~SortConstraint() override {}
443 
Post()444   void Post() override {
445     Demon* const demon =
446         solver()->MakeDelayedConstraintInitialPropagateCallback(this);
447     for (int i = 0; i < size(); ++i) {
448       ovars_[i]->WhenRange(demon);
449       svars_[i]->WhenRange(demon);
450     }
451   }
452 
InitialPropagate()453   void InitialPropagate() override {
454     for (int i = 0; i < size(); ++i) {
455       int64_t vmin = 0;
456       int64_t vmax = 0;
457       ovars_[i]->Range(&vmin, &vmax);
458       mins_[i] = vmin;
459       maxs_[i] = vmax;
460     }
461     // Propagates from variables to sorted variables.
462     std::sort(mins_.begin(), mins_.end());
463     std::sort(maxs_.begin(), maxs_.end());
464     for (int i = 0; i < size(); ++i) {
465       svars_[i]->SetRange(mins_[i], maxs_[i]);
466     }
467     // Maintains sortedness.
468     for (int i = 0; i < size() - 1; ++i) {
469       svars_[i + 1]->SetMin(svars_[i]->Min());
470     }
471     for (int i = size() - 1; i > 0; --i) {
472       svars_[i - 1]->SetMax(svars_[i]->Max());
473     }
474     // Reverse propagation.
475     for (int i = 0; i < size(); ++i) {
476       int64_t imin = 0;
477       int64_t imax = 0;
478       FindIntersectionRange(i, &imin, &imax);
479       matching_.SetRange(i, imin, imax);
480     }
481     matching_.Propagate();
482     for (int i = 0; i < size(); ++i) {
483       const int64_t vmin = svars_[matching_.Min(i)]->Min();
484       const int64_t vmax = svars_[matching_.Max(i)]->Max();
485       ovars_[i]->SetRange(vmin, vmax);
486     }
487   }
488 
Accept(ModelVisitor * const visitor) const489   void Accept(ModelVisitor* const visitor) const override {
490     visitor->BeginVisitConstraint(ModelVisitor::kSortingConstraint, this);
491     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
492                                                ovars_);
493     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kTargetArgument,
494                                                svars_);
495     visitor->EndVisitConstraint(ModelVisitor::kSortingConstraint, this);
496   }
497 
DebugString() const498   std::string DebugString() const override {
499     return absl::StrFormat("Sort(%s, %s)", JoinDebugStringPtr(ovars_, ", "),
500                            JoinDebugStringPtr(svars_, ", "));
501   }
502 
503  private:
size() const504   int64_t size() const { return ovars_.size(); }
505 
FindIntersectionRange(int index,int64_t * const range_min,int64_t * const range_max) const506   void FindIntersectionRange(int index, int64_t* const range_min,
507                              int64_t* const range_max) const {
508     // Naive version.
509     // TODO(user): Implement log(n) version.
510     int64_t imin = 0;
511     while (imin < size() && NotIntersect(index, imin)) {
512       imin++;
513     }
514     if (imin == size()) {
515       solver()->Fail();
516     }
517     int64_t imax = size() - 1;
518     while (imax > imin && NotIntersect(index, imax)) {
519       imax--;
520     }
521     *range_min = imin;
522     *range_max = imax;
523   }
524 
NotIntersect(int oindex,int sindex) const525   bool NotIntersect(int oindex, int sindex) const {
526     return ovars_[oindex]->Min() > svars_[sindex]->Max() ||
527            ovars_[oindex]->Max() < svars_[sindex]->Min();
528   }
529 
530   const std::vector<IntVar*> ovars_;
531   const std::vector<IntVar*> svars_;
532   std::vector<int64_t> mins_;
533   std::vector<int64_t> maxs_;
534   RangeBipartiteMatching matching_;
535 };
536 
537 // All variables are pairwise different, unless they are assigned to
538 // the escape value.
539 class AllDifferentExcept : public Constraint {
540  public:
AllDifferentExcept(Solver * const s,std::vector<IntVar * > vars,int64_t escape_value)541   AllDifferentExcept(Solver* const s, std::vector<IntVar*> vars,
542                      int64_t escape_value)
543       : Constraint(s), vars_(std::move(vars)), escape_value_(escape_value) {}
544 
~AllDifferentExcept()545   ~AllDifferentExcept() override {}
546 
Post()547   void Post() override {
548     for (int i = 0; i < vars_.size(); ++i) {
549       IntVar* const var = vars_[i];
550       Demon* const d = MakeConstraintDemon1(
551           solver(), this, &AllDifferentExcept::Propagate, "Propagate", i);
552       var->WhenBound(d);
553     }
554   }
555 
InitialPropagate()556   void InitialPropagate() override {
557     for (int i = 0; i < vars_.size(); ++i) {
558       if (vars_[i]->Bound()) {
559         Propagate(i);
560       }
561     }
562   }
563 
Propagate(int index)564   void Propagate(int index) {
565     const int64_t val = vars_[index]->Value();
566     if (val != escape_value_) {
567       for (int j = 0; j < vars_.size(); ++j) {
568         if (index != j) {
569           vars_[j]->RemoveValue(val);
570         }
571       }
572     }
573   }
574 
DebugString() const575   std::string DebugString() const override {
576     return absl::StrFormat("AllDifferentExcept([%s], %d",
577                            JoinDebugStringPtr(vars_, ", "), escape_value_);
578   }
579 
Accept(ModelVisitor * const visitor) const580   void Accept(ModelVisitor* const visitor) const override {
581     visitor->BeginVisitConstraint(ModelVisitor::kAllDifferent, this);
582     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
583                                                vars_);
584     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, escape_value_);
585     visitor->EndVisitConstraint(ModelVisitor::kAllDifferent, this);
586   }
587 
588  private:
589   std::vector<IntVar*> vars_;
590   const int64_t escape_value_;
591 };
592 
593 // Creates a constraint that states that all variables in the first
594 // vector are different from all variables from the second group,
595 // unless they are assigned to the escape value if it is defined. Thus
596 // the set of values in the first vector minus the escape value does
597 // not intersect the set of values in the second vector.
598 class NullIntersectArrayExcept : public Constraint {
599  public:
NullIntersectArrayExcept(Solver * const s,std::vector<IntVar * > first_vars,std::vector<IntVar * > second_vars,int64_t escape_value)600   NullIntersectArrayExcept(Solver* const s, std::vector<IntVar*> first_vars,
601                            std::vector<IntVar*> second_vars,
602                            int64_t escape_value)
603       : Constraint(s),
604         first_vars_(std::move(first_vars)),
605         second_vars_(std::move(second_vars)),
606         escape_value_(escape_value),
607         has_escape_value_(true) {}
608 
NullIntersectArrayExcept(Solver * const s,std::vector<IntVar * > first_vars,std::vector<IntVar * > second_vars)609   NullIntersectArrayExcept(Solver* const s, std::vector<IntVar*> first_vars,
610                            std::vector<IntVar*> second_vars)
611       : Constraint(s),
612         first_vars_(std::move(first_vars)),
613         second_vars_(std::move(second_vars)),
614         escape_value_(0),
615         has_escape_value_(false) {}
616 
~NullIntersectArrayExcept()617   ~NullIntersectArrayExcept() override {}
618 
Post()619   void Post() override {
620     for (int i = 0; i < first_vars_.size(); ++i) {
621       IntVar* const var = first_vars_[i];
622       Demon* const d = MakeConstraintDemon1(
623           solver(), this, &NullIntersectArrayExcept::PropagateFirst,
624           "PropagateFirst", i);
625       var->WhenBound(d);
626     }
627     for (int i = 0; i < second_vars_.size(); ++i) {
628       IntVar* const var = second_vars_[i];
629       Demon* const d = MakeConstraintDemon1(
630           solver(), this, &NullIntersectArrayExcept::PropagateSecond,
631           "PropagateSecond", i);
632       var->WhenBound(d);
633     }
634   }
635 
InitialPropagate()636   void InitialPropagate() override {
637     for (int i = 0; i < first_vars_.size(); ++i) {
638       if (first_vars_[i]->Bound()) {
639         PropagateFirst(i);
640       }
641     }
642     for (int i = 0; i < second_vars_.size(); ++i) {
643       if (second_vars_[i]->Bound()) {
644         PropagateSecond(i);
645       }
646     }
647   }
648 
PropagateFirst(int index)649   void PropagateFirst(int index) {
650     const int64_t val = first_vars_[index]->Value();
651     if (!has_escape_value_ || val != escape_value_) {
652       for (int j = 0; j < second_vars_.size(); ++j) {
653         second_vars_[j]->RemoveValue(val);
654       }
655     }
656   }
657 
PropagateSecond(int index)658   void PropagateSecond(int index) {
659     const int64_t val = second_vars_[index]->Value();
660     if (!has_escape_value_ || val != escape_value_) {
661       for (int j = 0; j < first_vars_.size(); ++j) {
662         first_vars_[j]->RemoveValue(val);
663       }
664     }
665   }
666 
DebugString() const667   std::string DebugString() const override {
668     return absl::StrFormat("NullIntersectArray([%s], [%s], escape = %d",
669                            JoinDebugStringPtr(first_vars_, ", "),
670                            JoinDebugStringPtr(second_vars_, ", "),
671                            escape_value_);
672   }
673 
Accept(ModelVisitor * const visitor) const674   void Accept(ModelVisitor* const visitor) const override {
675     visitor->BeginVisitConstraint(ModelVisitor::kNullIntersect, this);
676     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kLeftArgument,
677                                                first_vars_);
678     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kRightArgument,
679                                                second_vars_);
680     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, escape_value_);
681     visitor->EndVisitConstraint(ModelVisitor::kNullIntersect, this);
682   }
683 
684  private:
685   std::vector<IntVar*> first_vars_;
686   std::vector<IntVar*> second_vars_;
687   const int64_t escape_value_;
688   const bool has_escape_value_;
689 };
690 }  // namespace
691 
MakeAllDifferent(const std::vector<IntVar * > & vars)692 Constraint* Solver::MakeAllDifferent(const std::vector<IntVar*>& vars) {
693   return MakeAllDifferent(vars, true);
694 }
695 
MakeAllDifferent(const std::vector<IntVar * > & vars,bool stronger_propagation)696 Constraint* Solver::MakeAllDifferent(const std::vector<IntVar*>& vars,
697                                      bool stronger_propagation) {
698   const int size = vars.size();
699   for (int i = 0; i < size; ++i) {
700     CHECK_EQ(this, vars[i]->solver());
701   }
702   if (size < 2) {
703     return MakeTrueConstraint();
704   } else if (size == 2) {
705     return MakeNonEquality(const_cast<IntVar* const>(vars[0]),
706                            const_cast<IntVar* const>(vars[1]));
707   } else {
708     if (stronger_propagation) {
709       return RevAlloc(new BoundsAllDifferent(this, vars));
710     } else {
711       return RevAlloc(new ValueAllDifferent(this, vars));
712     }
713   }
714 }
715 
MakeSortingConstraint(const std::vector<IntVar * > & vars,const std::vector<IntVar * > & sorted)716 Constraint* Solver::MakeSortingConstraint(const std::vector<IntVar*>& vars,
717                                           const std::vector<IntVar*>& sorted) {
718   CHECK_EQ(vars.size(), sorted.size());
719   return RevAlloc(new SortConstraint(this, vars, sorted));
720 }
721 
MakeAllDifferentExcept(const std::vector<IntVar * > & vars,int64_t escape_value)722 Constraint* Solver::MakeAllDifferentExcept(const std::vector<IntVar*>& vars,
723                                            int64_t escape_value) {
724   int escape_candidates = 0;
725   for (int i = 0; i < vars.size(); ++i) {
726     escape_candidates += (vars[i]->Contains(escape_value));
727   }
728   if (escape_candidates <= 1) {
729     return MakeAllDifferent(vars);
730   } else {
731     return RevAlloc(new AllDifferentExcept(this, vars, escape_value));
732   }
733 }
734 
MakeNullIntersect(const std::vector<IntVar * > & first_vars,const std::vector<IntVar * > & second_vars)735 Constraint* Solver::MakeNullIntersect(const std::vector<IntVar*>& first_vars,
736                                       const std::vector<IntVar*>& second_vars) {
737   return RevAlloc(new NullIntersectArrayExcept(this, first_vars, second_vars));
738 }
739 
MakeNullIntersectExcept(const std::vector<IntVar * > & first_vars,const std::vector<IntVar * > & second_vars,int64_t escape_value)740 Constraint* Solver::MakeNullIntersectExcept(
741     const std::vector<IntVar*>& first_vars,
742     const std::vector<IntVar*>& second_vars, int64_t escape_value) {
743   int first_escape_candidates = 0;
744   for (int i = 0; i < first_vars.size(); ++i) {
745     first_escape_candidates += (first_vars[i]->Contains(escape_value));
746   }
747   int second_escape_candidates = 0;
748   for (int i = 0; i < second_vars.size(); ++i) {
749     second_escape_candidates += (second_vars[i]->Contains(escape_value));
750   }
751   if (first_escape_candidates == 0 || second_escape_candidates == 0) {
752     return RevAlloc(
753         new NullIntersectArrayExcept(this, first_vars, second_vars));
754   } else {
755     return RevAlloc(new NullIntersectArrayExcept(this, first_vars, second_vars,
756                                                  escape_value));
757   }
758 }
759 }  // namespace operations_research
760