1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 //
15 // Expression constraints
16
17 #include <cstddef>
18 #include <cstdint>
19 #include <limits>
20 #include <set>
21 #include <string>
22 #include <vector>
23
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/str_join.h"
26 #include "ortools/base/commandlineflags.h"
27 #include "ortools/base/integral_types.h"
28 #include "ortools/base/logging.h"
29 #include "ortools/base/stl_util.h"
30 #include "ortools/constraint_solver/constraint_solver.h"
31 #include "ortools/constraint_solver/constraint_solveri.h"
32 #include "ortools/util/saturated_arithmetic.h"
33 #include "ortools/util/sorted_interval_list.h"
34
35 ABSL_FLAG(int, cache_initial_size, 1024,
36 "Initial size of the array of the hash "
37 "table of caches for objects of type Var(x == 3)");
38
39 namespace operations_research {
40
41 //-----------------------------------------------------------------------------
42 // Equality
43
44 namespace {
45 class EqualityExprCst : public Constraint {
46 public:
47 EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v);
~EqualityExprCst()48 ~EqualityExprCst() override {}
49 void Post() override;
50 void InitialPropagate() override;
Var()51 IntVar* Var() override {
52 return solver()->MakeIsEqualCstVar(expr_->Var(), value_);
53 }
54 std::string DebugString() const override;
55
Accept(ModelVisitor * const visitor) const56 void Accept(ModelVisitor* const visitor) const override {
57 visitor->BeginVisitConstraint(ModelVisitor::kEquality, this);
58 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
59 expr_);
60 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
61 visitor->EndVisitConstraint(ModelVisitor::kEquality, this);
62 }
63
64 private:
65 IntExpr* const expr_;
66 int64_t value_;
67 };
68
EqualityExprCst(Solver * const s,IntExpr * const e,int64_t v)69 EqualityExprCst::EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v)
70 : Constraint(s), expr_(e), value_(v) {}
71
Post()72 void EqualityExprCst::Post() {
73 if (!expr_->IsVar()) {
74 Demon* d = solver()->MakeConstraintInitialPropagateCallback(this);
75 expr_->WhenRange(d);
76 }
77 }
78
InitialPropagate()79 void EqualityExprCst::InitialPropagate() { expr_->SetValue(value_); }
80
DebugString() const81 std::string EqualityExprCst::DebugString() const {
82 return absl::StrFormat("(%s == %d)", expr_->DebugString(), value_);
83 }
84 } // namespace
85
MakeEquality(IntExpr * const e,int64_t v)86 Constraint* Solver::MakeEquality(IntExpr* const e, int64_t v) {
87 CHECK_EQ(this, e->solver());
88 IntExpr* left = nullptr;
89 IntExpr* right = nullptr;
90 if (IsADifference(e, &left, &right)) {
91 return MakeEquality(left, MakeSum(right, v));
92 } else if (e->IsVar() && !e->Var()->Contains(v)) {
93 return MakeFalseConstraint();
94 } else if (e->Min() == e->Max() && e->Min() == v) {
95 return MakeTrueConstraint();
96 } else {
97 return RevAlloc(new EqualityExprCst(this, e, v));
98 }
99 }
100
MakeEquality(IntExpr * const e,int v)101 Constraint* Solver::MakeEquality(IntExpr* const e, int v) {
102 CHECK_EQ(this, e->solver());
103 IntExpr* left = nullptr;
104 IntExpr* right = nullptr;
105 if (IsADifference(e, &left, &right)) {
106 return MakeEquality(left, MakeSum(right, v));
107 } else if (e->IsVar() && !e->Var()->Contains(v)) {
108 return MakeFalseConstraint();
109 } else if (e->Min() == e->Max() && e->Min() == v) {
110 return MakeTrueConstraint();
111 } else {
112 return RevAlloc(new EqualityExprCst(this, e, v));
113 }
114 }
115
116 //-----------------------------------------------------------------------------
117 // Greater or equal constraint
118
119 namespace {
120 class GreaterEqExprCst : public Constraint {
121 public:
122 GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v);
~GreaterEqExprCst()123 ~GreaterEqExprCst() override {}
124 void Post() override;
125 void InitialPropagate() override;
126 std::string DebugString() const override;
Var()127 IntVar* Var() override {
128 return solver()->MakeIsGreaterOrEqualCstVar(expr_->Var(), value_);
129 }
130
Accept(ModelVisitor * const visitor) const131 void Accept(ModelVisitor* const visitor) const override {
132 visitor->BeginVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
133 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
134 expr_);
135 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
136 visitor->EndVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
137 }
138
139 private:
140 IntExpr* const expr_;
141 int64_t value_;
142 Demon* demon_;
143 };
144
GreaterEqExprCst(Solver * const s,IntExpr * const e,int64_t v)145 GreaterEqExprCst::GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
146 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
147
Post()148 void GreaterEqExprCst::Post() {
149 if (!expr_->IsVar() && expr_->Min() < value_) {
150 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
151 expr_->WhenRange(demon_);
152 } else {
153 // Let's clean the demon in case the constraint is posted during search.
154 demon_ = nullptr;
155 }
156 }
157
InitialPropagate()158 void GreaterEqExprCst::InitialPropagate() {
159 expr_->SetMin(value_);
160 if (demon_ != nullptr && expr_->Min() >= value_) {
161 demon_->inhibit(solver());
162 }
163 }
164
DebugString() const165 std::string GreaterEqExprCst::DebugString() const {
166 return absl::StrFormat("(%s >= %d)", expr_->DebugString(), value_);
167 }
168 } // namespace
169
MakeGreaterOrEqual(IntExpr * const e,int64_t v)170 Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int64_t v) {
171 CHECK_EQ(this, e->solver());
172 if (e->Min() >= v) {
173 return MakeTrueConstraint();
174 } else if (e->Max() < v) {
175 return MakeFalseConstraint();
176 } else {
177 return RevAlloc(new GreaterEqExprCst(this, e, v));
178 }
179 }
180
MakeGreaterOrEqual(IntExpr * const e,int v)181 Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int v) {
182 CHECK_EQ(this, e->solver());
183 if (e->Min() >= v) {
184 return MakeTrueConstraint();
185 } else if (e->Max() < v) {
186 return MakeFalseConstraint();
187 } else {
188 return RevAlloc(new GreaterEqExprCst(this, e, v));
189 }
190 }
191
MakeGreater(IntExpr * const e,int64_t v)192 Constraint* Solver::MakeGreater(IntExpr* const e, int64_t v) {
193 CHECK_EQ(this, e->solver());
194 if (e->Min() > v) {
195 return MakeTrueConstraint();
196 } else if (e->Max() <= v) {
197 return MakeFalseConstraint();
198 } else {
199 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
200 }
201 }
202
MakeGreater(IntExpr * const e,int v)203 Constraint* Solver::MakeGreater(IntExpr* const e, int v) {
204 CHECK_EQ(this, e->solver());
205 if (e->Min() > v) {
206 return MakeTrueConstraint();
207 } else if (e->Max() <= v) {
208 return MakeFalseConstraint();
209 } else {
210 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
211 }
212 }
213
214 //-----------------------------------------------------------------------------
215 // Less or equal constraint
216
217 namespace {
218 class LessEqExprCst : public Constraint {
219 public:
220 LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v);
~LessEqExprCst()221 ~LessEqExprCst() override {}
222 void Post() override;
223 void InitialPropagate() override;
224 std::string DebugString() const override;
Var()225 IntVar* Var() override {
226 return solver()->MakeIsLessOrEqualCstVar(expr_->Var(), value_);
227 }
Accept(ModelVisitor * const visitor) const228 void Accept(ModelVisitor* const visitor) const override {
229 visitor->BeginVisitConstraint(ModelVisitor::kLessOrEqual, this);
230 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
231 expr_);
232 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
233 visitor->EndVisitConstraint(ModelVisitor::kLessOrEqual, this);
234 }
235
236 private:
237 IntExpr* const expr_;
238 int64_t value_;
239 Demon* demon_;
240 };
241
LessEqExprCst(Solver * const s,IntExpr * const e,int64_t v)242 LessEqExprCst::LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
243 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
244
Post()245 void LessEqExprCst::Post() {
246 if (!expr_->IsVar() && expr_->Max() > value_) {
247 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
248 expr_->WhenRange(demon_);
249 } else {
250 // Let's clean the demon in case the constraint is posted during search.
251 demon_ = nullptr;
252 }
253 }
254
InitialPropagate()255 void LessEqExprCst::InitialPropagate() {
256 expr_->SetMax(value_);
257 if (demon_ != nullptr && expr_->Max() <= value_) {
258 demon_->inhibit(solver());
259 }
260 }
261
DebugString() const262 std::string LessEqExprCst::DebugString() const {
263 return absl::StrFormat("(%s <= %d)", expr_->DebugString(), value_);
264 }
265 } // namespace
266
MakeLessOrEqual(IntExpr * const e,int64_t v)267 Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int64_t v) {
268 CHECK_EQ(this, e->solver());
269 if (e->Max() <= v) {
270 return MakeTrueConstraint();
271 } else if (e->Min() > v) {
272 return MakeFalseConstraint();
273 } else {
274 return RevAlloc(new LessEqExprCst(this, e, v));
275 }
276 }
277
MakeLessOrEqual(IntExpr * const e,int v)278 Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int v) {
279 CHECK_EQ(this, e->solver());
280 if (e->Max() <= v) {
281 return MakeTrueConstraint();
282 } else if (e->Min() > v) {
283 return MakeFalseConstraint();
284 } else {
285 return RevAlloc(new LessEqExprCst(this, e, v));
286 }
287 }
288
MakeLess(IntExpr * const e,int64_t v)289 Constraint* Solver::MakeLess(IntExpr* const e, int64_t v) {
290 CHECK_EQ(this, e->solver());
291 if (e->Max() < v) {
292 return MakeTrueConstraint();
293 } else if (e->Min() >= v) {
294 return MakeFalseConstraint();
295 } else {
296 return RevAlloc(new LessEqExprCst(this, e, v - 1));
297 }
298 }
299
MakeLess(IntExpr * const e,int v)300 Constraint* Solver::MakeLess(IntExpr* const e, int v) {
301 CHECK_EQ(this, e->solver());
302 if (e->Max() < v) {
303 return MakeTrueConstraint();
304 } else if (e->Min() >= v) {
305 return MakeFalseConstraint();
306 } else {
307 return RevAlloc(new LessEqExprCst(this, e, v - 1));
308 }
309 }
310
311 //-----------------------------------------------------------------------------
312 // Different constraints
313
314 namespace {
315 class DiffCst : public Constraint {
316 public:
317 DiffCst(Solver* const s, IntVar* const var, int64_t value);
~DiffCst()318 ~DiffCst() override {}
Post()319 void Post() override {}
320 void InitialPropagate() override;
321 void BoundPropagate();
322 std::string DebugString() const override;
Var()323 IntVar* Var() override {
324 return solver()->MakeIsDifferentCstVar(var_, value_);
325 }
Accept(ModelVisitor * const visitor) const326 void Accept(ModelVisitor* const visitor) const override {
327 visitor->BeginVisitConstraint(ModelVisitor::kNonEqual, this);
328 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
329 var_);
330 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
331 visitor->EndVisitConstraint(ModelVisitor::kNonEqual, this);
332 }
333
334 private:
335 bool HasLargeDomain(IntVar* var);
336
337 IntVar* const var_;
338 int64_t value_;
339 Demon* demon_;
340 };
341
DiffCst(Solver * const s,IntVar * const var,int64_t value)342 DiffCst::DiffCst(Solver* const s, IntVar* const var, int64_t value)
343 : Constraint(s), var_(var), value_(value), demon_(nullptr) {}
344
InitialPropagate()345 void DiffCst::InitialPropagate() {
346 if (HasLargeDomain(var_)) {
347 demon_ = MakeConstraintDemon0(solver(), this, &DiffCst::BoundPropagate,
348 "BoundPropagate");
349 var_->WhenRange(demon_);
350 } else {
351 var_->RemoveValue(value_);
352 }
353 }
354
BoundPropagate()355 void DiffCst::BoundPropagate() {
356 const int64_t var_min = var_->Min();
357 const int64_t var_max = var_->Max();
358 if (var_min > value_ || var_max < value_) {
359 demon_->inhibit(solver());
360 } else if (var_min == value_) {
361 var_->SetMin(value_ + 1);
362 } else if (var_max == value_) {
363 var_->SetMax(value_ - 1);
364 } else if (!HasLargeDomain(var_)) {
365 demon_->inhibit(solver());
366 var_->RemoveValue(value_);
367 }
368 }
369
DebugString() const370 std::string DiffCst::DebugString() const {
371 return absl::StrFormat("(%s != %d)", var_->DebugString(), value_);
372 }
373
HasLargeDomain(IntVar * var)374 bool DiffCst::HasLargeDomain(IntVar* var) {
375 return CapSub(var->Max(), var->Min()) > 0xFFFFFF;
376 }
377 } // namespace
378
MakeNonEquality(IntExpr * const e,int64_t v)379 Constraint* Solver::MakeNonEquality(IntExpr* const e, int64_t v) {
380 CHECK_EQ(this, e->solver());
381 IntExpr* left = nullptr;
382 IntExpr* right = nullptr;
383 if (IsADifference(e, &left, &right)) {
384 return MakeNonEquality(left, MakeSum(right, v));
385 } else if (e->IsVar() && !e->Var()->Contains(v)) {
386 return MakeTrueConstraint();
387 } else if (e->Bound() && e->Min() == v) {
388 return MakeFalseConstraint();
389 } else {
390 return RevAlloc(new DiffCst(this, e->Var(), v));
391 }
392 }
393
MakeNonEquality(IntExpr * const e,int v)394 Constraint* Solver::MakeNonEquality(IntExpr* const e, int v) {
395 CHECK_EQ(this, e->solver());
396 IntExpr* left = nullptr;
397 IntExpr* right = nullptr;
398 if (IsADifference(e, &left, &right)) {
399 return MakeNonEquality(left, MakeSum(right, v));
400 } else if (e->IsVar() && !e->Var()->Contains(v)) {
401 return MakeTrueConstraint();
402 } else if (e->Bound() && e->Min() == v) {
403 return MakeFalseConstraint();
404 } else {
405 return RevAlloc(new DiffCst(this, e->Var(), v));
406 }
407 }
408 // ----- is_equal_cst Constraint -----
409
410 namespace {
411 class IsEqualCstCt : public CastConstraint {
412 public:
IsEqualCstCt(Solver * const s,IntVar * const v,int64_t c,IntVar * const b)413 IsEqualCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
414 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
Post()415 void Post() override {
416 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
417 var_->WhenDomain(demon_);
418 target_var_->WhenBound(demon_);
419 }
InitialPropagate()420 void InitialPropagate() override {
421 bool inhibit = var_->Bound();
422 int64_t u = var_->Contains(cst_);
423 int64_t l = inhibit ? u : 0;
424 target_var_->SetRange(l, u);
425 if (target_var_->Bound()) {
426 if (target_var_->Min() == 0) {
427 if (var_->Size() <= 0xFFFFFF) {
428 var_->RemoveValue(cst_);
429 inhibit = true;
430 }
431 } else {
432 var_->SetValue(cst_);
433 inhibit = true;
434 }
435 }
436 if (inhibit) {
437 demon_->inhibit(solver());
438 }
439 }
DebugString() const440 std::string DebugString() const override {
441 return absl::StrFormat("IsEqualCstCt(%s, %d, %s)", var_->DebugString(),
442 cst_, target_var_->DebugString());
443 }
444
Accept(ModelVisitor * const visitor) const445 void Accept(ModelVisitor* const visitor) const override {
446 visitor->BeginVisitConstraint(ModelVisitor::kIsEqual, this);
447 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
448 var_);
449 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
450 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
451 target_var_);
452 visitor->EndVisitConstraint(ModelVisitor::kIsEqual, this);
453 }
454
455 private:
456 IntVar* const var_;
457 int64_t cst_;
458 Demon* demon_;
459 };
460 } // namespace
461
MakeIsEqualCstVar(IntExpr * const var,int64_t value)462 IntVar* Solver::MakeIsEqualCstVar(IntExpr* const var, int64_t value) {
463 IntExpr* left = nullptr;
464 IntExpr* right = nullptr;
465 if (IsADifference(var, &left, &right)) {
466 return MakeIsEqualVar(left, MakeSum(right, value));
467 }
468 if (CapSub(var->Max(), var->Min()) == 1) {
469 if (value == var->Min()) {
470 return MakeDifference(value + 1, var)->Var();
471 } else if (value == var->Max()) {
472 return MakeSum(var, -value + 1)->Var();
473 } else {
474 return MakeIntConst(0);
475 }
476 }
477 if (var->IsVar()) {
478 return var->Var()->IsEqual(value);
479 } else {
480 IntVar* const boolvar =
481 MakeBoolVar(absl::StrFormat("Is(%s == %d)", var->DebugString(), value));
482 AddConstraint(MakeIsEqualCstCt(var, value, boolvar));
483 return boolvar;
484 }
485 }
486
MakeIsEqualCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)487 Constraint* Solver::MakeIsEqualCstCt(IntExpr* const var, int64_t value,
488 IntVar* const boolvar) {
489 CHECK_EQ(this, var->solver());
490 CHECK_EQ(this, boolvar->solver());
491 if (value == var->Min()) {
492 if (CapSub(var->Max(), var->Min()) == 1) {
493 return MakeEquality(MakeDifference(value + 1, var), boolvar);
494 }
495 return MakeIsLessOrEqualCstCt(var, value, boolvar);
496 }
497 if (value == var->Max()) {
498 if (CapSub(var->Max(), var->Min()) == 1) {
499 return MakeEquality(MakeSum(var, -value + 1), boolvar);
500 }
501 return MakeIsGreaterOrEqualCstCt(var, value, boolvar);
502 }
503 if (boolvar->Bound()) {
504 if (boolvar->Min() == 0) {
505 return MakeNonEquality(var, value);
506 } else {
507 return MakeEquality(var, value);
508 }
509 }
510 // TODO(user) : what happens if the constraint is not posted?
511 // The cache becomes tainted.
512 model_cache_->InsertExprConstantExpression(
513 boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_EQUAL);
514 IntExpr* left = nullptr;
515 IntExpr* right = nullptr;
516 if (IsADifference(var, &left, &right)) {
517 return MakeIsEqualCt(left, MakeSum(right, value), boolvar);
518 } else {
519 return RevAlloc(new IsEqualCstCt(this, var->Var(), value, boolvar));
520 }
521 }
522
523 // ----- is_diff_cst Constraint -----
524
525 namespace {
526 class IsDiffCstCt : public CastConstraint {
527 public:
IsDiffCstCt(Solver * const s,IntVar * const v,int64_t c,IntVar * const b)528 IsDiffCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
529 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
530
Post()531 void Post() override {
532 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
533 var_->WhenDomain(demon_);
534 target_var_->WhenBound(demon_);
535 }
536
InitialPropagate()537 void InitialPropagate() override {
538 bool inhibit = var_->Bound();
539 int64_t l = 1 - var_->Contains(cst_);
540 int64_t u = inhibit ? l : 1;
541 target_var_->SetRange(l, u);
542 if (target_var_->Bound()) {
543 if (target_var_->Min() == 1) {
544 if (var_->Size() <= 0xFFFFFF) {
545 var_->RemoveValue(cst_);
546 inhibit = true;
547 }
548 } else {
549 var_->SetValue(cst_);
550 inhibit = true;
551 }
552 }
553 if (inhibit) {
554 demon_->inhibit(solver());
555 }
556 }
557
DebugString() const558 std::string DebugString() const override {
559 return absl::StrFormat("IsDiffCstCt(%s, %d, %s)", var_->DebugString(), cst_,
560 target_var_->DebugString());
561 }
562
Accept(ModelVisitor * const visitor) const563 void Accept(ModelVisitor* const visitor) const override {
564 visitor->BeginVisitConstraint(ModelVisitor::kIsDifferent, this);
565 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
566 var_);
567 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
568 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
569 target_var_);
570 visitor->EndVisitConstraint(ModelVisitor::kIsDifferent, this);
571 }
572
573 private:
574 IntVar* const var_;
575 int64_t cst_;
576 Demon* demon_;
577 };
578 } // namespace
579
MakeIsDifferentCstVar(IntExpr * const var,int64_t value)580 IntVar* Solver::MakeIsDifferentCstVar(IntExpr* const var, int64_t value) {
581 IntExpr* left = nullptr;
582 IntExpr* right = nullptr;
583 if (IsADifference(var, &left, &right)) {
584 return MakeIsDifferentVar(left, MakeSum(right, value));
585 }
586 return var->Var()->IsDifferent(value);
587 }
588
MakeIsDifferentCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)589 Constraint* Solver::MakeIsDifferentCstCt(IntExpr* const var, int64_t value,
590 IntVar* const boolvar) {
591 CHECK_EQ(this, var->solver());
592 CHECK_EQ(this, boolvar->solver());
593 if (value == var->Min()) {
594 return MakeIsGreaterOrEqualCstCt(var, value + 1, boolvar);
595 }
596 if (value == var->Max()) {
597 return MakeIsLessOrEqualCstCt(var, value - 1, boolvar);
598 }
599 if (var->IsVar() && !var->Var()->Contains(value)) {
600 return MakeEquality(boolvar, int64_t{1});
601 }
602 if (var->Bound() && var->Min() == value) {
603 return MakeEquality(boolvar, Zero());
604 }
605 if (boolvar->Bound()) {
606 if (boolvar->Min() == 0) {
607 return MakeEquality(var, value);
608 } else {
609 return MakeNonEquality(var, value);
610 }
611 }
612 model_cache_->InsertExprConstantExpression(
613 boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
614 IntExpr* left = nullptr;
615 IntExpr* right = nullptr;
616 if (IsADifference(var, &left, &right)) {
617 return MakeIsDifferentCt(left, MakeSum(right, value), boolvar);
618 } else {
619 return RevAlloc(new IsDiffCstCt(this, var->Var(), value, boolvar));
620 }
621 }
622
623 // ----- is_greater_equal_cst Constraint -----
624
625 namespace {
626 class IsGreaterEqualCstCt : public CastConstraint {
627 public:
IsGreaterEqualCstCt(Solver * const s,IntExpr * const v,int64_t c,IntVar * const b)628 IsGreaterEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
629 IntVar* const b)
630 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
Post()631 void Post() override {
632 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
633 expr_->WhenRange(demon_);
634 target_var_->WhenBound(demon_);
635 }
InitialPropagate()636 void InitialPropagate() override {
637 bool inhibit = false;
638 int64_t u = expr_->Max() >= cst_;
639 int64_t l = expr_->Min() >= cst_;
640 target_var_->SetRange(l, u);
641 if (target_var_->Bound()) {
642 inhibit = true;
643 if (target_var_->Min() == 0) {
644 expr_->SetMax(cst_ - 1);
645 } else {
646 expr_->SetMin(cst_);
647 }
648 }
649 if (inhibit && ((target_var_->Max() == 0 && expr_->Max() < cst_) ||
650 (target_var_->Min() == 1 && expr_->Min() >= cst_))) {
651 // Can we safely inhibit? Sometimes an expression is not
652 // persistent, just monotonic.
653 demon_->inhibit(solver());
654 }
655 }
DebugString() const656 std::string DebugString() const override {
657 return absl::StrFormat("IsGreaterEqualCstCt(%s, %d, %s)",
658 expr_->DebugString(), cst_,
659 target_var_->DebugString());
660 }
661
Accept(ModelVisitor * const visitor) const662 void Accept(ModelVisitor* const visitor) const override {
663 visitor->BeginVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
664 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
665 expr_);
666 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
667 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
668 target_var_);
669 visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
670 }
671
672 private:
673 IntExpr* const expr_;
674 int64_t cst_;
675 Demon* demon_;
676 };
677 } // namespace
678
MakeIsGreaterOrEqualCstVar(IntExpr * const var,int64_t value)679 IntVar* Solver::MakeIsGreaterOrEqualCstVar(IntExpr* const var, int64_t value) {
680 if (var->Min() >= value) {
681 return MakeIntConst(int64_t{1});
682 }
683 if (var->Max() < value) {
684 return MakeIntConst(int64_t{0});
685 }
686 if (var->IsVar()) {
687 return var->Var()->IsGreaterOrEqual(value);
688 } else {
689 IntVar* const boolvar =
690 MakeBoolVar(absl::StrFormat("Is(%s >= %d)", var->DebugString(), value));
691 AddConstraint(MakeIsGreaterOrEqualCstCt(var, value, boolvar));
692 return boolvar;
693 }
694 }
695
MakeIsGreaterCstVar(IntExpr * const var,int64_t value)696 IntVar* Solver::MakeIsGreaterCstVar(IntExpr* const var, int64_t value) {
697 return MakeIsGreaterOrEqualCstVar(var, value + 1);
698 }
699
MakeIsGreaterOrEqualCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)700 Constraint* Solver::MakeIsGreaterOrEqualCstCt(IntExpr* const var, int64_t value,
701 IntVar* const boolvar) {
702 if (boolvar->Bound()) {
703 if (boolvar->Min() == 0) {
704 return MakeLess(var, value);
705 } else {
706 return MakeGreaterOrEqual(var, value);
707 }
708 }
709 CHECK_EQ(this, var->solver());
710 CHECK_EQ(this, boolvar->solver());
711 model_cache_->InsertExprConstantExpression(
712 boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
713 return RevAlloc(new IsGreaterEqualCstCt(this, var, value, boolvar));
714 }
715
MakeIsGreaterCstCt(IntExpr * const v,int64_t c,IntVar * const b)716 Constraint* Solver::MakeIsGreaterCstCt(IntExpr* const v, int64_t c,
717 IntVar* const b) {
718 return MakeIsGreaterOrEqualCstCt(v, c + 1, b);
719 }
720
721 // ----- is_lesser_equal_cst Constraint -----
722
723 namespace {
724 class IsLessEqualCstCt : public CastConstraint {
725 public:
IsLessEqualCstCt(Solver * const s,IntExpr * const v,int64_t c,IntVar * const b)726 IsLessEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
727 IntVar* const b)
728 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
729
Post()730 void Post() override {
731 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
732 expr_->WhenRange(demon_);
733 target_var_->WhenBound(demon_);
734 }
735
InitialPropagate()736 void InitialPropagate() override {
737 bool inhibit = false;
738 int64_t u = expr_->Min() <= cst_;
739 int64_t l = expr_->Max() <= cst_;
740 target_var_->SetRange(l, u);
741 if (target_var_->Bound()) {
742 inhibit = true;
743 if (target_var_->Min() == 0) {
744 expr_->SetMin(cst_ + 1);
745 } else {
746 expr_->SetMax(cst_);
747 }
748 }
749 if (inhibit && ((target_var_->Max() == 0 && expr_->Min() > cst_) ||
750 (target_var_->Min() == 1 && expr_->Max() <= cst_))) {
751 // Can we safely inhibit? Sometimes an expression is not
752 // persistent, just monotonic.
753 demon_->inhibit(solver());
754 }
755 }
756
DebugString() const757 std::string DebugString() const override {
758 return absl::StrFormat("IsLessEqualCstCt(%s, %d, %s)", expr_->DebugString(),
759 cst_, target_var_->DebugString());
760 }
761
Accept(ModelVisitor * const visitor) const762 void Accept(ModelVisitor* const visitor) const override {
763 visitor->BeginVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
764 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
765 expr_);
766 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
767 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
768 target_var_);
769 visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
770 }
771
772 private:
773 IntExpr* const expr_;
774 int64_t cst_;
775 Demon* demon_;
776 };
777 } // namespace
778
MakeIsLessOrEqualCstVar(IntExpr * const var,int64_t value)779 IntVar* Solver::MakeIsLessOrEqualCstVar(IntExpr* const var, int64_t value) {
780 if (var->Max() <= value) {
781 return MakeIntConst(int64_t{1});
782 }
783 if (var->Min() > value) {
784 return MakeIntConst(int64_t{0});
785 }
786 if (var->IsVar()) {
787 return var->Var()->IsLessOrEqual(value);
788 } else {
789 IntVar* const boolvar =
790 MakeBoolVar(absl::StrFormat("Is(%s <= %d)", var->DebugString(), value));
791 AddConstraint(MakeIsLessOrEqualCstCt(var, value, boolvar));
792 return boolvar;
793 }
794 }
795
MakeIsLessCstVar(IntExpr * const var,int64_t value)796 IntVar* Solver::MakeIsLessCstVar(IntExpr* const var, int64_t value) {
797 return MakeIsLessOrEqualCstVar(var, value - 1);
798 }
799
MakeIsLessOrEqualCstCt(IntExpr * const var,int64_t value,IntVar * const boolvar)800 Constraint* Solver::MakeIsLessOrEqualCstCt(IntExpr* const var, int64_t value,
801 IntVar* const boolvar) {
802 if (boolvar->Bound()) {
803 if (boolvar->Min() == 0) {
804 return MakeGreater(var, value);
805 } else {
806 return MakeLessOrEqual(var, value);
807 }
808 }
809 CHECK_EQ(this, var->solver());
810 CHECK_EQ(this, boolvar->solver());
811 model_cache_->InsertExprConstantExpression(
812 boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
813 return RevAlloc(new IsLessEqualCstCt(this, var, value, boolvar));
814 }
815
MakeIsLessCstCt(IntExpr * const v,int64_t c,IntVar * const b)816 Constraint* Solver::MakeIsLessCstCt(IntExpr* const v, int64_t c,
817 IntVar* const b) {
818 return MakeIsLessOrEqualCstCt(v, c - 1, b);
819 }
820
821 // ----- BetweenCt -----
822
823 namespace {
824 class BetweenCt : public Constraint {
825 public:
BetweenCt(Solver * const s,IntExpr * const v,int64_t l,int64_t u)826 BetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
827 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
828
Post()829 void Post() override {
830 if (!expr_->IsVar()) {
831 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
832 expr_->WhenRange(demon_);
833 }
834 }
835
InitialPropagate()836 void InitialPropagate() override {
837 expr_->SetRange(min_, max_);
838 int64_t emin = 0;
839 int64_t emax = 0;
840 expr_->Range(&emin, &emax);
841 if (demon_ != nullptr && emin >= min_ && emax <= max_) {
842 demon_->inhibit(solver());
843 }
844 }
845
DebugString() const846 std::string DebugString() const override {
847 return absl::StrFormat("BetweenCt(%s, %d, %d)", expr_->DebugString(), min_,
848 max_);
849 }
850
Accept(ModelVisitor * const visitor) const851 void Accept(ModelVisitor* const visitor) const override {
852 visitor->BeginVisitConstraint(ModelVisitor::kBetween, this);
853 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
854 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
855 expr_);
856 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
857 visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
858 }
859
860 private:
861 IntExpr* const expr_;
862 int64_t min_;
863 int64_t max_;
864 Demon* demon_;
865 };
866
867 // ----- NonMember constraint -----
868
869 class NotBetweenCt : public Constraint {
870 public:
NotBetweenCt(Solver * const s,IntExpr * const v,int64_t l,int64_t u)871 NotBetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
872 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
873
Post()874 void Post() override {
875 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
876 expr_->WhenRange(demon_);
877 }
878
InitialPropagate()879 void InitialPropagate() override {
880 int64_t emin = 0;
881 int64_t emax = 0;
882 expr_->Range(&emin, &emax);
883 if (emin >= min_) {
884 expr_->SetMin(max_ + 1);
885 } else if (emax <= max_) {
886 expr_->SetMax(min_ - 1);
887 }
888
889 if (!expr_->IsVar() && (emax < min_ || emin > max_)) {
890 demon_->inhibit(solver());
891 }
892 }
893
DebugString() const894 std::string DebugString() const override {
895 return absl::StrFormat("NotBetweenCt(%s, %d, %d)", expr_->DebugString(),
896 min_, max_);
897 }
898
Accept(ModelVisitor * const visitor) const899 void Accept(ModelVisitor* const visitor) const override {
900 visitor->BeginVisitConstraint(ModelVisitor::kNotBetween, this);
901 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
902 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
903 expr_);
904 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
905 visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
906 }
907
908 private:
909 IntExpr* const expr_;
910 int64_t min_;
911 int64_t max_;
912 Demon* demon_;
913 };
914
ExtractExprProductCoeff(IntExpr ** expr)915 int64_t ExtractExprProductCoeff(IntExpr** expr) {
916 int64_t prod = 1;
917 int64_t coeff = 1;
918 while ((*expr)->solver()->IsProduct(*expr, expr, &coeff)) prod *= coeff;
919 return prod;
920 }
921 } // namespace
922
MakeBetweenCt(IntExpr * expr,int64_t l,int64_t u)923 Constraint* Solver::MakeBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
924 DCHECK_EQ(this, expr->solver());
925 // Catch empty and singleton intervals.
926 if (l >= u) {
927 if (l > u) return MakeFalseConstraint();
928 return MakeEquality(expr, l);
929 }
930 int64_t emin = 0;
931 int64_t emax = 0;
932 expr->Range(&emin, &emax);
933 // Catch the trivial cases first.
934 if (emax < l || emin > u) return MakeFalseConstraint();
935 if (emin >= l && emax <= u) return MakeTrueConstraint();
936 // Catch one-sided constraints.
937 if (emax <= u) return MakeGreaterOrEqual(expr, l);
938 if (emin >= l) return MakeLessOrEqual(expr, u);
939 // Simplify the common factor, if any.
940 int64_t coeff = ExtractExprProductCoeff(&expr);
941 if (coeff != 1) {
942 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
943 if (coeff < 0) {
944 std::swap(u, l);
945 u = -u;
946 l = -l;
947 coeff = -coeff;
948 }
949 return MakeBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff));
950 } else {
951 // No further reduction is possible.
952 return RevAlloc(new BetweenCt(this, expr, l, u));
953 }
954 }
955
MakeNotBetweenCt(IntExpr * expr,int64_t l,int64_t u)956 Constraint* Solver::MakeNotBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
957 DCHECK_EQ(this, expr->solver());
958 // Catch empty interval.
959 if (l > u) {
960 return MakeTrueConstraint();
961 }
962
963 int64_t emin = 0;
964 int64_t emax = 0;
965 expr->Range(&emin, &emax);
966 // Catch the trivial cases first.
967 if (emax < l || emin > u) return MakeTrueConstraint();
968 if (emin >= l && emax <= u) return MakeFalseConstraint();
969 // Catch one-sided constraints.
970 if (emin >= l) return MakeGreater(expr, u);
971 if (emax <= u) return MakeLess(expr, l);
972 // TODO(user): Add back simplification code if expr is constant *
973 // other_expr.
974 return RevAlloc(new NotBetweenCt(this, expr, l, u));
975 }
976
977 // ----- is_between_cst Constraint -----
978
979 namespace {
980 class IsBetweenCt : public Constraint {
981 public:
IsBetweenCt(Solver * const s,IntExpr * const e,int64_t l,int64_t u,IntVar * const b)982 IsBetweenCt(Solver* const s, IntExpr* const e, int64_t l, int64_t u,
983 IntVar* const b)
984 : Constraint(s),
985 expr_(e),
986 min_(l),
987 max_(u),
988 boolvar_(b),
989 demon_(nullptr) {}
990
Post()991 void Post() override {
992 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
993 expr_->WhenRange(demon_);
994 boolvar_->WhenBound(demon_);
995 }
996
InitialPropagate()997 void InitialPropagate() override {
998 bool inhibit = false;
999 int64_t emin = 0;
1000 int64_t emax = 0;
1001 expr_->Range(&emin, &emax);
1002 int64_t u = 1 - (emin > max_ || emax < min_);
1003 int64_t l = emax <= max_ && emin >= min_;
1004 boolvar_->SetRange(l, u);
1005 if (boolvar_->Bound()) {
1006 inhibit = true;
1007 if (boolvar_->Min() == 0) {
1008 if (expr_->IsVar()) {
1009 expr_->Var()->RemoveInterval(min_, max_);
1010 inhibit = true;
1011 } else if (emin > min_) {
1012 expr_->SetMin(max_ + 1);
1013 } else if (emax < max_) {
1014 expr_->SetMax(min_ - 1);
1015 }
1016 } else {
1017 expr_->SetRange(min_, max_);
1018 inhibit = true;
1019 }
1020 if (inhibit && expr_->IsVar()) {
1021 demon_->inhibit(solver());
1022 }
1023 }
1024 }
1025
DebugString() const1026 std::string DebugString() const override {
1027 return absl::StrFormat("IsBetweenCt(%s, %d, %d, %s)", expr_->DebugString(),
1028 min_, max_, boolvar_->DebugString());
1029 }
1030
Accept(ModelVisitor * const visitor) const1031 void Accept(ModelVisitor* const visitor) const override {
1032 visitor->BeginVisitConstraint(ModelVisitor::kIsBetween, this);
1033 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
1034 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1035 expr_);
1036 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
1037 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1038 boolvar_);
1039 visitor->EndVisitConstraint(ModelVisitor::kIsBetween, this);
1040 }
1041
1042 private:
1043 IntExpr* const expr_;
1044 int64_t min_;
1045 int64_t max_;
1046 IntVar* const boolvar_;
1047 Demon* demon_;
1048 };
1049 } // namespace
1050
MakeIsBetweenCt(IntExpr * expr,int64_t l,int64_t u,IntVar * const b)1051 Constraint* Solver::MakeIsBetweenCt(IntExpr* expr, int64_t l, int64_t u,
1052 IntVar* const b) {
1053 CHECK_EQ(this, expr->solver());
1054 CHECK_EQ(this, b->solver());
1055 // Catch empty and singleton intervals.
1056 if (l >= u) {
1057 if (l > u) return MakeEquality(b, Zero());
1058 return MakeIsEqualCstCt(expr, l, b);
1059 }
1060 int64_t emin = 0;
1061 int64_t emax = 0;
1062 expr->Range(&emin, &emax);
1063 // Catch the trivial cases first.
1064 if (emax < l || emin > u) return MakeEquality(b, Zero());
1065 if (emin >= l && emax <= u) return MakeEquality(b, 1);
1066 // Catch one-sided constraints.
1067 if (emax <= u) return MakeIsGreaterOrEqualCstCt(expr, l, b);
1068 if (emin >= l) return MakeIsLessOrEqualCstCt(expr, u, b);
1069 // Simplify the common factor, if any.
1070 int64_t coeff = ExtractExprProductCoeff(&expr);
1071 if (coeff != 1) {
1072 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
1073 if (coeff < 0) {
1074 std::swap(u, l);
1075 u = -u;
1076 l = -l;
1077 coeff = -coeff;
1078 }
1079 return MakeIsBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff),
1080 b);
1081 } else {
1082 // No further reduction is possible.
1083 return RevAlloc(new IsBetweenCt(this, expr, l, u, b));
1084 }
1085 }
1086
MakeIsBetweenVar(IntExpr * const v,int64_t l,int64_t u)1087 IntVar* Solver::MakeIsBetweenVar(IntExpr* const v, int64_t l, int64_t u) {
1088 CHECK_EQ(this, v->solver());
1089 IntVar* const b = MakeBoolVar();
1090 AddConstraint(MakeIsBetweenCt(v, l, u, b));
1091 return b;
1092 }
1093
1094 // ---------- Member ----------
1095
1096 // ----- Member(IntVar, IntSet) -----
1097
1098 namespace {
1099 // TODO(user): Do not create holes on expressions.
1100 class MemberCt : public Constraint {
1101 public:
MemberCt(Solver * const s,IntVar * const v,const std::vector<int64_t> & sorted_values)1102 MemberCt(Solver* const s, IntVar* const v,
1103 const std::vector<int64_t>& sorted_values)
1104 : Constraint(s), var_(v), values_(sorted_values) {
1105 DCHECK(v != nullptr);
1106 DCHECK(s != nullptr);
1107 }
1108
Post()1109 void Post() override {}
1110
InitialPropagate()1111 void InitialPropagate() override { var_->SetValues(values_); }
1112
DebugString() const1113 std::string DebugString() const override {
1114 return absl::StrFormat("Member(%s, %s)", var_->DebugString(),
1115 absl::StrJoin(values_, ", "));
1116 }
1117
Accept(ModelVisitor * const visitor) const1118 void Accept(ModelVisitor* const visitor) const override {
1119 visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1120 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1121 var_);
1122 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1123 visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1124 }
1125
1126 private:
1127 IntVar* const var_;
1128 const std::vector<int64_t> values_;
1129 };
1130
1131 class NotMemberCt : public Constraint {
1132 public:
NotMemberCt(Solver * const s,IntVar * const v,const std::vector<int64_t> & sorted_values)1133 NotMemberCt(Solver* const s, IntVar* const v,
1134 const std::vector<int64_t>& sorted_values)
1135 : Constraint(s), var_(v), values_(sorted_values) {
1136 DCHECK(v != nullptr);
1137 DCHECK(s != nullptr);
1138 }
1139
Post()1140 void Post() override {}
1141
InitialPropagate()1142 void InitialPropagate() override { var_->RemoveValues(values_); }
1143
DebugString() const1144 std::string DebugString() const override {
1145 return absl::StrFormat("NotMember(%s, %s)", var_->DebugString(),
1146 absl::StrJoin(values_, ", "));
1147 }
1148
Accept(ModelVisitor * const visitor) const1149 void Accept(ModelVisitor* const visitor) const override {
1150 visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1151 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1152 var_);
1153 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1154 visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1155 }
1156
1157 private:
1158 IntVar* const var_;
1159 const std::vector<int64_t> values_;
1160 };
1161 } // namespace
1162
MakeMemberCt(IntExpr * expr,const std::vector<int64_t> & values)1163 Constraint* Solver::MakeMemberCt(IntExpr* expr,
1164 const std::vector<int64_t>& values) {
1165 const int64_t coeff = ExtractExprProductCoeff(&expr);
1166 if (coeff == 0) {
1167 return std::find(values.begin(), values.end(), 0) == values.end()
1168 ? MakeFalseConstraint()
1169 : MakeTrueConstraint();
1170 }
1171 std::vector<int64_t> copied_values = values;
1172 // If the expression is a non-trivial product, we filter out the values that
1173 // aren't multiples of "coeff", and divide them.
1174 if (coeff != 1) {
1175 int num_kept = 0;
1176 for (const int64_t v : copied_values) {
1177 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1178 }
1179 copied_values.resize(num_kept);
1180 }
1181 // Filter out the values that are outside the [Min, Max] interval.
1182 int num_kept = 0;
1183 int64_t emin;
1184 int64_t emax;
1185 expr->Range(&emin, &emax);
1186 for (const int64_t v : copied_values) {
1187 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1188 }
1189 copied_values.resize(num_kept);
1190 // Catch empty set.
1191 if (copied_values.empty()) return MakeFalseConstraint();
1192 // Sort and remove duplicates.
1193 gtl::STLSortAndRemoveDuplicates(&copied_values);
1194 // Special case for singleton.
1195 if (copied_values.size() == 1) return MakeEquality(expr, copied_values[0]);
1196 // Catch contiguous intervals.
1197 if (copied_values.size() ==
1198 copied_values.back() - copied_values.front() + 1) {
1199 // Note: MakeBetweenCt() has a fast-track for trivially true constraints.
1200 return MakeBetweenCt(expr, copied_values.front(), copied_values.back());
1201 }
1202 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1203 // "values" is smaller than "values", then it's more efficient to use
1204 // NotMemberCt. Catch that case here.
1205 if (emax - emin < 2 * copied_values.size()) {
1206 // Convert "copied_values" to list the values *not* allowed.
1207 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1208 for (const int64_t v : copied_values)
1209 is_among_input_values[v - emin] = true;
1210 // We use the zero valued indices of is_among_input_values to build the
1211 // complement of copied_values.
1212 copied_values.clear();
1213 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1214 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1215 }
1216 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1217 // "values" input) was caught earlier, by the "contiguous interval" case.
1218 DCHECK_GE(copied_values.size(), 1);
1219 if (copied_values.size() == 1) {
1220 return MakeNonEquality(expr, copied_values[0]);
1221 }
1222 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1223 }
1224 // Otherwise, just use MemberCt. No further reduction is possible.
1225 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1226 }
1227
MakeMemberCt(IntExpr * const expr,const std::vector<int> & values)1228 Constraint* Solver::MakeMemberCt(IntExpr* const expr,
1229 const std::vector<int>& values) {
1230 return MakeMemberCt(expr, ToInt64Vector(values));
1231 }
1232
MakeNotMemberCt(IntExpr * expr,const std::vector<int64_t> & values)1233 Constraint* Solver::MakeNotMemberCt(IntExpr* expr,
1234 const std::vector<int64_t>& values) {
1235 const int64_t coeff = ExtractExprProductCoeff(&expr);
1236 if (coeff == 0) {
1237 return std::find(values.begin(), values.end(), 0) == values.end()
1238 ? MakeTrueConstraint()
1239 : MakeFalseConstraint();
1240 }
1241 std::vector<int64_t> copied_values = values;
1242 // If the expression is a non-trivial product, we filter out the values that
1243 // aren't multiples of "coeff", and divide them.
1244 if (coeff != 1) {
1245 int num_kept = 0;
1246 for (const int64_t v : copied_values) {
1247 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1248 }
1249 copied_values.resize(num_kept);
1250 }
1251 // Filter out the values that are outside the [Min, Max] interval.
1252 int num_kept = 0;
1253 int64_t emin;
1254 int64_t emax;
1255 expr->Range(&emin, &emax);
1256 for (const int64_t v : copied_values) {
1257 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1258 }
1259 copied_values.resize(num_kept);
1260 // Catch empty set.
1261 if (copied_values.empty()) return MakeTrueConstraint();
1262 // Sort and remove duplicates.
1263 gtl::STLSortAndRemoveDuplicates(&copied_values);
1264 // Special case for singleton.
1265 if (copied_values.size() == 1) return MakeNonEquality(expr, copied_values[0]);
1266 // Catch contiguous intervals.
1267 if (copied_values.size() ==
1268 copied_values.back() - copied_values.front() + 1) {
1269 return MakeNotBetweenCt(expr, copied_values.front(), copied_values.back());
1270 }
1271 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1272 // "values" is smaller than "values", then it's more efficient to use
1273 // MemberCt. Catch that case here.
1274 if (emax - emin < 2 * copied_values.size()) {
1275 // Convert "copied_values" to a dense boolean vector.
1276 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1277 for (const int64_t v : copied_values)
1278 is_among_input_values[v - emin] = true;
1279 // Use zero valued indices for is_among_input_values to build the
1280 // complement of copied_values.
1281 copied_values.clear();
1282 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1283 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1284 }
1285 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1286 // "values" input) was caught earlier, by the "contiguous interval" case.
1287 DCHECK_GE(copied_values.size(), 1);
1288 if (copied_values.size() == 1) {
1289 return MakeEquality(expr, copied_values[0]);
1290 }
1291 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1292 }
1293 // Otherwise, just use NotMemberCt. No further reduction is possible.
1294 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1295 }
1296
MakeNotMemberCt(IntExpr * const expr,const std::vector<int> & values)1297 Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1298 const std::vector<int>& values) {
1299 return MakeNotMemberCt(expr, ToInt64Vector(values));
1300 }
1301
1302 // ----- IsMemberCt -----
1303
1304 namespace {
1305 class IsMemberCt : public Constraint {
1306 public:
IsMemberCt(Solver * const s,IntVar * const v,const std::vector<int64_t> & sorted_values,IntVar * const b)1307 IsMemberCt(Solver* const s, IntVar* const v,
1308 const std::vector<int64_t>& sorted_values, IntVar* const b)
1309 : Constraint(s),
1310 var_(v),
1311 values_as_set_(sorted_values.begin(), sorted_values.end()),
1312 values_(sorted_values),
1313 boolvar_(b),
1314 support_(0),
1315 demon_(nullptr),
1316 domain_(var_->MakeDomainIterator(true)),
1317 neg_support_(std::numeric_limits<int64_t>::min()) {
1318 DCHECK(v != nullptr);
1319 DCHECK(s != nullptr);
1320 DCHECK(b != nullptr);
1321 while (values_as_set_.contains(neg_support_)) {
1322 neg_support_++;
1323 }
1324 }
1325
Post()1326 void Post() override {
1327 demon_ = MakeConstraintDemon0(solver(), this, &IsMemberCt::VarDomain,
1328 "VarDomain");
1329 if (!var_->Bound()) {
1330 var_->WhenDomain(demon_);
1331 }
1332 if (!boolvar_->Bound()) {
1333 Demon* const bdemon = MakeConstraintDemon0(
1334 solver(), this, &IsMemberCt::TargetBound, "TargetBound");
1335 boolvar_->WhenBound(bdemon);
1336 }
1337 }
1338
InitialPropagate()1339 void InitialPropagate() override {
1340 boolvar_->SetRange(0, 1);
1341 if (boolvar_->Bound()) {
1342 TargetBound();
1343 } else {
1344 VarDomain();
1345 }
1346 }
1347
DebugString() const1348 std::string DebugString() const override {
1349 return absl::StrFormat("IsMemberCt(%s, %s, %s)", var_->DebugString(),
1350 absl::StrJoin(values_, ", "),
1351 boolvar_->DebugString());
1352 }
1353
Accept(ModelVisitor * const visitor) const1354 void Accept(ModelVisitor* const visitor) const override {
1355 visitor->BeginVisitConstraint(ModelVisitor::kIsMember, this);
1356 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1357 var_);
1358 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1359 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1360 boolvar_);
1361 visitor->EndVisitConstraint(ModelVisitor::kIsMember, this);
1362 }
1363
1364 private:
VarDomain()1365 void VarDomain() {
1366 if (boolvar_->Bound()) {
1367 TargetBound();
1368 } else {
1369 for (int offset = 0; offset < values_.size(); ++offset) {
1370 const int candidate = (support_ + offset) % values_.size();
1371 if (var_->Contains(values_[candidate])) {
1372 support_ = candidate;
1373 if (var_->Bound()) {
1374 demon_->inhibit(solver());
1375 boolvar_->SetValue(1);
1376 return;
1377 }
1378 // We have found a positive support. Let's check the
1379 // negative support.
1380 if (var_->Contains(neg_support_)) {
1381 return;
1382 } else {
1383 // Look for a new negative support.
1384 for (const int64_t value : InitAndGetValues(domain_)) {
1385 if (!values_as_set_.contains(value)) {
1386 neg_support_ = value;
1387 return;
1388 }
1389 }
1390 }
1391 // No negative support, setting boolvar to true.
1392 demon_->inhibit(solver());
1393 boolvar_->SetValue(1);
1394 return;
1395 }
1396 }
1397 // No positive support, setting boolvar to false.
1398 demon_->inhibit(solver());
1399 boolvar_->SetValue(0);
1400 }
1401 }
1402
TargetBound()1403 void TargetBound() {
1404 DCHECK(boolvar_->Bound());
1405 if (boolvar_->Min() == 1LL) {
1406 demon_->inhibit(solver());
1407 var_->SetValues(values_);
1408 } else {
1409 demon_->inhibit(solver());
1410 var_->RemoveValues(values_);
1411 }
1412 }
1413
1414 IntVar* const var_;
1415 absl::flat_hash_set<int64_t> values_as_set_;
1416 std::vector<int64_t> values_;
1417 IntVar* const boolvar_;
1418 int support_;
1419 Demon* demon_;
1420 IntVarIterator* const domain_;
1421 int64_t neg_support_;
1422 };
1423
1424 template <class T>
BuildIsMemberCt(Solver * const solver,IntExpr * const expr,const std::vector<T> & values,IntVar * const boolvar)1425 Constraint* BuildIsMemberCt(Solver* const solver, IntExpr* const expr,
1426 const std::vector<T>& values,
1427 IntVar* const boolvar) {
1428 // TODO(user): optimize this by copying the code from MakeMemberCt.
1429 // Simplify and filter if expr is a product.
1430 IntExpr* sub = nullptr;
1431 int64_t coef = 1;
1432 if (solver->IsProduct(expr, &sub, &coef) && coef != 0 && coef != 1) {
1433 std::vector<int64_t> new_values;
1434 new_values.reserve(values.size());
1435 for (const int64_t value : values) {
1436 if (value % coef == 0) {
1437 new_values.push_back(value / coef);
1438 }
1439 }
1440 return BuildIsMemberCt(solver, sub, new_values, boolvar);
1441 }
1442
1443 std::set<T> set_of_values(values.begin(), values.end());
1444 std::vector<int64_t> filtered_values;
1445 bool all_values = false;
1446 if (expr->IsVar()) {
1447 IntVar* const var = expr->Var();
1448 for (const T value : set_of_values) {
1449 if (var->Contains(value)) {
1450 filtered_values.push_back(value);
1451 }
1452 }
1453 all_values = (filtered_values.size() == var->Size());
1454 } else {
1455 int64_t emin = 0;
1456 int64_t emax = 0;
1457 expr->Range(&emin, &emax);
1458 for (const T value : set_of_values) {
1459 if (value >= emin && value <= emax) {
1460 filtered_values.push_back(value);
1461 }
1462 }
1463 all_values = (filtered_values.size() == emax - emin + 1);
1464 }
1465 if (filtered_values.empty()) {
1466 return solver->MakeEquality(boolvar, Zero());
1467 } else if (all_values) {
1468 return solver->MakeEquality(boolvar, 1);
1469 } else if (filtered_values.size() == 1) {
1470 return solver->MakeIsEqualCstCt(expr, filtered_values.back(), boolvar);
1471 } else if (filtered_values.back() ==
1472 filtered_values.front() + filtered_values.size() - 1) {
1473 // Contiguous
1474 return solver->MakeIsBetweenCt(expr, filtered_values.front(),
1475 filtered_values.back(), boolvar);
1476 } else {
1477 return solver->RevAlloc(
1478 new IsMemberCt(solver, expr->Var(), filtered_values, boolvar));
1479 }
1480 }
1481 } // namespace
1482
MakeIsMemberCt(IntExpr * const expr,const std::vector<int64_t> & values,IntVar * const boolvar)1483 Constraint* Solver::MakeIsMemberCt(IntExpr* const expr,
1484 const std::vector<int64_t>& values,
1485 IntVar* const boolvar) {
1486 return BuildIsMemberCt(this, expr, values, boolvar);
1487 }
1488
MakeIsMemberCt(IntExpr * const expr,const std::vector<int> & values,IntVar * const boolvar)1489 Constraint* Solver::MakeIsMemberCt(IntExpr* const expr,
1490 const std::vector<int>& values,
1491 IntVar* const boolvar) {
1492 return BuildIsMemberCt(this, expr, values, boolvar);
1493 }
1494
MakeIsMemberVar(IntExpr * const expr,const std::vector<int64_t> & values)1495 IntVar* Solver::MakeIsMemberVar(IntExpr* const expr,
1496 const std::vector<int64_t>& values) {
1497 IntVar* const b = MakeBoolVar();
1498 AddConstraint(MakeIsMemberCt(expr, values, b));
1499 return b;
1500 }
1501
MakeIsMemberVar(IntExpr * const expr,const std::vector<int> & values)1502 IntVar* Solver::MakeIsMemberVar(IntExpr* const expr,
1503 const std::vector<int>& values) {
1504 IntVar* const b = MakeBoolVar();
1505 AddConstraint(MakeIsMemberCt(expr, values, b));
1506 return b;
1507 }
1508
1509 namespace {
1510 class SortedDisjointForbiddenIntervalsConstraint : public Constraint {
1511 public:
SortedDisjointForbiddenIntervalsConstraint(Solver * const solver,IntVar * const var,SortedDisjointIntervalList intervals)1512 SortedDisjointForbiddenIntervalsConstraint(
1513 Solver* const solver, IntVar* const var,
1514 SortedDisjointIntervalList intervals)
1515 : Constraint(solver), var_(var), intervals_(std::move(intervals)) {}
1516
~SortedDisjointForbiddenIntervalsConstraint()1517 ~SortedDisjointForbiddenIntervalsConstraint() override {}
1518
Post()1519 void Post() override {
1520 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1521 var_->WhenRange(demon);
1522 }
1523
InitialPropagate()1524 void InitialPropagate() override {
1525 const int64_t vmin = var_->Min();
1526 const int64_t vmax = var_->Max();
1527 const auto first_interval_it = intervals_.FirstIntervalGreaterOrEqual(vmin);
1528 if (first_interval_it == intervals_.end()) {
1529 // No interval intersects the variable's range. Nothing to do.
1530 return;
1531 }
1532 const auto last_interval_it = intervals_.LastIntervalLessOrEqual(vmax);
1533 if (last_interval_it == intervals_.end()) {
1534 // No interval intersects the variable's range. Nothing to do.
1535 return;
1536 }
1537 // TODO(user): Quick fail if first_interval_it == last_interval_it, which
1538 // would imply that the interval contains the entire range of the variable?
1539 if (vmin >= first_interval_it->start) {
1540 // The variable's minimum is inside a forbidden interval. Move it to the
1541 // interval's end.
1542 var_->SetMin(CapAdd(first_interval_it->end, 1));
1543 }
1544 if (vmax <= last_interval_it->end) {
1545 // Ditto, on the other side.
1546 var_->SetMax(CapSub(last_interval_it->start, 1));
1547 }
1548 }
1549
DebugString() const1550 std::string DebugString() const override {
1551 return absl::StrFormat("ForbiddenIntervalCt(%s, %s)", var_->DebugString(),
1552 intervals_.DebugString());
1553 }
1554
Accept(ModelVisitor * const visitor) const1555 void Accept(ModelVisitor* const visitor) const override {
1556 visitor->BeginVisitConstraint(ModelVisitor::kNotMember, this);
1557 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1558 var_);
1559 std::vector<int64_t> starts;
1560 std::vector<int64_t> ends;
1561 for (auto& interval : intervals_) {
1562 starts.push_back(interval.start);
1563 ends.push_back(interval.end);
1564 }
1565 visitor->VisitIntegerArrayArgument(ModelVisitor::kStartsArgument, starts);
1566 visitor->VisitIntegerArrayArgument(ModelVisitor::kEndsArgument, ends);
1567 visitor->EndVisitConstraint(ModelVisitor::kNotMember, this);
1568 }
1569
1570 private:
1571 IntVar* const var_;
1572 const SortedDisjointIntervalList intervals_;
1573 };
1574 } // namespace
1575
MakeNotMemberCt(IntExpr * const expr,std::vector<int64_t> starts,std::vector<int64_t> ends)1576 Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1577 std::vector<int64_t> starts,
1578 std::vector<int64_t> ends) {
1579 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1580 this, expr->Var(), {starts, ends}));
1581 }
1582
MakeNotMemberCt(IntExpr * const expr,std::vector<int> starts,std::vector<int> ends)1583 Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1584 std::vector<int> starts,
1585 std::vector<int> ends) {
1586 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1587 this, expr->Var(), {starts, ends}));
1588 }
1589
MakeNotMemberCt(IntExpr * expr,SortedDisjointIntervalList intervals)1590 Constraint* Solver::MakeNotMemberCt(IntExpr* expr,
1591 SortedDisjointIntervalList intervals) {
1592 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1593 this, expr->Var(), std::move(intervals)));
1594 }
1595 } // namespace operations_research
1596