1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //
15 //  Expression constraints
16 
17 #include <cstddef>
18 #include <cstdint>
19 #include <limits>
20 #include <set>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/str_join.h"
26 #include "ortools/base/commandlineflags.h"
27 #include "ortools/base/integral_types.h"
28 #include "ortools/base/logging.h"
29 #include "ortools/base/stl_util.h"
30 #include "ortools/constraint_solver/constraint_solver.h"
31 #include "ortools/constraint_solver/constraint_solveri.h"
32 #include "ortools/util/saturated_arithmetic.h"
33 #include "ortools/util/sorted_interval_list.h"
34 
35 ABSL_FLAG(int, cache_initial_size, 1024,
36           "Initial size of the array of the hash "
37           "table of caches for objects of type Var(x == 3)");
38 
39 namespace operations_research {
40 
41 //-----------------------------------------------------------------------------
42 // Equality
43 
44 namespace {
45 class EqualityExprCst : public Constraint {
46  public:
47   EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v);
~EqualityExprCst()48   ~EqualityExprCst() override {}
49   void Post() override;
50   void InitialPropagate() override;
Var()51   IntVar* Var() override {
52     return solver()->MakeIsEqualCstVar(expr_->Var(), value_);
53   }
54   std::string DebugString() const override;
55 
Accept(ModelVisitor * const visitor) const56   void Accept(ModelVisitor* const visitor) const override {
57     visitor->BeginVisitConstraint(ModelVisitor::kEquality, this);
58     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
59                                             expr_);
60     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
61     visitor->EndVisitConstraint(ModelVisitor::kEquality, this);
62   }
63 
64  private:
65   IntExpr* const expr_;
66   int64_t value_;
67 };
68 
EqualityExprCst(Solver * const s,IntExpr * const e,int64_t v)69 EqualityExprCst::EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v)
70     : Constraint(s), expr_(e), value_(v) {}
71 
Post()72 void EqualityExprCst::Post() {
73   if (!expr_->IsVar()) {
74     Demon* d = solver()->MakeConstraintInitialPropagateCallback(this);
75     expr_->WhenRange(d);
76   }
77 }
78 
InitialPropagate()79 void EqualityExprCst::InitialPropagate() { expr_->SetValue(value_); }
80 
DebugString() const81 std::string EqualityExprCst::DebugString() const {
82   return absl::StrFormat("(%s == %d)", expr_->DebugString(), value_);
83 }
84 }  // namespace
85 
MakeEquality(IntExpr * const e,int64_t v)86 Constraint* Solver::MakeEquality(IntExpr* const e, int64_t v) {
87   CHECK_EQ(this, e->solver());
88   IntExpr* left = nullptr;
89   IntExpr* right = nullptr;
90   if (IsADifference(e, &left, &right)) {
91     return MakeEquality(left, MakeSum(right, v));
92   } else if (e->IsVar() && !e->Var()->Contains(v)) {
93     return MakeFalseConstraint();
94   } else if (e->Min() == e->Max() && e->Min() == v) {
95     return MakeTrueConstraint();
96   } else {
97     return RevAlloc(new EqualityExprCst(this, e, v));
98   }
99 }
100 
MakeEquality(IntExpr * const e,int v)101 Constraint* Solver::MakeEquality(IntExpr* const e, int v) {
102   CHECK_EQ(this, e->solver());
103   IntExpr* left = nullptr;
104   IntExpr* right = nullptr;
105   if (IsADifference(e, &left, &right)) {
106     return MakeEquality(left, MakeSum(right, v));
107   } else if (e->IsVar() && !e->Var()->Contains(v)) {
108     return MakeFalseConstraint();
109   } else if (e->Min() == e->Max() && e->Min() == v) {
110     return MakeTrueConstraint();
111   } else {
112     return RevAlloc(new EqualityExprCst(this, e, v));
113   }
114 }
115 
116 //-----------------------------------------------------------------------------
117 // Greater or equal constraint
118 
119 namespace {
120 class GreaterEqExprCst : public Constraint {
121  public:
122   GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v);
~GreaterEqExprCst()123   ~GreaterEqExprCst() override {}
124   void Post() override;
125   void InitialPropagate() override;
126   std::string DebugString() const override;
Var()127   IntVar* Var() override {
128     return solver()->MakeIsGreaterOrEqualCstVar(expr_->Var(), value_);
129   }
130 
Accept(ModelVisitor * const visitor) const131   void Accept(ModelVisitor* const visitor) const override {
132     visitor->BeginVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
133     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
134                                             expr_);
135     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
136     visitor->EndVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
137   }
138 
139  private:
140   IntExpr* const expr_;
141   int64_t value_;
142   Demon* demon_;
143 };
144 
GreaterEqExprCst(Solver * const s,IntExpr * const e,int64_t v)145 GreaterEqExprCst::GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
146     : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
147 
Post()148 void GreaterEqExprCst::Post() {
149   if (!expr_->IsVar() && expr_->Min() < value_) {
150     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
151     expr_->WhenRange(demon_);
152   } else {
153     // Let's clean the demon in case the constraint is posted during search.
154     demon_ = nullptr;
155   }
156 }
157 
InitialPropagate()158 void GreaterEqExprCst::InitialPropagate() {
159   expr_->SetMin(value_);
160   if (demon_ != nullptr && expr_->Min() >= value_) {
161     demon_->inhibit(solver());
162   }
163 }
164 
DebugString() const165 std::string GreaterEqExprCst::DebugString() const {
166   return absl::StrFormat("(%s >= %d)", expr_->DebugString(), value_);
167 }
168 }  // namespace
169 
MakeGreaterOrEqual(IntExpr * const e,int64_t v)170 Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int64_t v) {
171   CHECK_EQ(this, e->solver());
172   if (e->Min() >= v) {
173     return MakeTrueConstraint();
174   } else if (e->Max() < v) {
175     return MakeFalseConstraint();
176   } else {
177     return RevAlloc(new GreaterEqExprCst(this, e, v));
178   }
179 }
180 
MakeGreaterOrEqual(IntExpr * const e,int v)181 Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int v) {
182   CHECK_EQ(this, e->solver());
183   if (e->Min() >= v) {
184     return MakeTrueConstraint();
185   } else if (e->Max() < v) {
186     return MakeFalseConstraint();
187   } else {
188     return RevAlloc(new GreaterEqExprCst(this, e, v));
189   }
190 }
191 
MakeGreater(IntExpr * const e,int64_t v)192 Constraint* Solver::MakeGreater(IntExpr* const e, int64_t v) {
193   CHECK_EQ(this, e->solver());
194   if (e->Min() > v) {
195     return MakeTrueConstraint();
196   } else if (e->Max() <= v) {
197     return MakeFalseConstraint();
198   } else {
199     return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
200   }
201 }
202 
MakeGreater(IntExpr * const e,int v)203 Constraint* Solver::MakeGreater(IntExpr* const e, int v) {
204   CHECK_EQ(this, e->solver());
205   if (e->Min() > v) {
206     return MakeTrueConstraint();
207   } else if (e->Max() <= v) {
208     return MakeFalseConstraint();
209   } else {
210     return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
211   }
212 }
213 
214 //-----------------------------------------------------------------------------
215 // Less or equal constraint
216 
217 namespace {
218 class LessEqExprCst : public Constraint {
219  public:
220   LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v);
~LessEqExprCst()221   ~LessEqExprCst() override {}
222   void Post() override;
223   void InitialPropagate() override;
224   std::string DebugString() const override;
Var()225   IntVar* Var() override {
226     return solver()->MakeIsLessOrEqualCstVar(expr_->Var(), value_);
227   }
Accept(ModelVisitor * const visitor) const228   void Accept(ModelVisitor* const visitor) const override {
229     visitor->BeginVisitConstraint(ModelVisitor::kLessOrEqual, this);
230     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
231                                             expr_);
232     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
233     visitor->EndVisitConstraint(ModelVisitor::kLessOrEqual, this);
234   }
235 
236  private:
237   IntExpr* const expr_;
238   int64_t value_;
239   Demon* demon_;
240 };
241 
LessEqExprCst(Solver * const s,IntExpr * const e,int64_t v)242 LessEqExprCst::LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
243     : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
244 
Post()245 void LessEqExprCst::Post() {
246   if (!expr_->IsVar() && expr_->Max() > value_) {
247     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
248     expr_->WhenRange(demon_);
249   } else {
250     // Let's clean the demon in case the constraint is posted during search.
251     demon_ = nullptr;
252   }
253 }
254 
InitialPropagate()255 void LessEqExprCst::InitialPropagate() {
256   expr_->SetMax(value_);
257   if (demon_ != nullptr && expr_->Max() <= value_) {
258     demon_->inhibit(solver());
259   }
260 }
261 
DebugString() const262 std::string LessEqExprCst::DebugString() const {
263   return absl::StrFormat("(%s <= %d)", expr_->DebugString(), value_);
264 }
265 }  // namespace
266 
MakeLessOrEqual(IntExpr * const e,int64_t v)267 Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int64_t v) {
268   CHECK_EQ(this, e->solver());
269   if (e->Max() <= v) {
270     return MakeTrueConstraint();
271   } else if (e->Min() > v) {
272     return MakeFalseConstraint();
273   } else {
274     return RevAlloc(new LessEqExprCst(this, e, v));
275   }
276 }
277 
MakeLessOrEqual(IntExpr * const e,int v)278 Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int v) {
279   CHECK_EQ(this, e->solver());
280   if (e->Max() <= v) {
281     return MakeTrueConstraint();
282   } else if (e->Min() > v) {
283     return MakeFalseConstraint();
284   } else {
285     return RevAlloc(new LessEqExprCst(this, e, v));
286   }
287 }
288 
MakeLess(IntExpr * const e,int64_t v)289 Constraint* Solver::MakeLess(IntExpr* const e, int64_t v) {
290   CHECK_EQ(this, e->solver());
291   if (e->Max() < v) {
292     return MakeTrueConstraint();
293   } else if (e->Min() >= v) {
294     return MakeFalseConstraint();
295   } else {
296     return RevAlloc(new LessEqExprCst(this, e, v - 1));
297   }
298 }
299 
MakeLess(IntExpr * const e,int v)300 Constraint* Solver::MakeLess(IntExpr* const e, int v) {
301   CHECK_EQ(this, e->solver());
302   if (e->Max() < v) {
303     return MakeTrueConstraint();
304   } else if (e->Min() >= v) {
305     return MakeFalseConstraint();
306   } else {
307     return RevAlloc(new LessEqExprCst(this, e, v - 1));
308   }
309 }
310 
311 //-----------------------------------------------------------------------------
312 // Different constraints
313 
314 namespace {
315 class DiffCst : public Constraint {
316  public:
317   DiffCst(Solver* const s, IntVar* const var, int64_t value);
~DiffCst()318   ~DiffCst() override {}
Post()319   void Post() override {}
320   void InitialPropagate() override;
321   void BoundPropagate();
322   std::string DebugString() const override;
Var()323   IntVar* Var() override {
324     return solver()->MakeIsDifferentCstVar(var_, value_);
325   }
Accept(ModelVisitor * const visitor) const326   void Accept(ModelVisitor* const visitor) const override {
327     visitor->BeginVisitConstraint(ModelVisitor::kNonEqual, this);
328     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
329                                             var_);
330     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
331     visitor->EndVisitConstraint(ModelVisitor::kNonEqual, this);
332   }
333 
334  private:
335   bool HasLargeDomain(IntVar* var);
336 
337   IntVar* const var_;
338   int64_t value_;
339   Demon* demon_;
340 };
341 
DiffCst(Solver * const s,IntVar * const var,int64_t value)342 DiffCst::DiffCst(Solver* const s, IntVar* const var, int64_t value)
343     : Constraint(s), var_(var), value_(value), demon_(nullptr) {}
344 
InitialPropagate()345 void DiffCst::InitialPropagate() {
346   if (HasLargeDomain(var_)) {
347     demon_ = MakeConstraintDemon0(solver(), this, &DiffCst::BoundPropagate,
348                                   "BoundPropagate");
349     var_->WhenRange(demon_);
350   } else {
351     var_->RemoveValue(value_);
352   }
353 }
354 
BoundPropagate()355 void DiffCst::BoundPropagate() {
356   const int64_t var_min = var_->Min();
357   const int64_t var_max = var_->Max();
358   if (var_min > value_ || var_max < value_) {
359     demon_->inhibit(solver());
360   } else if (var_min == value_) {
361     var_->SetMin(value_ + 1);
362   } else if (var_max == value_) {
363     var_->SetMax(value_ - 1);
364   } else if (!HasLargeDomain(var_)) {
365     demon_->inhibit(solver());
366     var_->RemoveValue(value_);
367   }
368 }
369 
DebugString() const370 std::string DiffCst::DebugString() const {
371   return absl::StrFormat("(%s != %d)", var_->DebugString(), value_);
372 }
373 
HasLargeDomain(IntVar * var)374 bool DiffCst::HasLargeDomain(IntVar* var) {
375   return CapSub(var->Max(), var->Min()) > 0xFFFFFF;
376 }
377 }  // namespace
378 
MakeNonEquality(IntExpr * const e,int64_t v)379 Constraint* Solver::MakeNonEquality(IntExpr* const e, int64_t v) {
380   CHECK_EQ(this, e->solver());
381   IntExpr* left = nullptr;
382   IntExpr* right = nullptr;
383   if (IsADifference(e, &left, &right)) {
384     return MakeNonEquality(left, MakeSum(right, v));
385   } else if (e->IsVar() && !e->Var()->Contains(v)) {
386     return MakeTrueConstraint();
387   } else if (e->Bound() && e->Min() == v) {
388     return MakeFalseConstraint();
389   } else {
390     return RevAlloc(new DiffCst(this, e->Var(), v));
391   }
392 }
393 
MakeNonEquality(IntExpr * const e,int v)394 Constraint* Solver::MakeNonEquality(IntExpr* const e, int v) {
395   CHECK_EQ(this, e->solver());
396   IntExpr* left = nullptr;
397   IntExpr* right = nullptr;
398   if (IsADifference(e, &left, &right)) {
399     return MakeNonEquality(left, MakeSum(right, v));
400   } else if (e->IsVar() && !e->Var()->Contains(v)) {
401     return MakeTrueConstraint();
402   } else if (e->Bound() && e->Min() == v) {
403     return MakeFalseConstraint();
404   } else {
405     return RevAlloc(new DiffCst(this, e->Var(), v));
406   }
407 }
408 // ----- is_equal_cst Constraint -----
409 
410 namespace {
411 class IsEqualCstCt : public CastConstraint {
412  public:
IsEqualCstCt(Solver * const s,IntVar * const v,int64_t c,IntVar * const b)413   IsEqualCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
414       : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
Post()415   void Post() override {
416     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
417     var_->WhenDomain(demon_);
418     target_var_->WhenBound(demon_);
419   }
InitialPropagate()420   void InitialPropagate() override {
421     bool inhibit = var_->Bound();
422     int64_t u = var_->Contains(cst_);
423     int64_t l = inhibit ? u : 0;
424     target_var_->SetRange(l, u);
425     if (target_var_->Bound()) {
426       if (target_var_->Min() == 0) {
427         if (var_->Size() <= 0xFFFFFF) {
428           var_->RemoveValue(cst_);
429           inhibit = true;
430         }
431       } else {
432         var_->SetValue(cst_);
433         inhibit = true;
434       }
435     }
436     if (inhibit) {
437       demon_->inhibit(solver());
438     }
439   }
DebugString() const440   std::string DebugString() const override {
441     return absl::StrFormat("IsEqualCstCt(%s, %d, %s)", var_->DebugString(),
442                            cst_, target_var_->DebugString());
443   }
444 
Accept(ModelVisitor * const visitor) const445   void Accept(ModelVisitor* const visitor) const override {
446     visitor->BeginVisitConstraint(ModelVisitor::kIsEqual, this);
447     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
448                                             var_);
449     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
450     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
451                                             target_var_);
452     visitor->EndVisitConstraint(ModelVisitor::kIsEqual, this);
453   }
454 
455  private:
456   IntVar* const var_;
457   int64_t cst_;
458   Demon* demon_;
459 };
460 }  // namespace
461 
MakeIsEqualCstVar(IntExpr * const var,int64_t value)462 IntVar* Solver::MakeIsEqualCstVar(IntExpr* const var, int64_t value) {
463   IntExpr* left = nullptr;
464   IntExpr* right = nullptr;
465   if (IsADifference(var, &left, &right)) {
466     return MakeIsEqualVar(left, MakeSum(right, value));
467   }
468   if (CapSub(var->Max(), var->Min()) == 1) {
469     if (value == var->Min()) {
470       return MakeDifference(value + 1, var)->Var();
471     } else if (value == var->Max()) {
472       return MakeSum(var, -value + 1)->Var();
473     } else {
474       return MakeIntConst(0);
475     }
476   }
477   if (var->IsVar()) {
478     return var->Var()->IsEqual(value);
479   } else {
480     IntVar* const boolvar =
481         MakeBoolVar(absl::StrFormat("Is(%s == %d)", var->DebugString(), value));
482     AddConstraint(MakeIsEqualCstCt(var, value, boolvar));
483     return boolvar;
484   }
485 }
486 
MakeIsEqualCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)487 Constraint* Solver::MakeIsEqualCstCt(IntExpr* const var, int64_t value,
488                                      IntVar* const boolvar) {
489   CHECK_EQ(this, var->solver());
490   CHECK_EQ(this, boolvar->solver());
491   if (value == var->Min()) {
492     if (CapSub(var->Max(), var->Min()) == 1) {
493       return MakeEquality(MakeDifference(value + 1, var), boolvar);
494     }
495     return MakeIsLessOrEqualCstCt(var, value, boolvar);
496   }
497   if (value == var->Max()) {
498     if (CapSub(var->Max(), var->Min()) == 1) {
499       return MakeEquality(MakeSum(var, -value + 1), boolvar);
500     }
501     return MakeIsGreaterOrEqualCstCt(var, value, boolvar);
502   }
503   if (boolvar->Bound()) {
504     if (boolvar->Min() == 0) {
505       return MakeNonEquality(var, value);
506     } else {
507       return MakeEquality(var, value);
508     }
509   }
510   // TODO(user) : what happens if the constraint is not posted?
511   // The cache becomes tainted.
512   model_cache_->InsertExprConstantExpression(
513       boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_EQUAL);
514   IntExpr* left = nullptr;
515   IntExpr* right = nullptr;
516   if (IsADifference(var, &left, &right)) {
517     return MakeIsEqualCt(left, MakeSum(right, value), boolvar);
518   } else {
519     return RevAlloc(new IsEqualCstCt(this, var->Var(), value, boolvar));
520   }
521 }
522 
523 // ----- is_diff_cst Constraint -----
524 
525 namespace {
526 class IsDiffCstCt : public CastConstraint {
527  public:
IsDiffCstCt(Solver * const s,IntVar * const v,int64_t c,IntVar * const b)528   IsDiffCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
529       : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
530 
Post()531   void Post() override {
532     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
533     var_->WhenDomain(demon_);
534     target_var_->WhenBound(demon_);
535   }
536 
InitialPropagate()537   void InitialPropagate() override {
538     bool inhibit = var_->Bound();
539     int64_t l = 1 - var_->Contains(cst_);
540     int64_t u = inhibit ? l : 1;
541     target_var_->SetRange(l, u);
542     if (target_var_->Bound()) {
543       if (target_var_->Min() == 1) {
544         if (var_->Size() <= 0xFFFFFF) {
545           var_->RemoveValue(cst_);
546           inhibit = true;
547         }
548       } else {
549         var_->SetValue(cst_);
550         inhibit = true;
551       }
552     }
553     if (inhibit) {
554       demon_->inhibit(solver());
555     }
556   }
557 
DebugString() const558   std::string DebugString() const override {
559     return absl::StrFormat("IsDiffCstCt(%s, %d, %s)", var_->DebugString(), cst_,
560                            target_var_->DebugString());
561   }
562 
Accept(ModelVisitor * const visitor) const563   void Accept(ModelVisitor* const visitor) const override {
564     visitor->BeginVisitConstraint(ModelVisitor::kIsDifferent, this);
565     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
566                                             var_);
567     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
568     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
569                                             target_var_);
570     visitor->EndVisitConstraint(ModelVisitor::kIsDifferent, this);
571   }
572 
573  private:
574   IntVar* const var_;
575   int64_t cst_;
576   Demon* demon_;
577 };
578 }  // namespace
579 
MakeIsDifferentCstVar(IntExpr * const var,int64_t value)580 IntVar* Solver::MakeIsDifferentCstVar(IntExpr* const var, int64_t value) {
581   IntExpr* left = nullptr;
582   IntExpr* right = nullptr;
583   if (IsADifference(var, &left, &right)) {
584     return MakeIsDifferentVar(left, MakeSum(right, value));
585   }
586   return var->Var()->IsDifferent(value);
587 }
588 
MakeIsDifferentCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)589 Constraint* Solver::MakeIsDifferentCstCt(IntExpr* const var, int64_t value,
590                                          IntVar* const boolvar) {
591   CHECK_EQ(this, var->solver());
592   CHECK_EQ(this, boolvar->solver());
593   if (value == var->Min()) {
594     return MakeIsGreaterOrEqualCstCt(var, value + 1, boolvar);
595   }
596   if (value == var->Max()) {
597     return MakeIsLessOrEqualCstCt(var, value - 1, boolvar);
598   }
599   if (var->IsVar() && !var->Var()->Contains(value)) {
600     return MakeEquality(boolvar, int64_t{1});
601   }
602   if (var->Bound() && var->Min() == value) {
603     return MakeEquality(boolvar, Zero());
604   }
605   if (boolvar->Bound()) {
606     if (boolvar->Min() == 0) {
607       return MakeEquality(var, value);
608     } else {
609       return MakeNonEquality(var, value);
610     }
611   }
612   model_cache_->InsertExprConstantExpression(
613       boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
614   IntExpr* left = nullptr;
615   IntExpr* right = nullptr;
616   if (IsADifference(var, &left, &right)) {
617     return MakeIsDifferentCt(left, MakeSum(right, value), boolvar);
618   } else {
619     return RevAlloc(new IsDiffCstCt(this, var->Var(), value, boolvar));
620   }
621 }
622 
623 // ----- is_greater_equal_cst Constraint -----
624 
625 namespace {
626 class IsGreaterEqualCstCt : public CastConstraint {
627  public:
IsGreaterEqualCstCt(Solver * const s,IntExpr * const v,int64_t c,IntVar * const b)628   IsGreaterEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
629                       IntVar* const b)
630       : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
Post()631   void Post() override {
632     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
633     expr_->WhenRange(demon_);
634     target_var_->WhenBound(demon_);
635   }
InitialPropagate()636   void InitialPropagate() override {
637     bool inhibit = false;
638     int64_t u = expr_->Max() >= cst_;
639     int64_t l = expr_->Min() >= cst_;
640     target_var_->SetRange(l, u);
641     if (target_var_->Bound()) {
642       inhibit = true;
643       if (target_var_->Min() == 0) {
644         expr_->SetMax(cst_ - 1);
645       } else {
646         expr_->SetMin(cst_);
647       }
648     }
649     if (inhibit && ((target_var_->Max() == 0 && expr_->Max() < cst_) ||
650                     (target_var_->Min() == 1 && expr_->Min() >= cst_))) {
651       // Can we safely inhibit? Sometimes an expression is not
652       // persistent, just monotonic.
653       demon_->inhibit(solver());
654     }
655   }
DebugString() const656   std::string DebugString() const override {
657     return absl::StrFormat("IsGreaterEqualCstCt(%s, %d, %s)",
658                            expr_->DebugString(), cst_,
659                            target_var_->DebugString());
660   }
661 
Accept(ModelVisitor * const visitor) const662   void Accept(ModelVisitor* const visitor) const override {
663     visitor->BeginVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
664     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
665                                             expr_);
666     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
667     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
668                                             target_var_);
669     visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
670   }
671 
672  private:
673   IntExpr* const expr_;
674   int64_t cst_;
675   Demon* demon_;
676 };
677 }  // namespace
678 
MakeIsGreaterOrEqualCstVar(IntExpr * const var,int64_t value)679 IntVar* Solver::MakeIsGreaterOrEqualCstVar(IntExpr* const var, int64_t value) {
680   if (var->Min() >= value) {
681     return MakeIntConst(int64_t{1});
682   }
683   if (var->Max() < value) {
684     return MakeIntConst(int64_t{0});
685   }
686   if (var->IsVar()) {
687     return var->Var()->IsGreaterOrEqual(value);
688   } else {
689     IntVar* const boolvar =
690         MakeBoolVar(absl::StrFormat("Is(%s >= %d)", var->DebugString(), value));
691     AddConstraint(MakeIsGreaterOrEqualCstCt(var, value, boolvar));
692     return boolvar;
693   }
694 }
695 
MakeIsGreaterCstVar(IntExpr * const var,int64_t value)696 IntVar* Solver::MakeIsGreaterCstVar(IntExpr* const var, int64_t value) {
697   return MakeIsGreaterOrEqualCstVar(var, value + 1);
698 }
699 
MakeIsGreaterOrEqualCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)700 Constraint* Solver::MakeIsGreaterOrEqualCstCt(IntExpr* const var, int64_t value,
701                                               IntVar* const boolvar) {
702   if (boolvar->Bound()) {
703     if (boolvar->Min() == 0) {
704       return MakeLess(var, value);
705     } else {
706       return MakeGreaterOrEqual(var, value);
707     }
708   }
709   CHECK_EQ(this, var->solver());
710   CHECK_EQ(this, boolvar->solver());
711   model_cache_->InsertExprConstantExpression(
712       boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
713   return RevAlloc(new IsGreaterEqualCstCt(this, var, value, boolvar));
714 }
715 
MakeIsGreaterCstCt(IntExpr * const v,int64_t c,IntVar * const b)716 Constraint* Solver::MakeIsGreaterCstCt(IntExpr* const v, int64_t c,
717                                        IntVar* const b) {
718   return MakeIsGreaterOrEqualCstCt(v, c + 1, b);
719 }
720 
721 // ----- is_lesser_equal_cst Constraint -----
722 
723 namespace {
724 class IsLessEqualCstCt : public CastConstraint {
725  public:
IsLessEqualCstCt(Solver * const s,IntExpr * const v,int64_t c,IntVar * const b)726   IsLessEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
727                    IntVar* const b)
728       : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
729 
Post()730   void Post() override {
731     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
732     expr_->WhenRange(demon_);
733     target_var_->WhenBound(demon_);
734   }
735 
InitialPropagate()736   void InitialPropagate() override {
737     bool inhibit = false;
738     int64_t u = expr_->Min() <= cst_;
739     int64_t l = expr_->Max() <= cst_;
740     target_var_->SetRange(l, u);
741     if (target_var_->Bound()) {
742       inhibit = true;
743       if (target_var_->Min() == 0) {
744         expr_->SetMin(cst_ + 1);
745       } else {
746         expr_->SetMax(cst_);
747       }
748     }
749     if (inhibit && ((target_var_->Max() == 0 && expr_->Min() > cst_) ||
750                     (target_var_->Min() == 1 && expr_->Max() <= cst_))) {
751       // Can we safely inhibit? Sometimes an expression is not
752       // persistent, just monotonic.
753       demon_->inhibit(solver());
754     }
755   }
756 
DebugString() const757   std::string DebugString() const override {
758     return absl::StrFormat("IsLessEqualCstCt(%s, %d, %s)", expr_->DebugString(),
759                            cst_, target_var_->DebugString());
760   }
761 
Accept(ModelVisitor * const visitor) const762   void Accept(ModelVisitor* const visitor) const override {
763     visitor->BeginVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
764     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
765                                             expr_);
766     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
767     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
768                                             target_var_);
769     visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
770   }
771 
772  private:
773   IntExpr* const expr_;
774   int64_t cst_;
775   Demon* demon_;
776 };
777 }  // namespace
778 
MakeIsLessOrEqualCstVar(IntExpr * const var,int64_t value)779 IntVar* Solver::MakeIsLessOrEqualCstVar(IntExpr* const var, int64_t value) {
780   if (var->Max() <= value) {
781     return MakeIntConst(int64_t{1});
782   }
783   if (var->Min() > value) {
784     return MakeIntConst(int64_t{0});
785   }
786   if (var->IsVar()) {
787     return var->Var()->IsLessOrEqual(value);
788   } else {
789     IntVar* const boolvar =
790         MakeBoolVar(absl::StrFormat("Is(%s <= %d)", var->DebugString(), value));
791     AddConstraint(MakeIsLessOrEqualCstCt(var, value, boolvar));
792     return boolvar;
793   }
794 }
795 
MakeIsLessCstVar(IntExpr * const var,int64_t value)796 IntVar* Solver::MakeIsLessCstVar(IntExpr* const var, int64_t value) {
797   return MakeIsLessOrEqualCstVar(var, value - 1);
798 }
799 
MakeIsLessOrEqualCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)800 Constraint* Solver::MakeIsLessOrEqualCstCt(IntExpr* const var, int64_t value,
801                                            IntVar* const boolvar) {
802   if (boolvar->Bound()) {
803     if (boolvar->Min() == 0) {
804       return MakeGreater(var, value);
805     } else {
806       return MakeLessOrEqual(var, value);
807     }
808   }
809   CHECK_EQ(this, var->solver());
810   CHECK_EQ(this, boolvar->solver());
811   model_cache_->InsertExprConstantExpression(
812       boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
813   return RevAlloc(new IsLessEqualCstCt(this, var, value, boolvar));
814 }
815 
MakeIsLessCstCt(IntExpr * const v,int64_t c,IntVar * const b)816 Constraint* Solver::MakeIsLessCstCt(IntExpr* const v, int64_t c,
817                                     IntVar* const b) {
818   return MakeIsLessOrEqualCstCt(v, c - 1, b);
819 }
820 
821 // ----- BetweenCt -----
822 
823 namespace {
824 class BetweenCt : public Constraint {
825  public:
BetweenCt(Solver * const s,IntExpr * const v,int64_t l,int64_t u)826   BetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
827       : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
828 
Post()829   void Post() override {
830     if (!expr_->IsVar()) {
831       demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
832       expr_->WhenRange(demon_);
833     }
834   }
835 
InitialPropagate()836   void InitialPropagate() override {
837     expr_->SetRange(min_, max_);
838     int64_t emin = 0;
839     int64_t emax = 0;
840     expr_->Range(&emin, &emax);
841     if (demon_ != nullptr && emin >= min_ && emax <= max_) {
842       demon_->inhibit(solver());
843     }
844   }
845 
DebugString() const846   std::string DebugString() const override {
847     return absl::StrFormat("BetweenCt(%s, %d, %d)", expr_->DebugString(), min_,
848                            max_);
849   }
850 
Accept(ModelVisitor * const visitor) const851   void Accept(ModelVisitor* const visitor) const override {
852     visitor->BeginVisitConstraint(ModelVisitor::kBetween, this);
853     visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
854     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
855                                             expr_);
856     visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
857     visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
858   }
859 
860  private:
861   IntExpr* const expr_;
862   int64_t min_;
863   int64_t max_;
864   Demon* demon_;
865 };
866 
867 // ----- NonMember constraint -----
868 
869 class NotBetweenCt : public Constraint {
870  public:
NotBetweenCt(Solver * const s,IntExpr * const v,int64_t l,int64_t u)871   NotBetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
872       : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
873 
Post()874   void Post() override {
875     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
876     expr_->WhenRange(demon_);
877   }
878 
InitialPropagate()879   void InitialPropagate() override {
880     int64_t emin = 0;
881     int64_t emax = 0;
882     expr_->Range(&emin, &emax);
883     if (emin >= min_) {
884       expr_->SetMin(max_ + 1);
885     } else if (emax <= max_) {
886       expr_->SetMax(min_ - 1);
887     }
888 
889     if (!expr_->IsVar() && (emax < min_ || emin > max_)) {
890       demon_->inhibit(solver());
891     }
892   }
893 
DebugString() const894   std::string DebugString() const override {
895     return absl::StrFormat("NotBetweenCt(%s, %d, %d)", expr_->DebugString(),
896                            min_, max_);
897   }
898 
Accept(ModelVisitor * const visitor) const899   void Accept(ModelVisitor* const visitor) const override {
900     visitor->BeginVisitConstraint(ModelVisitor::kNotBetween, this);
901     visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
902     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
903                                             expr_);
904     visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
905     visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
906   }
907 
908  private:
909   IntExpr* const expr_;
910   int64_t min_;
911   int64_t max_;
912   Demon* demon_;
913 };
914 
ExtractExprProductCoeff(IntExpr ** expr)915 int64_t ExtractExprProductCoeff(IntExpr** expr) {
916   int64_t prod = 1;
917   int64_t coeff = 1;
918   while ((*expr)->solver()->IsProduct(*expr, expr, &coeff)) prod *= coeff;
919   return prod;
920 }
921 }  // namespace
922 
MakeBetweenCt(IntExpr * expr,int64_t l,int64_t u)923 Constraint* Solver::MakeBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
924   DCHECK_EQ(this, expr->solver());
925   // Catch empty and singleton intervals.
926   if (l >= u) {
927     if (l > u) return MakeFalseConstraint();
928     return MakeEquality(expr, l);
929   }
930   int64_t emin = 0;
931   int64_t emax = 0;
932   expr->Range(&emin, &emax);
933   // Catch the trivial cases first.
934   if (emax < l || emin > u) return MakeFalseConstraint();
935   if (emin >= l && emax <= u) return MakeTrueConstraint();
936   // Catch one-sided constraints.
937   if (emax <= u) return MakeGreaterOrEqual(expr, l);
938   if (emin >= l) return MakeLessOrEqual(expr, u);
939   // Simplify the common factor, if any.
940   int64_t coeff = ExtractExprProductCoeff(&expr);
941   if (coeff != 1) {
942     CHECK_NE(coeff, 0);  // Would have been caught by the trivial cases already.
943     if (coeff < 0) {
944       std::swap(u, l);
945       u = -u;
946       l = -l;
947       coeff = -coeff;
948     }
949     return MakeBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff));
950   } else {
951     // No further reduction is possible.
952     return RevAlloc(new BetweenCt(this, expr, l, u));
953   }
954 }
955 
MakeNotBetweenCt(IntExpr * expr,int64_t l,int64_t u)956 Constraint* Solver::MakeNotBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
957   DCHECK_EQ(this, expr->solver());
958   // Catch empty interval.
959   if (l > u) {
960     return MakeTrueConstraint();
961   }
962 
963   int64_t emin = 0;
964   int64_t emax = 0;
965   expr->Range(&emin, &emax);
966   // Catch the trivial cases first.
967   if (emax < l || emin > u) return MakeTrueConstraint();
968   if (emin >= l && emax <= u) return MakeFalseConstraint();
969   // Catch one-sided constraints.
970   if (emin >= l) return MakeGreater(expr, u);
971   if (emax <= u) return MakeLess(expr, l);
972   // TODO(user): Add back simplification code if expr is constant *
973   // other_expr.
974   return RevAlloc(new NotBetweenCt(this, expr, l, u));
975 }
976 
977 // ----- is_between_cst Constraint -----
978 
979 namespace {
980 class IsBetweenCt : public Constraint {
981  public:
IsBetweenCt(Solver * const s,IntExpr * const e,int64_t l,int64_t u,IntVar * const b)982   IsBetweenCt(Solver* const s, IntExpr* const e, int64_t l, int64_t u,
983               IntVar* const b)
984       : Constraint(s),
985         expr_(e),
986         min_(l),
987         max_(u),
988         boolvar_(b),
989         demon_(nullptr) {}
990 
Post()991   void Post() override {
992     demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
993     expr_->WhenRange(demon_);
994     boolvar_->WhenBound(demon_);
995   }
996 
InitialPropagate()997   void InitialPropagate() override {
998     bool inhibit = false;
999     int64_t emin = 0;
1000     int64_t emax = 0;
1001     expr_->Range(&emin, &emax);
1002     int64_t u = 1 - (emin > max_ || emax < min_);
1003     int64_t l = emax <= max_ && emin >= min_;
1004     boolvar_->SetRange(l, u);
1005     if (boolvar_->Bound()) {
1006       inhibit = true;
1007       if (boolvar_->Min() == 0) {
1008         if (expr_->IsVar()) {
1009           expr_->Var()->RemoveInterval(min_, max_);
1010           inhibit = true;
1011         } else if (emin > min_) {
1012           expr_->SetMin(max_ + 1);
1013         } else if (emax < max_) {
1014           expr_->SetMax(min_ - 1);
1015         }
1016       } else {
1017         expr_->SetRange(min_, max_);
1018         inhibit = true;
1019       }
1020       if (inhibit && expr_->IsVar()) {
1021         demon_->inhibit(solver());
1022       }
1023     }
1024   }
1025 
DebugString() const1026   std::string DebugString() const override {
1027     return absl::StrFormat("IsBetweenCt(%s, %d, %d, %s)", expr_->DebugString(),
1028                            min_, max_, boolvar_->DebugString());
1029   }
1030 
Accept(ModelVisitor * const visitor) const1031   void Accept(ModelVisitor* const visitor) const override {
1032     visitor->BeginVisitConstraint(ModelVisitor::kIsBetween, this);
1033     visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
1034     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1035                                             expr_);
1036     visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
1037     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1038                                             boolvar_);
1039     visitor->EndVisitConstraint(ModelVisitor::kIsBetween, this);
1040   }
1041 
1042  private:
1043   IntExpr* const expr_;
1044   int64_t min_;
1045   int64_t max_;
1046   IntVar* const boolvar_;
1047   Demon* demon_;
1048 };
1049 }  // namespace
1050 
MakeIsBetweenCt(IntExpr * expr,int64_t l,int64_t u,IntVar * const b)1051 Constraint* Solver::MakeIsBetweenCt(IntExpr* expr, int64_t l, int64_t u,
1052                                     IntVar* const b) {
1053   CHECK_EQ(this, expr->solver());
1054   CHECK_EQ(this, b->solver());
1055   // Catch empty and singleton intervals.
1056   if (l >= u) {
1057     if (l > u) return MakeEquality(b, Zero());
1058     return MakeIsEqualCstCt(expr, l, b);
1059   }
1060   int64_t emin = 0;
1061   int64_t emax = 0;
1062   expr->Range(&emin, &emax);
1063   // Catch the trivial cases first.
1064   if (emax < l || emin > u) return MakeEquality(b, Zero());
1065   if (emin >= l && emax <= u) return MakeEquality(b, 1);
1066   // Catch one-sided constraints.
1067   if (emax <= u) return MakeIsGreaterOrEqualCstCt(expr, l, b);
1068   if (emin >= l) return MakeIsLessOrEqualCstCt(expr, u, b);
1069   // Simplify the common factor, if any.
1070   int64_t coeff = ExtractExprProductCoeff(&expr);
1071   if (coeff != 1) {
1072     CHECK_NE(coeff, 0);  // Would have been caught by the trivial cases already.
1073     if (coeff < 0) {
1074       std::swap(u, l);
1075       u = -u;
1076       l = -l;
1077       coeff = -coeff;
1078     }
1079     return MakeIsBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff),
1080                            b);
1081   } else {
1082     // No further reduction is possible.
1083     return RevAlloc(new IsBetweenCt(this, expr, l, u, b));
1084   }
1085 }
1086 
MakeIsBetweenVar(IntExpr * const v,int64_t l,int64_t u)1087 IntVar* Solver::MakeIsBetweenVar(IntExpr* const v, int64_t l, int64_t u) {
1088   CHECK_EQ(this, v->solver());
1089   IntVar* const b = MakeBoolVar();
1090   AddConstraint(MakeIsBetweenCt(v, l, u, b));
1091   return b;
1092 }
1093 
1094 // ---------- Member ----------
1095 
1096 // ----- Member(IntVar, IntSet) -----
1097 
1098 namespace {
1099 // TODO(user): Do not create holes on expressions.
1100 class MemberCt : public Constraint {
1101  public:
MemberCt(Solver * const s,IntVar * const v,const std::vector<int64_t> & sorted_values)1102   MemberCt(Solver* const s, IntVar* const v,
1103            const std::vector<int64_t>& sorted_values)
1104       : Constraint(s), var_(v), values_(sorted_values) {
1105     DCHECK(v != nullptr);
1106     DCHECK(s != nullptr);
1107   }
1108 
Post()1109   void Post() override {}
1110 
InitialPropagate()1111   void InitialPropagate() override { var_->SetValues(values_); }
1112 
DebugString() const1113   std::string DebugString() const override {
1114     return absl::StrFormat("Member(%s, %s)", var_->DebugString(),
1115                            absl::StrJoin(values_, ", "));
1116   }
1117 
Accept(ModelVisitor * const visitor) const1118   void Accept(ModelVisitor* const visitor) const override {
1119     visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1120     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1121                                             var_);
1122     visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1123     visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1124   }
1125 
1126  private:
1127   IntVar* const var_;
1128   const std::vector<int64_t> values_;
1129 };
1130 
1131 class NotMemberCt : public Constraint {
1132  public:
NotMemberCt(Solver * const s,IntVar * const v,const std::vector<int64_t> & sorted_values)1133   NotMemberCt(Solver* const s, IntVar* const v,
1134               const std::vector<int64_t>& sorted_values)
1135       : Constraint(s), var_(v), values_(sorted_values) {
1136     DCHECK(v != nullptr);
1137     DCHECK(s != nullptr);
1138   }
1139 
Post()1140   void Post() override {}
1141 
InitialPropagate()1142   void InitialPropagate() override { var_->RemoveValues(values_); }
1143 
DebugString() const1144   std::string DebugString() const override {
1145     return absl::StrFormat("NotMember(%s, %s)", var_->DebugString(),
1146                            absl::StrJoin(values_, ", "));
1147   }
1148 
Accept(ModelVisitor * const visitor) const1149   void Accept(ModelVisitor* const visitor) const override {
1150     visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1151     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1152                                             var_);
1153     visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1154     visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1155   }
1156 
1157  private:
1158   IntVar* const var_;
1159   const std::vector<int64_t> values_;
1160 };
1161 }  // namespace
1162 
MakeMemberCt(IntExpr * expr,const std::vector<int64_t> & values)1163 Constraint* Solver::MakeMemberCt(IntExpr* expr,
1164                                  const std::vector<int64_t>& values) {
1165   const int64_t coeff = ExtractExprProductCoeff(&expr);
1166   if (coeff == 0) {
1167     return std::find(values.begin(), values.end(), 0) == values.end()
1168                ? MakeFalseConstraint()
1169                : MakeTrueConstraint();
1170   }
1171   std::vector<int64_t> copied_values = values;
1172   // If the expression is a non-trivial product, we filter out the values that
1173   // aren't multiples of "coeff", and divide them.
1174   if (coeff != 1) {
1175     int num_kept = 0;
1176     for (const int64_t v : copied_values) {
1177       if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1178     }
1179     copied_values.resize(num_kept);
1180   }
1181   // Filter out the values that are outside the [Min, Max] interval.
1182   int num_kept = 0;
1183   int64_t emin;
1184   int64_t emax;
1185   expr->Range(&emin, &emax);
1186   for (const int64_t v : copied_values) {
1187     if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1188   }
1189   copied_values.resize(num_kept);
1190   // Catch empty set.
1191   if (copied_values.empty()) return MakeFalseConstraint();
1192   // Sort and remove duplicates.
1193   gtl::STLSortAndRemoveDuplicates(&copied_values);
1194   // Special case for singleton.
1195   if (copied_values.size() == 1) return MakeEquality(expr, copied_values[0]);
1196   // Catch contiguous intervals.
1197   if (copied_values.size() ==
1198       copied_values.back() - copied_values.front() + 1) {
1199     // Note: MakeBetweenCt() has a fast-track for trivially true constraints.
1200     return MakeBetweenCt(expr, copied_values.front(), copied_values.back());
1201   }
1202   // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1203   // "values" is smaller than "values", then it's more efficient to use
1204   // NotMemberCt. Catch that case here.
1205   if (emax - emin < 2 * copied_values.size()) {
1206     // Convert "copied_values" to list the values *not* allowed.
1207     std::vector<bool> is_among_input_values(emax - emin + 1, false);
1208     for (const int64_t v : copied_values)
1209       is_among_input_values[v - emin] = true;
1210     // We use the zero valued indices of is_among_input_values to build the
1211     // complement of copied_values.
1212     copied_values.clear();
1213     for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1214       if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1215     }
1216     // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1217     // "values" input) was caught earlier, by the "contiguous interval" case.
1218     DCHECK_GE(copied_values.size(), 1);
1219     if (copied_values.size() == 1) {
1220       return MakeNonEquality(expr, copied_values[0]);
1221     }
1222     return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1223   }
1224   // Otherwise, just use MemberCt. No further reduction is possible.
1225   return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1226 }
1227 
MakeMemberCt(IntExpr * const expr,const std::vector<int> & values)1228 Constraint* Solver::MakeMemberCt(IntExpr* const expr,
1229                                  const std::vector<int>& values) {
1230   return MakeMemberCt(expr, ToInt64Vector(values));
1231 }
1232 
MakeNotMemberCt(IntExpr * expr,const std::vector<int64_t> & values)1233 Constraint* Solver::MakeNotMemberCt(IntExpr* expr,
1234                                     const std::vector<int64_t>& values) {
1235   const int64_t coeff = ExtractExprProductCoeff(&expr);
1236   if (coeff == 0) {
1237     return std::find(values.begin(), values.end(), 0) == values.end()
1238                ? MakeTrueConstraint()
1239                : MakeFalseConstraint();
1240   }
1241   std::vector<int64_t> copied_values = values;
1242   // If the expression is a non-trivial product, we filter out the values that
1243   // aren't multiples of "coeff", and divide them.
1244   if (coeff != 1) {
1245     int num_kept = 0;
1246     for (const int64_t v : copied_values) {
1247       if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1248     }
1249     copied_values.resize(num_kept);
1250   }
1251   // Filter out the values that are outside the [Min, Max] interval.
1252   int num_kept = 0;
1253   int64_t emin;
1254   int64_t emax;
1255   expr->Range(&emin, &emax);
1256   for (const int64_t v : copied_values) {
1257     if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1258   }
1259   copied_values.resize(num_kept);
1260   // Catch empty set.
1261   if (copied_values.empty()) return MakeTrueConstraint();
1262   // Sort and remove duplicates.
1263   gtl::STLSortAndRemoveDuplicates(&copied_values);
1264   // Special case for singleton.
1265   if (copied_values.size() == 1) return MakeNonEquality(expr, copied_values[0]);
1266   // Catch contiguous intervals.
1267   if (copied_values.size() ==
1268       copied_values.back() - copied_values.front() + 1) {
1269     return MakeNotBetweenCt(expr, copied_values.front(), copied_values.back());
1270   }
1271   // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1272   // "values" is smaller than "values", then it's more efficient to use
1273   // MemberCt. Catch that case here.
1274   if (emax - emin < 2 * copied_values.size()) {
1275     // Convert "copied_values" to a dense boolean vector.
1276     std::vector<bool> is_among_input_values(emax - emin + 1, false);
1277     for (const int64_t v : copied_values)
1278       is_among_input_values[v - emin] = true;
1279     // Use zero valued indices for is_among_input_values to build the
1280     // complement of copied_values.
1281     copied_values.clear();
1282     for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1283       if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1284     }
1285     // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1286     // "values" input) was caught earlier, by the "contiguous interval" case.
1287     DCHECK_GE(copied_values.size(), 1);
1288     if (copied_values.size() == 1) {
1289       return MakeEquality(expr, copied_values[0]);
1290     }
1291     return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1292   }
1293   // Otherwise, just use NotMemberCt. No further reduction is possible.
1294   return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1295 }
1296 
MakeNotMemberCt(IntExpr * const expr,const std::vector<int> & values)1297 Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1298                                     const std::vector<int>& values) {
1299   return MakeNotMemberCt(expr, ToInt64Vector(values));
1300 }
1301 
1302 // ----- IsMemberCt -----
1303 
1304 namespace {
1305 class IsMemberCt : public Constraint {
1306  public:
IsMemberCt(Solver * const s,IntVar * const v,const std::vector<int64_t> & sorted_values,IntVar * const b)1307   IsMemberCt(Solver* const s, IntVar* const v,
1308              const std::vector<int64_t>& sorted_values, IntVar* const b)
1309       : Constraint(s),
1310         var_(v),
1311         values_as_set_(sorted_values.begin(), sorted_values.end()),
1312         values_(sorted_values),
1313         boolvar_(b),
1314         support_(0),
1315         demon_(nullptr),
1316         domain_(var_->MakeDomainIterator(true)),
1317         neg_support_(std::numeric_limits<int64_t>::min()) {
1318     DCHECK(v != nullptr);
1319     DCHECK(s != nullptr);
1320     DCHECK(b != nullptr);
1321     while (values_as_set_.contains(neg_support_)) {
1322       neg_support_++;
1323     }
1324   }
1325 
Post()1326   void Post() override {
1327     demon_ = MakeConstraintDemon0(solver(), this, &IsMemberCt::VarDomain,
1328                                   "VarDomain");
1329     if (!var_->Bound()) {
1330       var_->WhenDomain(demon_);
1331     }
1332     if (!boolvar_->Bound()) {
1333       Demon* const bdemon = MakeConstraintDemon0(
1334           solver(), this, &IsMemberCt::TargetBound, "TargetBound");
1335       boolvar_->WhenBound(bdemon);
1336     }
1337   }
1338 
InitialPropagate()1339   void InitialPropagate() override {
1340     boolvar_->SetRange(0, 1);
1341     if (boolvar_->Bound()) {
1342       TargetBound();
1343     } else {
1344       VarDomain();
1345     }
1346   }
1347 
DebugString() const1348   std::string DebugString() const override {
1349     return absl::StrFormat("IsMemberCt(%s, %s, %s)", var_->DebugString(),
1350                            absl::StrJoin(values_, ", "),
1351                            boolvar_->DebugString());
1352   }
1353 
Accept(ModelVisitor * const visitor) const1354   void Accept(ModelVisitor* const visitor) const override {
1355     visitor->BeginVisitConstraint(ModelVisitor::kIsMember, this);
1356     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1357                                             var_);
1358     visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1359     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1360                                             boolvar_);
1361     visitor->EndVisitConstraint(ModelVisitor::kIsMember, this);
1362   }
1363 
1364  private:
VarDomain()1365   void VarDomain() {
1366     if (boolvar_->Bound()) {
1367       TargetBound();
1368     } else {
1369       for (int offset = 0; offset < values_.size(); ++offset) {
1370         const int candidate = (support_ + offset) % values_.size();
1371         if (var_->Contains(values_[candidate])) {
1372           support_ = candidate;
1373           if (var_->Bound()) {
1374             demon_->inhibit(solver());
1375             boolvar_->SetValue(1);
1376             return;
1377           }
1378           // We have found a positive support. Let's check the
1379           // negative support.
1380           if (var_->Contains(neg_support_)) {
1381             return;
1382           } else {
1383             // Look for a new negative support.
1384             for (const int64_t value : InitAndGetValues(domain_)) {
1385               if (!values_as_set_.contains(value)) {
1386                 neg_support_ = value;
1387                 return;
1388               }
1389             }
1390           }
1391           // No negative support, setting boolvar to true.
1392           demon_->inhibit(solver());
1393           boolvar_->SetValue(1);
1394           return;
1395         }
1396       }
1397       // No positive support, setting boolvar to false.
1398       demon_->inhibit(solver());
1399       boolvar_->SetValue(0);
1400     }
1401   }
1402 
TargetBound()1403   void TargetBound() {
1404     DCHECK(boolvar_->Bound());
1405     if (boolvar_->Min() == 1LL) {
1406       demon_->inhibit(solver());
1407       var_->SetValues(values_);
1408     } else {
1409       demon_->inhibit(solver());
1410       var_->RemoveValues(values_);
1411     }
1412   }
1413 
1414   IntVar* const var_;
1415   absl::flat_hash_set<int64_t> values_as_set_;
1416   std::vector<int64_t> values_;
1417   IntVar* const boolvar_;
1418   int support_;
1419   Demon* demon_;
1420   IntVarIterator* const domain_;
1421   int64_t neg_support_;
1422 };
1423 
1424 template <class T>
BuildIsMemberCt(Solver * const solver,IntExpr * const expr,const std::vector<T> & values,IntVar * const boolvar)1425 Constraint* BuildIsMemberCt(Solver* const solver, IntExpr* const expr,
1426                             const std::vector<T>& values,
1427                             IntVar* const boolvar) {
1428   // TODO(user): optimize this by copying the code from MakeMemberCt.
1429   // Simplify and filter if expr is a product.
1430   IntExpr* sub = nullptr;
1431   int64_t coef = 1;
1432   if (solver->IsProduct(expr, &sub, &coef) && coef != 0 && coef != 1) {
1433     std::vector<int64_t> new_values;
1434     new_values.reserve(values.size());
1435     for (const int64_t value : values) {
1436       if (value % coef == 0) {
1437         new_values.push_back(value / coef);
1438       }
1439     }
1440     return BuildIsMemberCt(solver, sub, new_values, boolvar);
1441   }
1442 
1443   std::set<T> set_of_values(values.begin(), values.end());
1444   std::vector<int64_t> filtered_values;
1445   bool all_values = false;
1446   if (expr->IsVar()) {
1447     IntVar* const var = expr->Var();
1448     for (const T value : set_of_values) {
1449       if (var->Contains(value)) {
1450         filtered_values.push_back(value);
1451       }
1452     }
1453     all_values = (filtered_values.size() == var->Size());
1454   } else {
1455     int64_t emin = 0;
1456     int64_t emax = 0;
1457     expr->Range(&emin, &emax);
1458     for (const T value : set_of_values) {
1459       if (value >= emin && value <= emax) {
1460         filtered_values.push_back(value);
1461       }
1462     }
1463     all_values = (filtered_values.size() == emax - emin + 1);
1464   }
1465   if (filtered_values.empty()) {
1466     return solver->MakeEquality(boolvar, Zero());
1467   } else if (all_values) {
1468     return solver->MakeEquality(boolvar, 1);
1469   } else if (filtered_values.size() == 1) {
1470     return solver->MakeIsEqualCstCt(expr, filtered_values.back(), boolvar);
1471   } else if (filtered_values.back() ==
1472              filtered_values.front() + filtered_values.size() - 1) {
1473     // Contiguous
1474     return solver->MakeIsBetweenCt(expr, filtered_values.front(),
1475                                    filtered_values.back(), boolvar);
1476   } else {
1477     return solver->RevAlloc(
1478         new IsMemberCt(solver, expr->Var(), filtered_values, boolvar));
1479   }
1480 }
1481 }  // namespace
1482 
MakeIsMemberCt(IntExpr * const expr,const std::vector<int64_t> & values,IntVar * const boolvar)1483 Constraint* Solver::MakeIsMemberCt(IntExpr* const expr,
1484                                    const std::vector<int64_t>& values,
1485                                    IntVar* const boolvar) {
1486   return BuildIsMemberCt(this, expr, values, boolvar);
1487 }
1488 
MakeIsMemberCt(IntExpr * const expr,const std::vector<int> & values,IntVar * const boolvar)1489 Constraint* Solver::MakeIsMemberCt(IntExpr* const expr,
1490                                    const std::vector<int>& values,
1491                                    IntVar* const boolvar) {
1492   return BuildIsMemberCt(this, expr, values, boolvar);
1493 }
1494 
MakeIsMemberVar(IntExpr * const expr,const std::vector<int64_t> & values)1495 IntVar* Solver::MakeIsMemberVar(IntExpr* const expr,
1496                                 const std::vector<int64_t>& values) {
1497   IntVar* const b = MakeBoolVar();
1498   AddConstraint(MakeIsMemberCt(expr, values, b));
1499   return b;
1500 }
1501 
MakeIsMemberVar(IntExpr * const expr,const std::vector<int> & values)1502 IntVar* Solver::MakeIsMemberVar(IntExpr* const expr,
1503                                 const std::vector<int>& values) {
1504   IntVar* const b = MakeBoolVar();
1505   AddConstraint(MakeIsMemberCt(expr, values, b));
1506   return b;
1507 }
1508 
1509 namespace {
1510 class SortedDisjointForbiddenIntervalsConstraint : public Constraint {
1511  public:
SortedDisjointForbiddenIntervalsConstraint(Solver * const solver,IntVar * const var,SortedDisjointIntervalList intervals)1512   SortedDisjointForbiddenIntervalsConstraint(
1513       Solver* const solver, IntVar* const var,
1514       SortedDisjointIntervalList intervals)
1515       : Constraint(solver), var_(var), intervals_(std::move(intervals)) {}
1516 
~SortedDisjointForbiddenIntervalsConstraint()1517   ~SortedDisjointForbiddenIntervalsConstraint() override {}
1518 
Post()1519   void Post() override {
1520     Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1521     var_->WhenRange(demon);
1522   }
1523 
InitialPropagate()1524   void InitialPropagate() override {
1525     const int64_t vmin = var_->Min();
1526     const int64_t vmax = var_->Max();
1527     const auto first_interval_it = intervals_.FirstIntervalGreaterOrEqual(vmin);
1528     if (first_interval_it == intervals_.end()) {
1529       // No interval intersects the variable's range. Nothing to do.
1530       return;
1531     }
1532     const auto last_interval_it = intervals_.LastIntervalLessOrEqual(vmax);
1533     if (last_interval_it == intervals_.end()) {
1534       // No interval intersects the variable's range. Nothing to do.
1535       return;
1536     }
1537     // TODO(user): Quick fail if first_interval_it == last_interval_it, which
1538     // would imply that the interval contains the entire range of the variable?
1539     if (vmin >= first_interval_it->start) {
1540       // The variable's minimum is inside a forbidden interval. Move it to the
1541       // interval's end.
1542       var_->SetMin(CapAdd(first_interval_it->end, 1));
1543     }
1544     if (vmax <= last_interval_it->end) {
1545       // Ditto, on the other side.
1546       var_->SetMax(CapSub(last_interval_it->start, 1));
1547     }
1548   }
1549 
DebugString() const1550   std::string DebugString() const override {
1551     return absl::StrFormat("ForbiddenIntervalCt(%s, %s)", var_->DebugString(),
1552                            intervals_.DebugString());
1553   }
1554 
Accept(ModelVisitor * const visitor) const1555   void Accept(ModelVisitor* const visitor) const override {
1556     visitor->BeginVisitConstraint(ModelVisitor::kNotMember, this);
1557     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1558                                             var_);
1559     std::vector<int64_t> starts;
1560     std::vector<int64_t> ends;
1561     for (auto& interval : intervals_) {
1562       starts.push_back(interval.start);
1563       ends.push_back(interval.end);
1564     }
1565     visitor->VisitIntegerArrayArgument(ModelVisitor::kStartsArgument, starts);
1566     visitor->VisitIntegerArrayArgument(ModelVisitor::kEndsArgument, ends);
1567     visitor->EndVisitConstraint(ModelVisitor::kNotMember, this);
1568   }
1569 
1570  private:
1571   IntVar* const var_;
1572   const SortedDisjointIntervalList intervals_;
1573 };
1574 }  // namespace
1575 
MakeNotMemberCt(IntExpr * const expr,std::vector<int64_t> starts,std::vector<int64_t> ends)1576 Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1577                                     std::vector<int64_t> starts,
1578                                     std::vector<int64_t> ends) {
1579   return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1580       this, expr->Var(), {starts, ends}));
1581 }
1582 
MakeNotMemberCt(IntExpr * const expr,std::vector<int> starts,std::vector<int> ends)1583 Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1584                                     std::vector<int> starts,
1585                                     std::vector<int> ends) {
1586   return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1587       this, expr->Var(), {starts, ends}));
1588 }
1589 
MakeNotMemberCt(IntExpr * expr,SortedDisjointIntervalList intervals)1590 Constraint* Solver::MakeNotMemberCt(IntExpr* expr,
1591                                     SortedDisjointIntervalList intervals) {
1592   return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1593       this, expr->Var(), std::move(intervals)));
1594 }
1595 }  // namespace operations_research
1596