1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 #include <algorithm>
15 #include <cmath>
16 #include <cstdint>
17 #include <limits>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "ortools/base/commandlineflags.h"
27 #include "ortools/base/integral_types.h"
28 #include "ortools/base/logging.h"
29 #include "ortools/base/map_util.h"
30 #include "ortools/base/mathutil.h"
31 #include "ortools/base/stl_util.h"
32 #include "ortools/constraint_solver/constraint_solver.h"
33 #include "ortools/constraint_solver/constraint_solveri.h"
34 #include "ortools/util/bitset.h"
35 #include "ortools/util/saturated_arithmetic.h"
36 #include "ortools/util/string_array.h"
37
38 ABSL_FLAG(bool, cp_disable_expression_optimization, false,
39 "Disable special optimization when creating expressions.");
40 ABSL_FLAG(bool, cp_share_int_consts, true,
41 "Share IntConst's with the same value.");
42
43 #if defined(_MSC_VER)
44 #pragma warning(disable : 4351 4355)
45 #endif
46
47 namespace operations_research {
48
49 // ---------- IntExpr ----------
50
VarWithName(const std::string & name)51 IntVar* IntExpr::VarWithName(const std::string& name) {
52 IntVar* const var = Var();
53 var->set_name(name);
54 return var;
55 }
56
57 // ---------- IntVar ----------
58
IntVar(Solver * const s)59 IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
60
IntVar(Solver * const s,const std::string & name)61 IntVar::IntVar(Solver* const s, const std::string& name)
62 : IntExpr(s), index_(s->GetNewIntVarIndex()) {
63 set_name(name);
64 }
65
66 // ----- Boolean variable -----
67
68 const int BooleanVar::kUnboundBooleanVarValue = 2;
69
SetMin(int64_t m)70 void BooleanVar::SetMin(int64_t m) {
71 if (m <= 0) return;
72 if (m > 1) solver()->Fail();
73 SetValue(1);
74 }
75
SetMax(int64_t m)76 void BooleanVar::SetMax(int64_t m) {
77 if (m >= 1) return;
78 if (m < 0) solver()->Fail();
79 SetValue(0);
80 }
81
SetRange(int64_t mi,int64_t ma)82 void BooleanVar::SetRange(int64_t mi, int64_t ma) {
83 if (mi > 1 || ma < 0 || mi > ma) {
84 solver()->Fail();
85 }
86 if (mi == 1) {
87 SetValue(1);
88 } else if (ma == 0) {
89 SetValue(0);
90 }
91 }
92
RemoveValue(int64_t v)93 void BooleanVar::RemoveValue(int64_t v) {
94 if (value_ == kUnboundBooleanVarValue) {
95 if (v == 0) {
96 SetValue(1);
97 } else if (v == 1) {
98 SetValue(0);
99 }
100 } else if (v == value_) {
101 solver()->Fail();
102 }
103 }
104
RemoveInterval(int64_t l,int64_t u)105 void BooleanVar::RemoveInterval(int64_t l, int64_t u) {
106 if (u < l) return;
107 if (l <= 0 && u >= 1) {
108 solver()->Fail();
109 } else if (l == 1) {
110 SetValue(0);
111 } else if (u == 0) {
112 SetValue(1);
113 }
114 }
115
WhenBound(Demon * d)116 void BooleanVar::WhenBound(Demon* d) {
117 if (value_ == kUnboundBooleanVarValue) {
118 if (d->priority() == Solver::DELAYED_PRIORITY) {
119 delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
120 } else {
121 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
122 }
123 }
124 }
125
Size() const126 uint64_t BooleanVar::Size() const {
127 return (1 + (value_ == kUnboundBooleanVarValue));
128 }
129
Contains(int64_t v) const130 bool BooleanVar::Contains(int64_t v) const {
131 return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
132 }
133
IsEqual(int64_t constant)134 IntVar* BooleanVar::IsEqual(int64_t constant) {
135 if (constant > 1 || constant < 0) {
136 return solver()->MakeIntConst(0);
137 }
138 if (constant == 1) {
139 return this;
140 } else { // constant == 0.
141 return solver()->MakeDifference(1, this)->Var();
142 }
143 }
144
IsDifferent(int64_t constant)145 IntVar* BooleanVar::IsDifferent(int64_t constant) {
146 if (constant > 1 || constant < 0) {
147 return solver()->MakeIntConst(1);
148 }
149 if (constant == 1) {
150 return solver()->MakeDifference(1, this)->Var();
151 } else { // constant == 0.
152 return this;
153 }
154 }
155
IsGreaterOrEqual(int64_t constant)156 IntVar* BooleanVar::IsGreaterOrEqual(int64_t constant) {
157 if (constant > 1) {
158 return solver()->MakeIntConst(0);
159 } else if (constant <= 0) {
160 return solver()->MakeIntConst(1);
161 } else {
162 return this;
163 }
164 }
165
IsLessOrEqual(int64_t constant)166 IntVar* BooleanVar::IsLessOrEqual(int64_t constant) {
167 if (constant < 0) {
168 return solver()->MakeIntConst(0);
169 } else if (constant >= 1) {
170 return solver()->MakeIntConst(1);
171 } else {
172 return IsEqual(0);
173 }
174 }
175
DebugString() const176 std::string BooleanVar::DebugString() const {
177 std::string out;
178 const std::string& var_name = name();
179 if (!var_name.empty()) {
180 out = var_name + "(";
181 } else {
182 out = "BooleanVar(";
183 }
184 switch (value_) {
185 case 0:
186 out += "0";
187 break;
188 case 1:
189 out += "1";
190 break;
191 case kUnboundBooleanVarValue:
192 out += "0 .. 1";
193 break;
194 }
195 out += ")";
196 return out;
197 }
198
199 namespace {
200 // ---------- Subclasses of IntVar ----------
201
202 // ----- Domain Int Var: base class for variables -----
203 // It Contains bounds and a bitset representation of possible values.
204 class DomainIntVar : public IntVar {
205 public:
206 // Utility classes
207 class BitSetIterator : public BaseObject {
208 public:
BitSetIterator(uint64_t * const bitset,int64_t omin)209 BitSetIterator(uint64_t* const bitset, int64_t omin)
210 : bitset_(bitset),
211 omin_(omin),
212 max_(std::numeric_limits<int64_t>::min()),
213 current_(std::numeric_limits<int64_t>::max()) {}
214
~BitSetIterator()215 ~BitSetIterator() override {}
216
Init(int64_t min,int64_t max)217 void Init(int64_t min, int64_t max) {
218 max_ = max;
219 current_ = min;
220 }
221
Ok() const222 bool Ok() const { return current_ <= max_; }
223
Value() const224 int64_t Value() const { return current_; }
225
Next()226 void Next() {
227 if (++current_ <= max_) {
228 current_ = UnsafeLeastSignificantBitPosition64(
229 bitset_, current_ - omin_, max_ - omin_) +
230 omin_;
231 }
232 }
233
DebugString() const234 std::string DebugString() const override { return "BitSetIterator"; }
235
236 private:
237 uint64_t* const bitset_;
238 const int64_t omin_;
239 int64_t max_;
240 int64_t current_;
241 };
242
243 class BitSet : public BaseObject {
244 public:
BitSet(Solver * const s)245 explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
~BitSet()246 ~BitSet() override {}
247
248 virtual int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) = 0;
249 virtual int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) = 0;
250 virtual bool Contains(int64_t val) const = 0;
251 virtual bool SetValue(int64_t val) = 0;
252 virtual bool RemoveValue(int64_t val) = 0;
253 virtual uint64_t Size() const = 0;
254 virtual void DelayRemoveValue(int64_t val) = 0;
255 virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
256 virtual void ClearRemovedValues() = 0;
257 virtual std::string pretty_DebugString(int64_t min, int64_t max) const = 0;
258 virtual BitSetIterator* MakeIterator() = 0;
259
InitHoles()260 void InitHoles() {
261 const uint64_t current_stamp = solver_->stamp();
262 if (holes_stamp_ < current_stamp) {
263 holes_.clear();
264 holes_stamp_ = current_stamp;
265 }
266 }
267
ClearHoles()268 virtual void ClearHoles() { holes_.clear(); }
269
Holes()270 const std::vector<int64_t>& Holes() { return holes_; }
271
AddHole(int64_t value)272 void AddHole(int64_t value) { holes_.push_back(value); }
273
NumHoles() const274 int NumHoles() const {
275 return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
276 }
277
278 protected:
279 Solver* const solver_;
280
281 private:
282 std::vector<int64_t> holes_;
283 uint64_t holes_stamp_;
284 };
285
286 class QueueHandler : public Demon {
287 public:
QueueHandler(DomainIntVar * const var)288 explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
~QueueHandler()289 ~QueueHandler() override {}
Run(Solver * const s)290 void Run(Solver* const s) override {
291 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
292 var_->Process();
293 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
294 }
priority() const295 Solver::DemonPriority priority() const override {
296 return Solver::VAR_PRIORITY;
297 }
DebugString() const298 std::string DebugString() const override {
299 return absl::StrFormat("Handler(%s)", var_->DebugString());
300 }
301
302 private:
303 DomainIntVar* const var_;
304 };
305
306 // Bounds and Value watchers
307
308 // This class stores the watchers variables attached to values. It is
309 // reversible and it helps maintaining the set of 'active' watchers
310 // (variables not bound to a single value).
311 template <class T>
312 class RevIntPtrMap {
313 public:
RevIntPtrMap(Solver * const solver,int64_t rmin,int64_t rmax)314 RevIntPtrMap(Solver* const solver, int64_t rmin, int64_t rmax)
315 : solver_(solver), range_min_(rmin), start_(0) {}
316
~RevIntPtrMap()317 ~RevIntPtrMap() {}
318
Empty() const319 bool Empty() const { return start_.Value() == elements_.size(); }
320
SortActive()321 void SortActive() { std::sort(elements_.begin(), elements_.end()); }
322
323 // Access with value API.
324
325 // Add the pointer to the map attached to the given value.
UnsafeRevInsert(int64_t value,T * elem)326 void UnsafeRevInsert(int64_t value, T* elem) {
327 elements_.push_back(std::make_pair(value, elem));
328 if (solver_->state() != Solver::OUTSIDE_SEARCH) {
329 solver_->AddBacktrackAction(
330 [this, value](Solver* s) { Uninsert(value); }, false);
331 }
332 }
333
FindPtrOrNull(int64_t value,int * position)334 T* FindPtrOrNull(int64_t value, int* position) {
335 for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
336 if (elements_[pos].first == value) {
337 if (position != nullptr) *position = pos;
338 return At(pos).second;
339 }
340 }
341 return nullptr;
342 }
343
344 // Access map through the underlying vector.
RemoveAt(int position)345 void RemoveAt(int position) {
346 const int start = start_.Value();
347 DCHECK_GE(position, start);
348 DCHECK_LT(position, elements_.size());
349 if (position > start) {
350 // Swap the current element with the one at the start position, and
351 // increase start.
352 const std::pair<int64_t, T*> copy = elements_[start];
353 elements_[start] = elements_[position];
354 elements_[position] = copy;
355 }
356 start_.Incr(solver_);
357 }
358
At(int position) const359 const std::pair<int64_t, T*>& At(int position) const {
360 DCHECK_GE(position, start_.Value());
361 DCHECK_LT(position, elements_.size());
362 return elements_[position];
363 }
364
RemoveAll()365 void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
366
start() const367 int start() const { return start_.Value(); }
end() const368 int end() const { return elements_.size(); }
369 // Number of active elements.
Size() const370 int Size() const { return elements_.size() - start_.Value(); }
371
372 // Removes the object permanently from the map.
Uninsert(int64_t value)373 void Uninsert(int64_t value) {
374 for (int pos = 0; pos < elements_.size(); ++pos) {
375 if (elements_[pos].first == value) {
376 DCHECK_GE(pos, start_.Value());
377 const int last = elements_.size() - 1;
378 if (pos != last) { // Swap the current with the last.
379 elements_[pos] = elements_.back();
380 }
381 elements_.pop_back();
382 return;
383 }
384 }
385 LOG(FATAL) << "The element should have been removed";
386 }
387
388 private:
389 Solver* const solver_;
390 const int64_t range_min_;
391 NumericalRev<int> start_;
392 std::vector<std::pair<int64_t, T*>> elements_;
393 };
394
395 // Base class for value watchers
396 class BaseValueWatcher : public Constraint {
397 public:
BaseValueWatcher(Solver * const solver)398 explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
399
~BaseValueWatcher()400 ~BaseValueWatcher() override {}
401
402 virtual IntVar* GetOrMakeValueWatcher(int64_t value) = 0;
403
404 virtual void SetValueWatcher(IntVar* const boolvar, int64_t value) = 0;
405 };
406
407 // This class monitors the domain of the variable and updates the
408 // IsEqual/IsDifferent boolean variables accordingly.
409 class ValueWatcher : public BaseValueWatcher {
410 public:
411 class WatchDemon : public Demon {
412 public:
WatchDemon(ValueWatcher * const watcher,int64_t value,IntVar * var)413 WatchDemon(ValueWatcher* const watcher, int64_t value, IntVar* var)
414 : value_watcher_(watcher), value_(value), var_(var) {}
~WatchDemon()415 ~WatchDemon() override {}
416
Run(Solver * const solver)417 void Run(Solver* const solver) override {
418 value_watcher_->ProcessValueWatcher(value_, var_);
419 }
420
421 private:
422 ValueWatcher* const value_watcher_;
423 const int64_t value_;
424 IntVar* const var_;
425 };
426
427 class VarDemon : public Demon {
428 public:
VarDemon(ValueWatcher * const watcher)429 explicit VarDemon(ValueWatcher* const watcher)
430 : value_watcher_(watcher) {}
431
~VarDemon()432 ~VarDemon() override {}
433
Run(Solver * const solver)434 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
435
436 private:
437 ValueWatcher* const value_watcher_;
438 };
439
ValueWatcher(Solver * const solver,DomainIntVar * const variable)440 ValueWatcher(Solver* const solver, DomainIntVar* const variable)
441 : BaseValueWatcher(solver),
442 variable_(variable),
443 hole_iterator_(variable_->MakeHoleIterator(true)),
444 var_demon_(nullptr),
445 watchers_(solver, variable->Min(), variable->Max()) {}
446
~ValueWatcher()447 ~ValueWatcher() override {}
448
GetOrMakeValueWatcher(int64_t value)449 IntVar* GetOrMakeValueWatcher(int64_t value) override {
450 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
451 if (watcher != nullptr) return watcher;
452 if (variable_->Contains(value)) {
453 if (variable_->Bound()) {
454 return solver()->MakeIntConst(1);
455 } else {
456 const std::string vname = variable_->HasName()
457 ? variable_->name()
458 : variable_->DebugString();
459 const std::string bname =
460 absl::StrFormat("Watch<%s == %d>", vname, value);
461 IntVar* const boolvar = solver()->MakeBoolVar(bname);
462 watchers_.UnsafeRevInsert(value, boolvar);
463 if (posted_.Switched()) {
464 boolvar->WhenBound(
465 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
466 var_demon_->desinhibit(solver());
467 }
468 return boolvar;
469 }
470 } else {
471 return variable_->solver()->MakeIntConst(0);
472 }
473 }
474
SetValueWatcher(IntVar * const boolvar,int64_t value)475 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
476 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
477 if (!boolvar->Bound()) {
478 watchers_.UnsafeRevInsert(value, boolvar);
479 if (posted_.Switched() && !boolvar->Bound()) {
480 boolvar->WhenBound(
481 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
482 var_demon_->desinhibit(solver());
483 }
484 }
485 }
486
Post()487 void Post() override {
488 var_demon_ = solver()->RevAlloc(new VarDemon(this));
489 variable_->WhenDomain(var_demon_);
490 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
491 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
492 const int64_t value = w.first;
493 IntVar* const boolvar = w.second;
494 if (!boolvar->Bound() && variable_->Contains(value)) {
495 boolvar->WhenBound(
496 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
497 }
498 }
499 posted_.Switch(solver());
500 }
501
InitialPropagate()502 void InitialPropagate() override {
503 if (variable_->Bound()) {
504 VariableBound();
505 } else {
506 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
507 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
508 const int64_t value = w.first;
509 IntVar* const boolvar = w.second;
510 if (!variable_->Contains(value)) {
511 boolvar->SetValue(0);
512 watchers_.RemoveAt(pos);
513 } else {
514 if (boolvar->Bound()) {
515 ProcessValueWatcher(value, boolvar);
516 watchers_.RemoveAt(pos);
517 }
518 }
519 }
520 CheckInhibit();
521 }
522 }
523
ProcessValueWatcher(int64_t value,IntVar * boolvar)524 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
525 if (boolvar->Min() == 0) {
526 if (variable_->Size() < 0xFFFFFF) {
527 variable_->RemoveValue(value);
528 } else {
529 // Delay removal.
530 solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
531 }
532 } else {
533 variable_->SetValue(value);
534 }
535 }
536
ProcessVar()537 void ProcessVar() {
538 const int kSmallList = 16;
539 if (variable_->Bound()) {
540 VariableBound();
541 } else if (watchers_.Size() <= kSmallList ||
542 variable_->Min() != variable_->OldMin() ||
543 variable_->Max() != variable_->OldMax()) {
544 // Brute force loop for small numbers of watchers, or if the bounds have
545 // changed, which would have required a sort (n log(n)) anyway to take
546 // advantage of.
547 ScanWatchers();
548 CheckInhibit();
549 } else {
550 // If there is no bitset, then there are no holes.
551 // In that case, the two loops above should have performed all
552 // propagation. Otherwise, scan the remaining watchers.
553 BitSet* const bitset = variable_->bitset();
554 if (bitset != nullptr && !watchers_.Empty()) {
555 if (bitset->NumHoles() * 2 < watchers_.Size()) {
556 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
557 int pos = 0;
558 IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
559 if (boolvar != nullptr) {
560 boolvar->SetValue(0);
561 watchers_.RemoveAt(pos);
562 }
563 }
564 } else {
565 ScanWatchers();
566 }
567 }
568 CheckInhibit();
569 }
570 }
571
572 // Optimized case if the variable is bound.
VariableBound()573 void VariableBound() {
574 DCHECK(variable_->Bound());
575 const int64_t value = variable_->Min();
576 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
577 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
578 w.second->SetValue(w.first == value);
579 }
580 watchers_.RemoveAll();
581 var_demon_->inhibit(solver());
582 }
583
584 // Scans all the watchers to check and assign them.
ScanWatchers()585 void ScanWatchers() {
586 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
587 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
588 if (!variable_->Contains(w.first)) {
589 IntVar* const boolvar = w.second;
590 boolvar->SetValue(0);
591 watchers_.RemoveAt(pos);
592 }
593 }
594 }
595
596 // If the set of active watchers is empty, we can inhibit the demon on the
597 // main variable.
CheckInhibit()598 void CheckInhibit() {
599 if (watchers_.Empty()) {
600 var_demon_->inhibit(solver());
601 }
602 }
603
Accept(ModelVisitor * const visitor) const604 void Accept(ModelVisitor* const visitor) const override {
605 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
606 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
607 variable_);
608 std::vector<int64_t> all_coefficients;
609 std::vector<IntVar*> all_bool_vars;
610 for (int position = watchers_.start(); position < watchers_.end();
611 ++position) {
612 const std::pair<int64_t, IntVar*>& w = watchers_.At(position);
613 all_coefficients.push_back(w.first);
614 all_bool_vars.push_back(w.second);
615 }
616 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
617 all_bool_vars);
618 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
619 all_coefficients);
620 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
621 }
622
DebugString() const623 std::string DebugString() const override {
624 return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
625 }
626
627 private:
628 DomainIntVar* const variable_;
629 IntVarIterator* const hole_iterator_;
630 RevSwitch posted_;
631 Demon* var_demon_;
632 RevIntPtrMap<IntVar> watchers_;
633 };
634
635 // Optimized case for small maps.
636 class DenseValueWatcher : public BaseValueWatcher {
637 public:
638 class WatchDemon : public Demon {
639 public:
WatchDemon(DenseValueWatcher * const watcher,int64_t value,IntVar * var)640 WatchDemon(DenseValueWatcher* const watcher, int64_t value, IntVar* var)
641 : value_watcher_(watcher), value_(value), var_(var) {}
~WatchDemon()642 ~WatchDemon() override {}
643
Run(Solver * const solver)644 void Run(Solver* const solver) override {
645 value_watcher_->ProcessValueWatcher(value_, var_);
646 }
647
648 private:
649 DenseValueWatcher* const value_watcher_;
650 const int64_t value_;
651 IntVar* const var_;
652 };
653
654 class VarDemon : public Demon {
655 public:
VarDemon(DenseValueWatcher * const watcher)656 explicit VarDemon(DenseValueWatcher* const watcher)
657 : value_watcher_(watcher) {}
658
~VarDemon()659 ~VarDemon() override {}
660
Run(Solver * const solver)661 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
662
663 private:
664 DenseValueWatcher* const value_watcher_;
665 };
666
DenseValueWatcher(Solver * const solver,DomainIntVar * const variable)667 DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
668 : BaseValueWatcher(solver),
669 variable_(variable),
670 hole_iterator_(variable_->MakeHoleIterator(true)),
671 var_demon_(nullptr),
672 offset_(variable->Min()),
673 watchers_(variable->Max() - variable->Min() + 1, nullptr),
674 active_watchers_(0) {}
675
~DenseValueWatcher()676 ~DenseValueWatcher() override {}
677
GetOrMakeValueWatcher(int64_t value)678 IntVar* GetOrMakeValueWatcher(int64_t value) override {
679 const int64_t var_max = offset_ + watchers_.size() - 1; // Bad cast.
680 if (value < offset_ || value > var_max) {
681 return solver()->MakeIntConst(0);
682 }
683 const int index = value - offset_;
684 IntVar* const watcher = watchers_[index];
685 if (watcher != nullptr) return watcher;
686 if (variable_->Contains(value)) {
687 if (variable_->Bound()) {
688 return solver()->MakeIntConst(1);
689 } else {
690 const std::string vname = variable_->HasName()
691 ? variable_->name()
692 : variable_->DebugString();
693 const std::string bname =
694 absl::StrFormat("Watch<%s == %d>", vname, value);
695 IntVar* const boolvar = solver()->MakeBoolVar(bname);
696 RevInsert(index, boolvar);
697 if (posted_.Switched()) {
698 boolvar->WhenBound(
699 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
700 var_demon_->desinhibit(solver());
701 }
702 return boolvar;
703 }
704 } else {
705 return variable_->solver()->MakeIntConst(0);
706 }
707 }
708
SetValueWatcher(IntVar * const boolvar,int64_t value)709 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
710 const int index = value - offset_;
711 CHECK(watchers_[index] == nullptr);
712 if (!boolvar->Bound()) {
713 RevInsert(index, boolvar);
714 if (posted_.Switched() && !boolvar->Bound()) {
715 boolvar->WhenBound(
716 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
717 var_demon_->desinhibit(solver());
718 }
719 }
720 }
721
Post()722 void Post() override {
723 var_demon_ = solver()->RevAlloc(new VarDemon(this));
724 variable_->WhenDomain(var_demon_);
725 for (int pos = 0; pos < watchers_.size(); ++pos) {
726 const int64_t value = pos + offset_;
727 IntVar* const boolvar = watchers_[pos];
728 if (boolvar != nullptr && !boolvar->Bound() &&
729 variable_->Contains(value)) {
730 boolvar->WhenBound(
731 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
732 }
733 }
734 posted_.Switch(solver());
735 }
736
InitialPropagate()737 void InitialPropagate() override {
738 if (variable_->Bound()) {
739 VariableBound();
740 } else {
741 for (int pos = 0; pos < watchers_.size(); ++pos) {
742 IntVar* const boolvar = watchers_[pos];
743 if (boolvar == nullptr) continue;
744 const int64_t value = pos + offset_;
745 if (!variable_->Contains(value)) {
746 boolvar->SetValue(0);
747 RevRemove(pos);
748 } else if (boolvar->Bound()) {
749 ProcessValueWatcher(value, boolvar);
750 RevRemove(pos);
751 }
752 }
753 if (active_watchers_.Value() == 0) {
754 var_demon_->inhibit(solver());
755 }
756 }
757 }
758
ProcessValueWatcher(int64_t value,IntVar * boolvar)759 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
760 if (boolvar->Min() == 0) {
761 variable_->RemoveValue(value);
762 } else {
763 variable_->SetValue(value);
764 }
765 }
766
ProcessVar()767 void ProcessVar() {
768 if (variable_->Bound()) {
769 VariableBound();
770 } else {
771 // Brute force loop for small numbers of watchers.
772 ScanWatchers();
773 if (active_watchers_.Value() == 0) {
774 var_demon_->inhibit(solver());
775 }
776 }
777 }
778
779 // Optimized case if the variable is bound.
VariableBound()780 void VariableBound() {
781 DCHECK(variable_->Bound());
782 const int64_t value = variable_->Min();
783 for (int pos = 0; pos < watchers_.size(); ++pos) {
784 IntVar* const boolvar = watchers_[pos];
785 if (boolvar != nullptr) {
786 boolvar->SetValue(pos + offset_ == value);
787 RevRemove(pos);
788 }
789 }
790 var_demon_->inhibit(solver());
791 }
792
793 // Scans all the watchers to check and assign them.
ScanWatchers()794 void ScanWatchers() {
795 const int64_t old_min_index = variable_->OldMin() - offset_;
796 const int64_t old_max_index = variable_->OldMax() - offset_;
797 const int64_t min_index = variable_->Min() - offset_;
798 const int64_t max_index = variable_->Max() - offset_;
799 for (int pos = old_min_index; pos < min_index; ++pos) {
800 IntVar* const boolvar = watchers_[pos];
801 if (boolvar != nullptr) {
802 boolvar->SetValue(0);
803 RevRemove(pos);
804 }
805 }
806 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
807 IntVar* const boolvar = watchers_[pos];
808 if (boolvar != nullptr) {
809 boolvar->SetValue(0);
810 RevRemove(pos);
811 }
812 }
813 BitSet* const bitset = variable_->bitset();
814 if (bitset != nullptr) {
815 if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
816 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
817 IntVar* const boolvar = watchers_[hole - offset_];
818 if (boolvar != nullptr) {
819 boolvar->SetValue(0);
820 RevRemove(hole - offset_);
821 }
822 }
823 } else {
824 for (int pos = min_index + 1; pos < max_index; ++pos) {
825 IntVar* const boolvar = watchers_[pos];
826 if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
827 boolvar->SetValue(0);
828 RevRemove(pos);
829 }
830 }
831 }
832 }
833 }
834
RevRemove(int pos)835 void RevRemove(int pos) {
836 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
837 watchers_[pos] = nullptr;
838 active_watchers_.Decr(solver());
839 }
840
RevInsert(int pos,IntVar * boolvar)841 void RevInsert(int pos, IntVar* boolvar) {
842 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
843 watchers_[pos] = boolvar;
844 active_watchers_.Incr(solver());
845 }
846
Accept(ModelVisitor * const visitor) const847 void Accept(ModelVisitor* const visitor) const override {
848 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
849 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
850 variable_);
851 std::vector<int64_t> all_coefficients;
852 std::vector<IntVar*> all_bool_vars;
853 for (int position = 0; position < watchers_.size(); ++position) {
854 if (watchers_[position] != nullptr) {
855 all_coefficients.push_back(position + offset_);
856 all_bool_vars.push_back(watchers_[position]);
857 }
858 }
859 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
860 all_bool_vars);
861 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
862 all_coefficients);
863 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
864 }
865
DebugString() const866 std::string DebugString() const override {
867 return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
868 }
869
870 private:
871 DomainIntVar* const variable_;
872 IntVarIterator* const hole_iterator_;
873 RevSwitch posted_;
874 Demon* var_demon_;
875 const int64_t offset_;
876 std::vector<IntVar*> watchers_;
877 NumericalRev<int> active_watchers_;
878 };
879
880 class BaseUpperBoundWatcher : public Constraint {
881 public:
BaseUpperBoundWatcher(Solver * const solver)882 explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
883
~BaseUpperBoundWatcher()884 ~BaseUpperBoundWatcher() override {}
885
886 virtual IntVar* GetOrMakeUpperBoundWatcher(int64_t value) = 0;
887
888 virtual void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) = 0;
889 };
890
891 // This class watches the bounds of the variable and updates the
892 // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
893 // accordingly.
894 class UpperBoundWatcher : public BaseUpperBoundWatcher {
895 public:
896 class WatchDemon : public Demon {
897 public:
WatchDemon(UpperBoundWatcher * const watcher,int64_t index,IntVar * const var)898 WatchDemon(UpperBoundWatcher* const watcher, int64_t index,
899 IntVar* const var)
900 : value_watcher_(watcher), index_(index), var_(var) {}
~WatchDemon()901 ~WatchDemon() override {}
902
Run(Solver * const solver)903 void Run(Solver* const solver) override {
904 value_watcher_->ProcessUpperBoundWatcher(index_, var_);
905 }
906
907 private:
908 UpperBoundWatcher* const value_watcher_;
909 const int64_t index_;
910 IntVar* const var_;
911 };
912
913 class VarDemon : public Demon {
914 public:
VarDemon(UpperBoundWatcher * const watcher)915 explicit VarDemon(UpperBoundWatcher* const watcher)
916 : value_watcher_(watcher) {}
~VarDemon()917 ~VarDemon() override {}
918
Run(Solver * const solver)919 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
920
921 private:
922 UpperBoundWatcher* const value_watcher_;
923 };
924
UpperBoundWatcher(Solver * const solver,DomainIntVar * const variable)925 UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
926 : BaseUpperBoundWatcher(solver),
927 variable_(variable),
928 var_demon_(nullptr),
929 watchers_(solver, variable->Min(), variable->Max()),
930 start_(0),
931 end_(0),
932 sorted_(false) {}
933
~UpperBoundWatcher()934 ~UpperBoundWatcher() override {}
935
GetOrMakeUpperBoundWatcher(int64_t value)936 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
937 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
938 if (watcher != nullptr) {
939 return watcher;
940 }
941 if (variable_->Max() >= value) {
942 if (variable_->Min() >= value) {
943 return solver()->MakeIntConst(1);
944 } else {
945 const std::string vname = variable_->HasName()
946 ? variable_->name()
947 : variable_->DebugString();
948 const std::string bname =
949 absl::StrFormat("Watch<%s >= %d>", vname, value);
950 IntVar* const boolvar = solver()->MakeBoolVar(bname);
951 watchers_.UnsafeRevInsert(value, boolvar);
952 if (posted_.Switched()) {
953 boolvar->WhenBound(
954 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
955 var_demon_->desinhibit(solver());
956 sorted_ = false;
957 }
958 return boolvar;
959 }
960 } else {
961 return variable_->solver()->MakeIntConst(0);
962 }
963 }
964
SetUpperBoundWatcher(IntVar * const boolvar,int64_t value)965 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
966 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
967 watchers_.UnsafeRevInsert(value, boolvar);
968 if (posted_.Switched() && !boolvar->Bound()) {
969 boolvar->WhenBound(
970 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
971 var_demon_->desinhibit(solver());
972 sorted_ = false;
973 }
974 }
975
Post()976 void Post() override {
977 const int kTooSmallToSort = 8;
978 var_demon_ = solver()->RevAlloc(new VarDemon(this));
979 variable_->WhenRange(var_demon_);
980
981 if (watchers_.Size() > kTooSmallToSort) {
982 watchers_.SortActive();
983 sorted_ = true;
984 start_.SetValue(solver(), watchers_.start());
985 end_.SetValue(solver(), watchers_.end() - 1);
986 }
987
988 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
989 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
990 IntVar* const boolvar = w.second;
991 const int64_t value = w.first;
992 if (!boolvar->Bound() && value > variable_->Min() &&
993 value <= variable_->Max()) {
994 boolvar->WhenBound(
995 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
996 }
997 }
998 posted_.Switch(solver());
999 }
1000
InitialPropagate()1001 void InitialPropagate() override {
1002 const int64_t var_min = variable_->Min();
1003 const int64_t var_max = variable_->Max();
1004 if (sorted_) {
1005 while (start_.Value() <= end_.Value()) {
1006 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1007 if (w.first <= var_min) {
1008 w.second->SetValue(1);
1009 start_.Incr(solver());
1010 } else {
1011 break;
1012 }
1013 }
1014 while (end_.Value() >= start_.Value()) {
1015 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1016 if (w.first > var_max) {
1017 w.second->SetValue(0);
1018 end_.Decr(solver());
1019 } else {
1020 break;
1021 }
1022 }
1023 for (int i = start_.Value(); i <= end_.Value(); ++i) {
1024 const std::pair<int64_t, IntVar*>& w = watchers_.At(i);
1025 if (w.second->Bound()) {
1026 ProcessUpperBoundWatcher(w.first, w.second);
1027 }
1028 }
1029 if (start_.Value() > end_.Value()) {
1030 var_demon_->inhibit(solver());
1031 }
1032 } else {
1033 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1034 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1035 const int64_t value = w.first;
1036 IntVar* const boolvar = w.second;
1037
1038 if (value <= var_min) {
1039 boolvar->SetValue(1);
1040 watchers_.RemoveAt(pos);
1041 } else if (value > var_max) {
1042 boolvar->SetValue(0);
1043 watchers_.RemoveAt(pos);
1044 } else if (boolvar->Bound()) {
1045 ProcessUpperBoundWatcher(value, boolvar);
1046 watchers_.RemoveAt(pos);
1047 }
1048 }
1049 }
1050 }
1051
Accept(ModelVisitor * const visitor) const1052 void Accept(ModelVisitor* const visitor) const override {
1053 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1054 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1055 variable_);
1056 std::vector<int64_t> all_coefficients;
1057 std::vector<IntVar*> all_bool_vars;
1058 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1059 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1060 all_coefficients.push_back(w.first);
1061 all_bool_vars.push_back(w.second);
1062 }
1063 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1064 all_bool_vars);
1065 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1066 all_coefficients);
1067 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1068 }
1069
DebugString() const1070 std::string DebugString() const override {
1071 return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1072 }
1073
1074 private:
ProcessUpperBoundWatcher(int64_t value,IntVar * const boolvar)1075 void ProcessUpperBoundWatcher(int64_t value, IntVar* const boolvar) {
1076 if (boolvar->Min() == 0) {
1077 variable_->SetMax(value - 1);
1078 } else {
1079 variable_->SetMin(value);
1080 }
1081 }
1082
ProcessVar()1083 void ProcessVar() {
1084 const int64_t var_min = variable_->Min();
1085 const int64_t var_max = variable_->Max();
1086 if (sorted_) {
1087 while (start_.Value() <= end_.Value()) {
1088 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1089 if (w.first <= var_min) {
1090 w.second->SetValue(1);
1091 start_.Incr(solver());
1092 } else {
1093 break;
1094 }
1095 }
1096 while (end_.Value() >= start_.Value()) {
1097 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1098 if (w.first > var_max) {
1099 w.second->SetValue(0);
1100 end_.Decr(solver());
1101 } else {
1102 break;
1103 }
1104 }
1105 if (start_.Value() > end_.Value()) {
1106 var_demon_->inhibit(solver());
1107 }
1108 } else {
1109 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1110 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1111 const int64_t value = w.first;
1112 IntVar* const boolvar = w.second;
1113
1114 if (value <= var_min) {
1115 boolvar->SetValue(1);
1116 watchers_.RemoveAt(pos);
1117 } else if (value > var_max) {
1118 boolvar->SetValue(0);
1119 watchers_.RemoveAt(pos);
1120 }
1121 }
1122 if (watchers_.Empty()) {
1123 var_demon_->inhibit(solver());
1124 }
1125 }
1126 }
1127
1128 DomainIntVar* const variable_;
1129 RevSwitch posted_;
1130 Demon* var_demon_;
1131 RevIntPtrMap<IntVar> watchers_;
1132 NumericalRev<int> start_;
1133 NumericalRev<int> end_;
1134 bool sorted_;
1135 };
1136
1137 // Optimized case for small maps.
1138 class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1139 public:
1140 class WatchDemon : public Demon {
1141 public:
WatchDemon(DenseUpperBoundWatcher * const watcher,int64_t value,IntVar * var)1142 WatchDemon(DenseUpperBoundWatcher* const watcher, int64_t value,
1143 IntVar* var)
1144 : value_watcher_(watcher), value_(value), var_(var) {}
~WatchDemon()1145 ~WatchDemon() override {}
1146
Run(Solver * const solver)1147 void Run(Solver* const solver) override {
1148 value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1149 }
1150
1151 private:
1152 DenseUpperBoundWatcher* const value_watcher_;
1153 const int64_t value_;
1154 IntVar* const var_;
1155 };
1156
1157 class VarDemon : public Demon {
1158 public:
VarDemon(DenseUpperBoundWatcher * const watcher)1159 explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1160 : value_watcher_(watcher) {}
1161
~VarDemon()1162 ~VarDemon() override {}
1163
Run(Solver * const solver)1164 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1165
1166 private:
1167 DenseUpperBoundWatcher* const value_watcher_;
1168 };
1169
DenseUpperBoundWatcher(Solver * const solver,DomainIntVar * const variable)1170 DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1171 : BaseUpperBoundWatcher(solver),
1172 variable_(variable),
1173 var_demon_(nullptr),
1174 offset_(variable->Min()),
1175 watchers_(variable->Max() - variable->Min() + 1, nullptr),
1176 active_watchers_(0) {}
1177
~DenseUpperBoundWatcher()1178 ~DenseUpperBoundWatcher() override {}
1179
GetOrMakeUpperBoundWatcher(int64_t value)1180 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
1181 if (variable_->Max() >= value) {
1182 if (variable_->Min() >= value) {
1183 return solver()->MakeIntConst(1);
1184 } else {
1185 const std::string vname = variable_->HasName()
1186 ? variable_->name()
1187 : variable_->DebugString();
1188 const std::string bname =
1189 absl::StrFormat("Watch<%s >= %d>", vname, value);
1190 IntVar* const boolvar = solver()->MakeBoolVar(bname);
1191 RevInsert(value - offset_, boolvar);
1192 if (posted_.Switched()) {
1193 boolvar->WhenBound(
1194 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1195 var_demon_->desinhibit(solver());
1196 }
1197 return boolvar;
1198 }
1199 } else {
1200 return variable_->solver()->MakeIntConst(0);
1201 }
1202 }
1203
SetUpperBoundWatcher(IntVar * const boolvar,int64_t value)1204 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
1205 const int index = value - offset_;
1206 CHECK(watchers_[index] == nullptr);
1207 if (!boolvar->Bound()) {
1208 RevInsert(index, boolvar);
1209 if (posted_.Switched() && !boolvar->Bound()) {
1210 boolvar->WhenBound(
1211 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1212 var_demon_->desinhibit(solver());
1213 }
1214 }
1215 }
1216
Post()1217 void Post() override {
1218 var_demon_ = solver()->RevAlloc(new VarDemon(this));
1219 variable_->WhenRange(var_demon_);
1220 for (int pos = 0; pos < watchers_.size(); ++pos) {
1221 const int64_t value = pos + offset_;
1222 IntVar* const boolvar = watchers_[pos];
1223 if (boolvar != nullptr && !boolvar->Bound() &&
1224 value > variable_->Min() && value <= variable_->Max()) {
1225 boolvar->WhenBound(
1226 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1227 }
1228 }
1229 posted_.Switch(solver());
1230 }
1231
InitialPropagate()1232 void InitialPropagate() override {
1233 for (int pos = 0; pos < watchers_.size(); ++pos) {
1234 IntVar* const boolvar = watchers_[pos];
1235 if (boolvar == nullptr) continue;
1236 const int64_t value = pos + offset_;
1237 if (value <= variable_->Min()) {
1238 boolvar->SetValue(1);
1239 RevRemove(pos);
1240 } else if (value > variable_->Max()) {
1241 boolvar->SetValue(0);
1242 RevRemove(pos);
1243 } else if (boolvar->Bound()) {
1244 ProcessUpperBoundWatcher(value, boolvar);
1245 RevRemove(pos);
1246 }
1247 }
1248 if (active_watchers_.Value() == 0) {
1249 var_demon_->inhibit(solver());
1250 }
1251 }
1252
ProcessUpperBoundWatcher(int64_t value,IntVar * boolvar)1253 void ProcessUpperBoundWatcher(int64_t value, IntVar* boolvar) {
1254 if (boolvar->Min() == 0) {
1255 variable_->SetMax(value - 1);
1256 } else {
1257 variable_->SetMin(value);
1258 }
1259 }
1260
ProcessVar()1261 void ProcessVar() {
1262 const int64_t old_min_index = variable_->OldMin() - offset_;
1263 const int64_t old_max_index = variable_->OldMax() - offset_;
1264 const int64_t min_index = variable_->Min() - offset_;
1265 const int64_t max_index = variable_->Max() - offset_;
1266 for (int pos = old_min_index; pos <= min_index; ++pos) {
1267 IntVar* const boolvar = watchers_[pos];
1268 if (boolvar != nullptr) {
1269 boolvar->SetValue(1);
1270 RevRemove(pos);
1271 }
1272 }
1273
1274 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1275 IntVar* const boolvar = watchers_[pos];
1276 if (boolvar != nullptr) {
1277 boolvar->SetValue(0);
1278 RevRemove(pos);
1279 }
1280 }
1281 if (active_watchers_.Value() == 0) {
1282 var_demon_->inhibit(solver());
1283 }
1284 }
1285
RevRemove(int pos)1286 void RevRemove(int pos) {
1287 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1288 watchers_[pos] = nullptr;
1289 active_watchers_.Decr(solver());
1290 }
1291
RevInsert(int pos,IntVar * boolvar)1292 void RevInsert(int pos, IntVar* boolvar) {
1293 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1294 watchers_[pos] = boolvar;
1295 active_watchers_.Incr(solver());
1296 }
1297
Accept(ModelVisitor * const visitor) const1298 void Accept(ModelVisitor* const visitor) const override {
1299 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1300 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1301 variable_);
1302 std::vector<int64_t> all_coefficients;
1303 std::vector<IntVar*> all_bool_vars;
1304 for (int position = 0; position < watchers_.size(); ++position) {
1305 if (watchers_[position] != nullptr) {
1306 all_coefficients.push_back(position + offset_);
1307 all_bool_vars.push_back(watchers_[position]);
1308 }
1309 }
1310 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1311 all_bool_vars);
1312 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1313 all_coefficients);
1314 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1315 }
1316
DebugString() const1317 std::string DebugString() const override {
1318 return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1319 variable_->DebugString());
1320 }
1321
1322 private:
1323 DomainIntVar* const variable_;
1324 RevSwitch posted_;
1325 Demon* var_demon_;
1326 const int64_t offset_;
1327 std::vector<IntVar*> watchers_;
1328 NumericalRev<int> active_watchers_;
1329 };
1330
1331 // ----- Main Class -----
1332 DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
1333 const std::string& name);
1334 DomainIntVar(Solver* const s, const std::vector<int64_t>& sorted_values,
1335 const std::string& name);
1336 ~DomainIntVar() override;
1337
Min() const1338 int64_t Min() const override { return min_.Value(); }
1339 void SetMin(int64_t m) override;
Max() const1340 int64_t Max() const override { return max_.Value(); }
1341 void SetMax(int64_t m) override;
1342 void SetRange(int64_t mi, int64_t ma) override;
1343 void SetValue(int64_t v) override;
Bound() const1344 bool Bound() const override { return (min_.Value() == max_.Value()); }
Value() const1345 int64_t Value() const override {
1346 CHECK_EQ(min_.Value(), max_.Value())
1347 << " variable " << DebugString() << " is not bound.";
1348 return min_.Value();
1349 }
1350 void RemoveValue(int64_t v) override;
1351 void RemoveInterval(int64_t l, int64_t u) override;
1352 void CreateBits();
WhenBound(Demon * d)1353 void WhenBound(Demon* d) override {
1354 if (min_.Value() != max_.Value()) {
1355 if (d->priority() == Solver::DELAYED_PRIORITY) {
1356 delayed_bound_demons_.PushIfNotTop(solver(),
1357 solver()->RegisterDemon(d));
1358 } else {
1359 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1360 }
1361 }
1362 }
WhenRange(Demon * d)1363 void WhenRange(Demon* d) override {
1364 if (min_.Value() != max_.Value()) {
1365 if (d->priority() == Solver::DELAYED_PRIORITY) {
1366 delayed_range_demons_.PushIfNotTop(solver(),
1367 solver()->RegisterDemon(d));
1368 } else {
1369 range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1370 }
1371 }
1372 }
WhenDomain(Demon * d)1373 void WhenDomain(Demon* d) override {
1374 if (min_.Value() != max_.Value()) {
1375 if (d->priority() == Solver::DELAYED_PRIORITY) {
1376 delayed_domain_demons_.PushIfNotTop(solver(),
1377 solver()->RegisterDemon(d));
1378 } else {
1379 domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1380 }
1381 }
1382 }
1383
IsEqual(int64_t constant)1384 IntVar* IsEqual(int64_t constant) override {
1385 Solver* const s = solver();
1386 if (constant == min_.Value() && value_watcher_ == nullptr) {
1387 return s->MakeIsLessOrEqualCstVar(this, constant);
1388 }
1389 if (constant == max_.Value() && value_watcher_ == nullptr) {
1390 return s->MakeIsGreaterOrEqualCstVar(this, constant);
1391 }
1392 if (!Contains(constant)) {
1393 return s->MakeIntConst(int64_t{0});
1394 }
1395 if (Bound() && min_.Value() == constant) {
1396 return s->MakeIntConst(int64_t{1});
1397 }
1398 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1399 this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1400 if (cache != nullptr) {
1401 return cache->Var();
1402 } else {
1403 if (value_watcher_ == nullptr) {
1404 if (CapSub(Max(), Min()) <= 256) {
1405 solver()->SaveAndSetValue(
1406 reinterpret_cast<void**>(&value_watcher_),
1407 reinterpret_cast<void*>(
1408 solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1409
1410 } else {
1411 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1412 reinterpret_cast<void*>(solver()->RevAlloc(
1413 new ValueWatcher(solver(), this))));
1414 }
1415 solver()->AddConstraint(value_watcher_);
1416 }
1417 IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1418 s->Cache()->InsertExprConstantExpression(
1419 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1420 return boolvar;
1421 }
1422 }
1423
SetIsEqual(const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)1424 Constraint* SetIsEqual(const std::vector<int64_t>& values,
1425 const std::vector<IntVar*>& vars) {
1426 if (value_watcher_ == nullptr) {
1427 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1428 reinterpret_cast<void*>(solver()->RevAlloc(
1429 new ValueWatcher(solver(), this))));
1430 for (int i = 0; i < vars.size(); ++i) {
1431 value_watcher_->SetValueWatcher(vars[i], values[i]);
1432 }
1433 }
1434 return value_watcher_;
1435 }
1436
IsDifferent(int64_t constant)1437 IntVar* IsDifferent(int64_t constant) override {
1438 Solver* const s = solver();
1439 if (constant == min_.Value() && value_watcher_ == nullptr) {
1440 return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1441 }
1442 if (constant == max_.Value() && value_watcher_ == nullptr) {
1443 return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1444 }
1445 if (!Contains(constant)) {
1446 return s->MakeIntConst(int64_t{1});
1447 }
1448 if (Bound() && min_.Value() == constant) {
1449 return s->MakeIntConst(int64_t{0});
1450 }
1451 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1452 this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1453 if (cache != nullptr) {
1454 return cache->Var();
1455 } else {
1456 IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1457 s->Cache()->InsertExprConstantExpression(
1458 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1459 return boolvar;
1460 }
1461 }
1462
IsGreaterOrEqual(int64_t constant)1463 IntVar* IsGreaterOrEqual(int64_t constant) override {
1464 Solver* const s = solver();
1465 if (max_.Value() < constant) {
1466 return s->MakeIntConst(int64_t{0});
1467 }
1468 if (min_.Value() >= constant) {
1469 return s->MakeIntConst(int64_t{1});
1470 }
1471 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1472 this, constant, ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
1473 if (cache != nullptr) {
1474 return cache->Var();
1475 } else {
1476 if (bound_watcher_ == nullptr) {
1477 if (CapSub(Max(), Min()) <= 256) {
1478 solver()->SaveAndSetValue(
1479 reinterpret_cast<void**>(&bound_watcher_),
1480 reinterpret_cast<void*>(solver()->RevAlloc(
1481 new DenseUpperBoundWatcher(solver(), this))));
1482 solver()->AddConstraint(bound_watcher_);
1483 } else {
1484 solver()->SaveAndSetValue(
1485 reinterpret_cast<void**>(&bound_watcher_),
1486 reinterpret_cast<void*>(
1487 solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1488 solver()->AddConstraint(bound_watcher_);
1489 }
1490 }
1491 IntVar* const boolvar =
1492 bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1493 s->Cache()->InsertExprConstantExpression(
1494 boolvar, this, constant,
1495 ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
1496 return boolvar;
1497 }
1498 }
1499
SetIsGreaterOrEqual(const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)1500 Constraint* SetIsGreaterOrEqual(const std::vector<int64_t>& values,
1501 const std::vector<IntVar*>& vars) {
1502 if (bound_watcher_ == nullptr) {
1503 if (CapSub(Max(), Min()) <= 256) {
1504 solver()->SaveAndSetValue(
1505 reinterpret_cast<void**>(&bound_watcher_),
1506 reinterpret_cast<void*>(solver()->RevAlloc(
1507 new DenseUpperBoundWatcher(solver(), this))));
1508 solver()->AddConstraint(bound_watcher_);
1509 } else {
1510 solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1511 reinterpret_cast<void*>(solver()->RevAlloc(
1512 new UpperBoundWatcher(solver(), this))));
1513 solver()->AddConstraint(bound_watcher_);
1514 }
1515 for (int i = 0; i < values.size(); ++i) {
1516 bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1517 }
1518 }
1519 return bound_watcher_;
1520 }
1521
IsLessOrEqual(int64_t constant)1522 IntVar* IsLessOrEqual(int64_t constant) override {
1523 Solver* const s = solver();
1524 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1525 this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1526 if (cache != nullptr) {
1527 return cache->Var();
1528 } else {
1529 IntVar* const boolvar =
1530 s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1531 s->Cache()->InsertExprConstantExpression(
1532 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1533 return boolvar;
1534 }
1535 }
1536
1537 void Process();
1538 void Push();
1539 void CleanInProcess();
Size() const1540 uint64_t Size() const override {
1541 if (bits_ != nullptr) return bits_->Size();
1542 return (static_cast<uint64_t>(max_.Value()) -
1543 static_cast<uint64_t>(min_.Value()) + 1);
1544 }
Contains(int64_t v) const1545 bool Contains(int64_t v) const override {
1546 if (v < min_.Value() || v > max_.Value()) return false;
1547 return (bits_ == nullptr ? true : bits_->Contains(v));
1548 }
1549 IntVarIterator* MakeHoleIterator(bool reversible) const override;
1550 IntVarIterator* MakeDomainIterator(bool reversible) const override;
OldMin() const1551 int64_t OldMin() const override { return std::min(old_min_, min_.Value()); }
OldMax() const1552 int64_t OldMax() const override { return std::max(old_max_, max_.Value()); }
1553
1554 std::string DebugString() const override;
bitset() const1555 BitSet* bitset() const { return bits_; }
VarType() const1556 int VarType() const override { return DOMAIN_INT_VAR; }
BaseName() const1557 std::string BaseName() const override { return "IntegerVar"; }
1558
1559 friend class PlusCstDomainIntVar;
1560 friend class LinkExprAndDomainIntVar;
1561
1562 private:
CheckOldMin()1563 void CheckOldMin() {
1564 if (old_min_ > min_.Value()) {
1565 old_min_ = min_.Value();
1566 }
1567 }
CheckOldMax()1568 void CheckOldMax() {
1569 if (old_max_ < max_.Value()) {
1570 old_max_ = max_.Value();
1571 }
1572 }
1573 Rev<int64_t> min_;
1574 Rev<int64_t> max_;
1575 int64_t old_min_;
1576 int64_t old_max_;
1577 int64_t new_min_;
1578 int64_t new_max_;
1579 SimpleRevFIFO<Demon*> bound_demons_;
1580 SimpleRevFIFO<Demon*> range_demons_;
1581 SimpleRevFIFO<Demon*> domain_demons_;
1582 SimpleRevFIFO<Demon*> delayed_bound_demons_;
1583 SimpleRevFIFO<Demon*> delayed_range_demons_;
1584 SimpleRevFIFO<Demon*> delayed_domain_demons_;
1585 QueueHandler handler_;
1586 bool in_process_;
1587 BitSet* bits_;
1588 BaseValueWatcher* value_watcher_;
1589 BaseUpperBoundWatcher* bound_watcher_;
1590 };
1591
1592 // ----- BitSet -----
1593
1594 // Return whether an integer interval [a..b] (inclusive) contains at most
1595 // K values, i.e. b - a < K, in a way that's robust to overflows.
1596 // For performance reasons, in opt mode it doesn't check that [a, b] is a
1597 // valid interval, nor that K is nonnegative.
ClosedIntervalNoLargerThan(int64_t a,int64_t b,int64_t K)1598 inline bool ClosedIntervalNoLargerThan(int64_t a, int64_t b, int64_t K) {
1599 DCHECK_LE(a, b);
1600 DCHECK_GE(K, 0);
1601 if (a > 0) {
1602 return a > b - K;
1603 } else {
1604 return a + K > b;
1605 }
1606 }
1607
1608 class SimpleBitSet : public DomainIntVar::BitSet {
1609 public:
SimpleBitSet(Solver * const s,int64_t vmin,int64_t vmax)1610 SimpleBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1611 : BitSet(s),
1612 bits_(nullptr),
1613 stamps_(nullptr),
1614 omin_(vmin),
1615 omax_(vmax),
1616 size_(vmax - vmin + 1),
1617 bsize_(BitLength64(size_.Value())) {
1618 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1619 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1620 bits_ = new uint64_t[bsize_];
1621 stamps_ = new uint64_t[bsize_];
1622 for (int i = 0; i < bsize_; ++i) {
1623 const int bs =
1624 (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1625 bits_[i] = kAllBits64 >> bs;
1626 stamps_[i] = s->stamp() - 1;
1627 }
1628 }
1629
SimpleBitSet(Solver * const s,const std::vector<int64_t> & sorted_values,int64_t vmin,int64_t vmax)1630 SimpleBitSet(Solver* const s, const std::vector<int64_t>& sorted_values,
1631 int64_t vmin, int64_t vmax)
1632 : BitSet(s),
1633 bits_(nullptr),
1634 stamps_(nullptr),
1635 omin_(vmin),
1636 omax_(vmax),
1637 size_(sorted_values.size()),
1638 bsize_(BitLength64(vmax - vmin + 1)) {
1639 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1640 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1641 bits_ = new uint64_t[bsize_];
1642 stamps_ = new uint64_t[bsize_];
1643 for (int i = 0; i < bsize_; ++i) {
1644 bits_[i] = uint64_t{0};
1645 stamps_[i] = s->stamp() - 1;
1646 }
1647 for (int i = 0; i < sorted_values.size(); ++i) {
1648 const int64_t val = sorted_values[i];
1649 DCHECK(!bit(val));
1650 const int offset = BitOffset64(val - omin_);
1651 const int pos = BitPos64(val - omin_);
1652 bits_[offset] |= OneBit64(pos);
1653 }
1654 }
1655
~SimpleBitSet()1656 ~SimpleBitSet() override {
1657 delete[] bits_;
1658 delete[] stamps_;
1659 }
1660
bit(int64_t val) const1661 bool bit(int64_t val) const { return IsBitSet64(bits_, val - omin_); }
1662
ComputeNewMin(int64_t nmin,int64_t cmin,int64_t cmax)1663 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1664 DCHECK_GE(nmin, cmin);
1665 DCHECK_LE(nmin, cmax);
1666 DCHECK_LE(cmin, cmax);
1667 DCHECK_GE(cmin, omin_);
1668 DCHECK_LE(cmax, omax_);
1669 const int64_t new_min =
1670 UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1671 omin_;
1672 const uint64_t removed_bits =
1673 BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1674 size_.Add(solver_, -removed_bits);
1675 return new_min;
1676 }
1677
ComputeNewMax(int64_t nmax,int64_t cmin,int64_t cmax)1678 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1679 DCHECK_GE(nmax, cmin);
1680 DCHECK_LE(nmax, cmax);
1681 DCHECK_LE(cmin, cmax);
1682 DCHECK_GE(cmin, omin_);
1683 DCHECK_LE(cmax, omax_);
1684 const int64_t new_max =
1685 UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1686 omin_;
1687 const uint64_t removed_bits =
1688 BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1689 size_.Add(solver_, -removed_bits);
1690 return new_max;
1691 }
1692
SetValue(int64_t val)1693 bool SetValue(int64_t val) override {
1694 DCHECK_GE(val, omin_);
1695 DCHECK_LE(val, omax_);
1696 if (bit(val)) {
1697 size_.SetValue(solver_, 1);
1698 return true;
1699 }
1700 return false;
1701 }
1702
Contains(int64_t val) const1703 bool Contains(int64_t val) const override {
1704 DCHECK_GE(val, omin_);
1705 DCHECK_LE(val, omax_);
1706 return bit(val);
1707 }
1708
RemoveValue(int64_t val)1709 bool RemoveValue(int64_t val) override {
1710 if (val < omin_ || val > omax_ || !bit(val)) {
1711 return false;
1712 }
1713 // Bitset.
1714 const int64_t val_offset = val - omin_;
1715 const int offset = BitOffset64(val_offset);
1716 const uint64_t current_stamp = solver_->stamp();
1717 if (stamps_[offset] < current_stamp) {
1718 stamps_[offset] = current_stamp;
1719 solver_->SaveValue(&bits_[offset]);
1720 }
1721 const int pos = BitPos64(val_offset);
1722 bits_[offset] &= ~OneBit64(pos);
1723 // Size.
1724 size_.Decr(solver_);
1725 // Holes.
1726 InitHoles();
1727 AddHole(val);
1728 return true;
1729 }
Size() const1730 uint64_t Size() const override { return size_.Value(); }
1731
DebugString() const1732 std::string DebugString() const override {
1733 std::string out;
1734 absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1735 for (int i = 0; i < bsize_; ++i) {
1736 absl::StrAppendFormat(&out, "%x", bits_[i]);
1737 }
1738 out += ")";
1739 return out;
1740 }
1741
DelayRemoveValue(int64_t val)1742 void DelayRemoveValue(int64_t val) override { removed_.push_back(val); }
1743
ApplyRemovedValues(DomainIntVar * var)1744 void ApplyRemovedValues(DomainIntVar* var) override {
1745 std::sort(removed_.begin(), removed_.end());
1746 for (std::vector<int64_t>::iterator it = removed_.begin();
1747 it != removed_.end(); ++it) {
1748 var->RemoveValue(*it);
1749 }
1750 }
1751
ClearRemovedValues()1752 void ClearRemovedValues() override { removed_.clear(); }
1753
pretty_DebugString(int64_t min,int64_t max) const1754 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1755 std::string out;
1756 DCHECK(bit(min));
1757 DCHECK(bit(max));
1758 if (max != min) {
1759 int cumul = true;
1760 int64_t start_cumul = min;
1761 for (int64_t v = min + 1; v < max; ++v) {
1762 if (bit(v)) {
1763 if (!cumul) {
1764 cumul = true;
1765 start_cumul = v;
1766 }
1767 } else {
1768 if (cumul) {
1769 if (v == start_cumul + 1) {
1770 absl::StrAppendFormat(&out, "%d ", start_cumul);
1771 } else if (v == start_cumul + 2) {
1772 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1773 } else {
1774 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1775 }
1776 cumul = false;
1777 }
1778 }
1779 }
1780 if (cumul) {
1781 if (max == start_cumul + 1) {
1782 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1783 } else {
1784 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1785 }
1786 } else {
1787 absl::StrAppendFormat(&out, "%d", max);
1788 }
1789 } else {
1790 absl::StrAppendFormat(&out, "%d", min);
1791 }
1792 return out;
1793 }
1794
MakeIterator()1795 DomainIntVar::BitSetIterator* MakeIterator() override {
1796 return new DomainIntVar::BitSetIterator(bits_, omin_);
1797 }
1798
1799 private:
1800 uint64_t* bits_;
1801 uint64_t* stamps_;
1802 const int64_t omin_;
1803 const int64_t omax_;
1804 NumericalRev<int64_t> size_;
1805 const int bsize_;
1806 std::vector<int64_t> removed_;
1807 };
1808
1809 // This is a special case where the bitset fits into one 64 bit integer.
1810 // In that case, there are no offset to compute.
1811 // Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1812 class SmallBitSet : public DomainIntVar::BitSet {
1813 public:
SmallBitSet(Solver * const s,int64_t vmin,int64_t vmax)1814 SmallBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1815 : BitSet(s),
1816 bits_(uint64_t{0}),
1817 stamp_(s->stamp() - 1),
1818 omin_(vmin),
1819 omax_(vmax),
1820 size_(vmax - vmin + 1) {
1821 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1822 bits_ = OneRange64(0, size_.Value() - 1);
1823 }
1824
SmallBitSet(Solver * const s,const std::vector<int64_t> & sorted_values,int64_t vmin,int64_t vmax)1825 SmallBitSet(Solver* const s, const std::vector<int64_t>& sorted_values,
1826 int64_t vmin, int64_t vmax)
1827 : BitSet(s),
1828 bits_(uint64_t{0}),
1829 stamp_(s->stamp() - 1),
1830 omin_(vmin),
1831 omax_(vmax),
1832 size_(sorted_values.size()) {
1833 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1834 // We know the array is sorted and does not contains duplicate values.
1835 for (int i = 0; i < sorted_values.size(); ++i) {
1836 const int64_t val = sorted_values[i];
1837 DCHECK_GE(val, vmin);
1838 DCHECK_LE(val, vmax);
1839 DCHECK(!IsBitSet64(&bits_, val - omin_));
1840 bits_ |= OneBit64(val - omin_);
1841 }
1842 }
1843
~SmallBitSet()1844 ~SmallBitSet() override {}
1845
bit(int64_t val) const1846 bool bit(int64_t val) const {
1847 DCHECK_GE(val, omin_);
1848 DCHECK_LE(val, omax_);
1849 return (bits_ & OneBit64(val - omin_)) != 0;
1850 }
1851
ComputeNewMin(int64_t nmin,int64_t cmin,int64_t cmax)1852 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1853 DCHECK_GE(nmin, cmin);
1854 DCHECK_LE(nmin, cmax);
1855 DCHECK_LE(cmin, cmax);
1856 DCHECK_GE(cmin, omin_);
1857 DCHECK_LE(cmax, omax_);
1858 // We do not clean the bits between cmin and nmin.
1859 // But we use mask to look only at 'active' bits.
1860
1861 // Create the mask and compute new bits
1862 const uint64_t new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1863 if (new_bits != uint64_t{0}) {
1864 // Compute new size and new min
1865 size_.SetValue(solver_, BitCount64(new_bits));
1866 if (bit(nmin)) { // Common case, the new min is inside the bitset
1867 return nmin;
1868 }
1869 return LeastSignificantBitPosition64(new_bits) + omin_;
1870 } else { // == 0 -> Fail()
1871 solver_->Fail();
1872 return std::numeric_limits<int64_t>::max();
1873 }
1874 }
1875
ComputeNewMax(int64_t nmax,int64_t cmin,int64_t cmax)1876 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1877 DCHECK_GE(nmax, cmin);
1878 DCHECK_LE(nmax, cmax);
1879 DCHECK_LE(cmin, cmax);
1880 DCHECK_GE(cmin, omin_);
1881 DCHECK_LE(cmax, omax_);
1882 // We do not clean the bits between nmax and cmax.
1883 // But we use mask to look only at 'active' bits.
1884
1885 // Create the mask and compute new_bits
1886 const uint64_t new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1887 if (new_bits != uint64_t{0}) {
1888 // Compute new size and new min
1889 size_.SetValue(solver_, BitCount64(new_bits));
1890 if (bit(nmax)) { // Common case, the new max is inside the bitset
1891 return nmax;
1892 }
1893 return MostSignificantBitPosition64(new_bits) + omin_;
1894 } else { // == 0 -> Fail()
1895 solver_->Fail();
1896 return std::numeric_limits<int64_t>::min();
1897 }
1898 }
1899
SetValue(int64_t val)1900 bool SetValue(int64_t val) override {
1901 DCHECK_GE(val, omin_);
1902 DCHECK_LE(val, omax_);
1903 // We do not clean the bits. We will use masks to ignore the bits
1904 // that should have been cleaned.
1905 if (bit(val)) {
1906 size_.SetValue(solver_, 1);
1907 return true;
1908 }
1909 return false;
1910 }
1911
Contains(int64_t val) const1912 bool Contains(int64_t val) const override {
1913 DCHECK_GE(val, omin_);
1914 DCHECK_LE(val, omax_);
1915 return bit(val);
1916 }
1917
RemoveValue(int64_t val)1918 bool RemoveValue(int64_t val) override {
1919 DCHECK_GE(val, omin_);
1920 DCHECK_LE(val, omax_);
1921 if (bit(val)) {
1922 // Bitset.
1923 const uint64_t current_stamp = solver_->stamp();
1924 if (stamp_ < current_stamp) {
1925 stamp_ = current_stamp;
1926 solver_->SaveValue(&bits_);
1927 }
1928 bits_ &= ~OneBit64(val - omin_);
1929 DCHECK(!bit(val));
1930 // Size.
1931 size_.Decr(solver_);
1932 // Holes.
1933 InitHoles();
1934 AddHole(val);
1935 return true;
1936 } else {
1937 return false;
1938 }
1939 }
1940
Size() const1941 uint64_t Size() const override { return size_.Value(); }
1942
DebugString() const1943 std::string DebugString() const override {
1944 return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1945 }
1946
DelayRemoveValue(int64_t val)1947 void DelayRemoveValue(int64_t val) override {
1948 DCHECK_GE(val, omin_);
1949 DCHECK_LE(val, omax_);
1950 removed_.push_back(val);
1951 }
1952
ApplyRemovedValues(DomainIntVar * var)1953 void ApplyRemovedValues(DomainIntVar* var) override {
1954 std::sort(removed_.begin(), removed_.end());
1955 for (std::vector<int64_t>::iterator it = removed_.begin();
1956 it != removed_.end(); ++it) {
1957 var->RemoveValue(*it);
1958 }
1959 }
1960
ClearRemovedValues()1961 void ClearRemovedValues() override { removed_.clear(); }
1962
pretty_DebugString(int64_t min,int64_t max) const1963 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1964 std::string out;
1965 DCHECK(bit(min));
1966 DCHECK(bit(max));
1967 if (max != min) {
1968 int cumul = true;
1969 int64_t start_cumul = min;
1970 for (int64_t v = min + 1; v < max; ++v) {
1971 if (bit(v)) {
1972 if (!cumul) {
1973 cumul = true;
1974 start_cumul = v;
1975 }
1976 } else {
1977 if (cumul) {
1978 if (v == start_cumul + 1) {
1979 absl::StrAppendFormat(&out, "%d ", start_cumul);
1980 } else if (v == start_cumul + 2) {
1981 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1982 } else {
1983 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1984 }
1985 cumul = false;
1986 }
1987 }
1988 }
1989 if (cumul) {
1990 if (max == start_cumul + 1) {
1991 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1992 } else {
1993 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1994 }
1995 } else {
1996 absl::StrAppendFormat(&out, "%d", max);
1997 }
1998 } else {
1999 absl::StrAppendFormat(&out, "%d", min);
2000 }
2001 return out;
2002 }
2003
MakeIterator()2004 DomainIntVar::BitSetIterator* MakeIterator() override {
2005 return new DomainIntVar::BitSetIterator(&bits_, omin_);
2006 }
2007
2008 private:
2009 uint64_t bits_;
2010 uint64_t stamp_;
2011 const int64_t omin_;
2012 const int64_t omax_;
2013 NumericalRev<int64_t> size_;
2014 std::vector<int64_t> removed_;
2015 };
2016
2017 class EmptyIterator : public IntVarIterator {
2018 public:
~EmptyIterator()2019 ~EmptyIterator() override {}
Init()2020 void Init() override {}
Ok() const2021 bool Ok() const override { return false; }
Value() const2022 int64_t Value() const override {
2023 LOG(FATAL) << "Should not be called";
2024 return 0LL;
2025 }
Next()2026 void Next() override {}
2027 };
2028
2029 class RangeIterator : public IntVarIterator {
2030 public:
RangeIterator(const IntVar * const var)2031 explicit RangeIterator(const IntVar* const var)
2032 : var_(var),
2033 min_(std::numeric_limits<int64_t>::max()),
2034 max_(std::numeric_limits<int64_t>::min()),
2035 current_(-1) {}
2036
~RangeIterator()2037 ~RangeIterator() override {}
2038
Init()2039 void Init() override {
2040 min_ = var_->Min();
2041 max_ = var_->Max();
2042 current_ = min_;
2043 }
2044
Ok() const2045 bool Ok() const override { return current_ <= max_; }
2046
Value() const2047 int64_t Value() const override { return current_; }
2048
Next()2049 void Next() override { current_++; }
2050
2051 private:
2052 const IntVar* const var_;
2053 int64_t min_;
2054 int64_t max_;
2055 int64_t current_;
2056 };
2057
2058 class DomainIntVarHoleIterator : public IntVarIterator {
2059 public:
DomainIntVarHoleIterator(const DomainIntVar * const v)2060 explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2061 : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2062
~DomainIntVarHoleIterator()2063 ~DomainIntVarHoleIterator() override {}
2064
Init()2065 void Init() override {
2066 bits_ = var_->bitset();
2067 if (bits_ != nullptr) {
2068 bits_->InitHoles();
2069 values_ = bits_->Holes().data();
2070 size_ = bits_->Holes().size();
2071 } else {
2072 values_ = nullptr;
2073 size_ = 0;
2074 }
2075 index_ = 0;
2076 }
2077
Ok() const2078 bool Ok() const override { return index_ < size_; }
2079
Value() const2080 int64_t Value() const override {
2081 DCHECK(bits_ != nullptr);
2082 DCHECK(index_ < size_);
2083 return values_[index_];
2084 }
2085
Next()2086 void Next() override { index_++; }
2087
2088 private:
2089 const DomainIntVar* const var_;
2090 DomainIntVar::BitSet* bits_;
2091 const int64_t* values_;
2092 int size_;
2093 int index_;
2094 };
2095
2096 class DomainIntVarDomainIterator : public IntVarIterator {
2097 public:
DomainIntVarDomainIterator(const DomainIntVar * const v,bool reversible)2098 explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2099 bool reversible)
2100 : var_(v),
2101 bitset_iterator_(nullptr),
2102 min_(std::numeric_limits<int64_t>::max()),
2103 max_(std::numeric_limits<int64_t>::min()),
2104 current_(-1),
2105 reversible_(reversible) {}
2106
~DomainIntVarDomainIterator()2107 ~DomainIntVarDomainIterator() override {
2108 if (!reversible_ && bitset_iterator_) {
2109 delete bitset_iterator_;
2110 }
2111 }
2112
Init()2113 void Init() override {
2114 if (var_->bitset() != nullptr && !var_->Bound()) {
2115 if (reversible_) {
2116 if (!bitset_iterator_) {
2117 Solver* const solver = var_->solver();
2118 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2119 bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2120 }
2121 } else {
2122 if (bitset_iterator_) {
2123 delete bitset_iterator_;
2124 }
2125 bitset_iterator_ = var_->bitset()->MakeIterator();
2126 }
2127 bitset_iterator_->Init(var_->Min(), var_->Max());
2128 } else {
2129 if (bitset_iterator_) {
2130 if (reversible_) {
2131 Solver* const solver = var_->solver();
2132 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2133 } else {
2134 delete bitset_iterator_;
2135 }
2136 bitset_iterator_ = nullptr;
2137 }
2138 min_ = var_->Min();
2139 max_ = var_->Max();
2140 current_ = min_;
2141 }
2142 }
2143
Ok() const2144 bool Ok() const override {
2145 return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2146 }
2147
Value() const2148 int64_t Value() const override {
2149 return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2150 }
2151
Next()2152 void Next() override {
2153 if (bitset_iterator_) {
2154 bitset_iterator_->Next();
2155 } else {
2156 current_++;
2157 }
2158 }
2159
2160 private:
2161 const DomainIntVar* const var_;
2162 DomainIntVar::BitSetIterator* bitset_iterator_;
2163 int64_t min_;
2164 int64_t max_;
2165 int64_t current_;
2166 const bool reversible_;
2167 };
2168
2169 class UnaryIterator : public IntVarIterator {
2170 public:
UnaryIterator(const IntVar * const v,bool hole,bool reversible)2171 UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2172 : iterator_(hole ? v->MakeHoleIterator(reversible)
2173 : v->MakeDomainIterator(reversible)),
2174 reversible_(reversible) {}
2175
~UnaryIterator()2176 ~UnaryIterator() override {
2177 if (!reversible_) {
2178 delete iterator_;
2179 }
2180 }
2181
Init()2182 void Init() override { iterator_->Init(); }
2183
Ok() const2184 bool Ok() const override { return iterator_->Ok(); }
2185
Next()2186 void Next() override { iterator_->Next(); }
2187
2188 protected:
2189 IntVarIterator* const iterator_;
2190 const bool reversible_;
2191 };
2192
DomainIntVar(Solver * const s,int64_t vmin,int64_t vmax,const std::string & name)2193 DomainIntVar::DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
2194 const std::string& name)
2195 : IntVar(s, name),
2196 min_(vmin),
2197 max_(vmax),
2198 old_min_(vmin),
2199 old_max_(vmax),
2200 new_min_(vmin),
2201 new_max_(vmax),
2202 handler_(this),
2203 in_process_(false),
2204 bits_(nullptr),
2205 value_watcher_(nullptr),
2206 bound_watcher_(nullptr) {}
2207
DomainIntVar(Solver * const s,const std::vector<int64_t> & sorted_values,const std::string & name)2208 DomainIntVar::DomainIntVar(Solver* const s,
2209 const std::vector<int64_t>& sorted_values,
2210 const std::string& name)
2211 : IntVar(s, name),
2212 min_(std::numeric_limits<int64_t>::max()),
2213 max_(std::numeric_limits<int64_t>::min()),
2214 old_min_(std::numeric_limits<int64_t>::max()),
2215 old_max_(std::numeric_limits<int64_t>::min()),
2216 new_min_(std::numeric_limits<int64_t>::max()),
2217 new_max_(std::numeric_limits<int64_t>::min()),
2218 handler_(this),
2219 in_process_(false),
2220 bits_(nullptr),
2221 value_watcher_(nullptr),
2222 bound_watcher_(nullptr) {
2223 CHECK_GE(sorted_values.size(), 1);
2224 // We know that the vector is sorted and does not have duplicate values.
2225 const int64_t vmin = sorted_values.front();
2226 const int64_t vmax = sorted_values.back();
2227 const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2228
2229 min_.SetValue(solver(), vmin);
2230 old_min_ = vmin;
2231 new_min_ = vmin;
2232 max_.SetValue(solver(), vmax);
2233 old_max_ = vmax;
2234 new_max_ = vmax;
2235
2236 if (!contiguous) {
2237 if (vmax - vmin + 1 < 65) {
2238 bits_ = solver()->RevAlloc(
2239 new SmallBitSet(solver(), sorted_values, vmin, vmax));
2240 } else {
2241 bits_ = solver()->RevAlloc(
2242 new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2243 }
2244 }
2245 }
2246
~DomainIntVar()2247 DomainIntVar::~DomainIntVar() {}
2248
SetMin(int64_t m)2249 void DomainIntVar::SetMin(int64_t m) {
2250 if (m <= min_.Value()) return;
2251 if (m > max_.Value()) solver()->Fail();
2252 if (in_process_) {
2253 if (m > new_min_) {
2254 new_min_ = m;
2255 if (new_min_ > new_max_) {
2256 solver()->Fail();
2257 }
2258 }
2259 } else {
2260 CheckOldMin();
2261 const int64_t new_min =
2262 (bits_ == nullptr
2263 ? m
2264 : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2265 min_.SetValue(solver(), new_min);
2266 if (min_.Value() > max_.Value()) {
2267 solver()->Fail();
2268 }
2269 Push();
2270 }
2271 }
2272
SetMax(int64_t m)2273 void DomainIntVar::SetMax(int64_t m) {
2274 if (m >= max_.Value()) return;
2275 if (m < min_.Value()) solver()->Fail();
2276 if (in_process_) {
2277 if (m < new_max_) {
2278 new_max_ = m;
2279 if (new_max_ < new_min_) {
2280 solver()->Fail();
2281 }
2282 }
2283 } else {
2284 CheckOldMax();
2285 const int64_t new_max =
2286 (bits_ == nullptr
2287 ? m
2288 : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2289 max_.SetValue(solver(), new_max);
2290 if (min_.Value() > max_.Value()) {
2291 solver()->Fail();
2292 }
2293 Push();
2294 }
2295 }
2296
SetRange(int64_t mi,int64_t ma)2297 void DomainIntVar::SetRange(int64_t mi, int64_t ma) {
2298 if (mi == ma) {
2299 SetValue(mi);
2300 } else {
2301 if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2302 if (mi <= min_.Value() && ma >= max_.Value()) return;
2303 if (in_process_) {
2304 if (ma < new_max_) {
2305 new_max_ = ma;
2306 }
2307 if (mi > new_min_) {
2308 new_min_ = mi;
2309 }
2310 if (new_min_ > new_max_) {
2311 solver()->Fail();
2312 }
2313 } else {
2314 if (mi > min_.Value()) {
2315 CheckOldMin();
2316 const int64_t new_min =
2317 (bits_ == nullptr
2318 ? mi
2319 : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2320 min_.SetValue(solver(), new_min);
2321 }
2322 if (min_.Value() > ma) {
2323 solver()->Fail();
2324 }
2325 if (ma < max_.Value()) {
2326 CheckOldMax();
2327 const int64_t new_max =
2328 (bits_ == nullptr
2329 ? ma
2330 : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2331 max_.SetValue(solver(), new_max);
2332 }
2333 if (min_.Value() > max_.Value()) {
2334 solver()->Fail();
2335 }
2336 Push();
2337 }
2338 }
2339 }
2340
SetValue(int64_t v)2341 void DomainIntVar::SetValue(int64_t v) {
2342 if (v != min_.Value() || v != max_.Value()) {
2343 if (v < min_.Value() || v > max_.Value()) {
2344 solver()->Fail();
2345 }
2346 if (in_process_) {
2347 if (v > new_max_ || v < new_min_) {
2348 solver()->Fail();
2349 }
2350 new_min_ = v;
2351 new_max_ = v;
2352 } else {
2353 if (bits_ && !bits_->SetValue(v)) {
2354 solver()->Fail();
2355 }
2356 CheckOldMin();
2357 CheckOldMax();
2358 min_.SetValue(solver(), v);
2359 max_.SetValue(solver(), v);
2360 Push();
2361 }
2362 }
2363 }
2364
RemoveValue(int64_t v)2365 void DomainIntVar::RemoveValue(int64_t v) {
2366 if (v < min_.Value() || v > max_.Value()) return;
2367 if (v == min_.Value()) {
2368 SetMin(v + 1);
2369 } else if (v == max_.Value()) {
2370 SetMax(v - 1);
2371 } else {
2372 if (bits_ == nullptr) {
2373 CreateBits();
2374 }
2375 if (in_process_) {
2376 if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2377 bits_->DelayRemoveValue(v);
2378 }
2379 } else {
2380 if (bits_->RemoveValue(v)) {
2381 Push();
2382 }
2383 }
2384 }
2385 }
2386
RemoveInterval(int64_t l,int64_t u)2387 void DomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2388 if (l <= min_.Value()) {
2389 SetMin(u + 1);
2390 } else if (u >= max_.Value()) {
2391 SetMax(l - 1);
2392 } else {
2393 for (int64_t v = l; v <= u; ++v) {
2394 RemoveValue(v);
2395 }
2396 }
2397 }
2398
CreateBits()2399 void DomainIntVar::CreateBits() {
2400 solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2401 if (max_.Value() - min_.Value() < 64) {
2402 bits_ = solver()->RevAlloc(
2403 new SmallBitSet(solver(), min_.Value(), max_.Value()));
2404 } else {
2405 bits_ = solver()->RevAlloc(
2406 new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2407 }
2408 }
2409
CleanInProcess()2410 void DomainIntVar::CleanInProcess() {
2411 in_process_ = false;
2412 if (bits_ != nullptr) {
2413 bits_->ClearHoles();
2414 }
2415 }
2416
Push()2417 void DomainIntVar::Push() {
2418 const bool in_process = in_process_;
2419 EnqueueVar(&handler_);
2420 CHECK_EQ(in_process, in_process_);
2421 }
2422
Process()2423 void DomainIntVar::Process() {
2424 CHECK(!in_process_);
2425 in_process_ = true;
2426 if (bits_ != nullptr) {
2427 bits_->ClearRemovedValues();
2428 }
2429 set_variable_to_clean_on_fail(this);
2430 new_min_ = min_.Value();
2431 new_max_ = max_.Value();
2432 const bool is_bound = min_.Value() == max_.Value();
2433 const bool range_changed =
2434 min_.Value() != OldMin() || max_.Value() != OldMax();
2435 // Process immediate demons.
2436 if (is_bound) {
2437 ExecuteAll(bound_demons_);
2438 }
2439 if (range_changed) {
2440 ExecuteAll(range_demons_);
2441 }
2442 ExecuteAll(domain_demons_);
2443
2444 // Process delayed demons.
2445 if (is_bound) {
2446 EnqueueAll(delayed_bound_demons_);
2447 }
2448 if (range_changed) {
2449 EnqueueAll(delayed_range_demons_);
2450 }
2451 EnqueueAll(delayed_domain_demons_);
2452
2453 // Everything went well if we arrive here. Let's clean the variable.
2454 set_variable_to_clean_on_fail(nullptr);
2455 CleanInProcess();
2456 old_min_ = min_.Value();
2457 old_max_ = max_.Value();
2458 if (min_.Value() < new_min_) {
2459 SetMin(new_min_);
2460 }
2461 if (max_.Value() > new_max_) {
2462 SetMax(new_max_);
2463 }
2464 if (bits_ != nullptr) {
2465 bits_->ApplyRemovedValues(this);
2466 }
2467 }
2468
2469 #define COND_REV_ALLOC(rev, alloc) rev ? solver()->RevAlloc(alloc) : alloc;
2470
MakeHoleIterator(bool reversible) const2471 IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2472 return COND_REV_ALLOC(reversible, new DomainIntVarHoleIterator(this));
2473 }
2474
MakeDomainIterator(bool reversible) const2475 IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2476 return COND_REV_ALLOC(reversible,
2477 new DomainIntVarDomainIterator(this, reversible));
2478 }
2479
DebugString() const2480 std::string DomainIntVar::DebugString() const {
2481 std::string out;
2482 const std::string& var_name = name();
2483 if (!var_name.empty()) {
2484 out = var_name + "(";
2485 } else {
2486 out = "DomainIntVar(";
2487 }
2488 if (min_.Value() == max_.Value()) {
2489 absl::StrAppendFormat(&out, "%d", min_.Value());
2490 } else if (bits_ != nullptr) {
2491 out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2492 } else {
2493 absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2494 }
2495 out += ")";
2496 return out;
2497 }
2498
2499 // ----- Real Boolean Var -----
2500
2501 class ConcreteBooleanVar : public BooleanVar {
2502 public:
2503 // Utility classes
2504 class Handler : public Demon {
2505 public:
Handler(ConcreteBooleanVar * const var)2506 explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
~Handler()2507 ~Handler() override {}
Run(Solver * const s)2508 void Run(Solver* const s) override {
2509 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2510 var_->Process();
2511 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2512 }
priority() const2513 Solver::DemonPriority priority() const override {
2514 return Solver::VAR_PRIORITY;
2515 }
DebugString() const2516 std::string DebugString() const override {
2517 return absl::StrFormat("Handler(%s)", var_->DebugString());
2518 }
2519
2520 private:
2521 ConcreteBooleanVar* const var_;
2522 };
2523
ConcreteBooleanVar(Solver * const s,const std::string & name)2524 ConcreteBooleanVar(Solver* const s, const std::string& name)
2525 : BooleanVar(s, name), handler_(this) {}
2526
~ConcreteBooleanVar()2527 ~ConcreteBooleanVar() override {}
2528
SetValue(int64_t v)2529 void SetValue(int64_t v) override {
2530 if (value_ == kUnboundBooleanVarValue) {
2531 if ((v & 0xfffffffffffffffe) == 0) {
2532 InternalSaveBooleanVarValue(solver(), this);
2533 value_ = static_cast<int>(v);
2534 EnqueueVar(&handler_);
2535 return;
2536 }
2537 } else if (v == value_) {
2538 return;
2539 }
2540 solver()->Fail();
2541 }
2542
Process()2543 void Process() {
2544 DCHECK_NE(value_, kUnboundBooleanVarValue);
2545 ExecuteAll(bound_demons_);
2546 for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2547 ++it) {
2548 EnqueueDelayedDemon(*it);
2549 }
2550 }
2551
OldMin() const2552 int64_t OldMin() const override { return 0LL; }
OldMax() const2553 int64_t OldMax() const override { return 1LL; }
RestoreValue()2554 void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2555
2556 private:
2557 Handler handler_;
2558 };
2559
2560 // ----- IntConst -----
2561
2562 class IntConst : public IntVar {
2563 public:
IntConst(Solver * const s,int64_t value,const std::string & name="")2564 IntConst(Solver* const s, int64_t value, const std::string& name = "")
2565 : IntVar(s, name), value_(value) {}
~IntConst()2566 ~IntConst() override {}
2567
Min() const2568 int64_t Min() const override { return value_; }
SetMin(int64_t m)2569 void SetMin(int64_t m) override {
2570 if (m > value_) {
2571 solver()->Fail();
2572 }
2573 }
Max() const2574 int64_t Max() const override { return value_; }
SetMax(int64_t m)2575 void SetMax(int64_t m) override {
2576 if (m < value_) {
2577 solver()->Fail();
2578 }
2579 }
SetRange(int64_t l,int64_t u)2580 void SetRange(int64_t l, int64_t u) override {
2581 if (l > value_ || u < value_) {
2582 solver()->Fail();
2583 }
2584 }
SetValue(int64_t v)2585 void SetValue(int64_t v) override {
2586 if (v != value_) {
2587 solver()->Fail();
2588 }
2589 }
Bound() const2590 bool Bound() const override { return true; }
Value() const2591 int64_t Value() const override { return value_; }
RemoveValue(int64_t v)2592 void RemoveValue(int64_t v) override {
2593 if (v == value_) {
2594 solver()->Fail();
2595 }
2596 }
RemoveInterval(int64_t l,int64_t u)2597 void RemoveInterval(int64_t l, int64_t u) override {
2598 if (l <= value_ && value_ <= u) {
2599 solver()->Fail();
2600 }
2601 }
WhenBound(Demon * d)2602 void WhenBound(Demon* d) override {}
WhenRange(Demon * d)2603 void WhenRange(Demon* d) override {}
WhenDomain(Demon * d)2604 void WhenDomain(Demon* d) override {}
Size() const2605 uint64_t Size() const override { return 1; }
Contains(int64_t v) const2606 bool Contains(int64_t v) const override { return (v == value_); }
MakeHoleIterator(bool reversible) const2607 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2608 return COND_REV_ALLOC(reversible, new EmptyIterator());
2609 }
MakeDomainIterator(bool reversible) const2610 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2611 return COND_REV_ALLOC(reversible, new RangeIterator(this));
2612 }
OldMin() const2613 int64_t OldMin() const override { return value_; }
OldMax() const2614 int64_t OldMax() const override { return value_; }
DebugString() const2615 std::string DebugString() const override {
2616 std::string out;
2617 if (solver()->HasName(this)) {
2618 const std::string& var_name = name();
2619 absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2620 } else {
2621 absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2622 }
2623 return out;
2624 }
2625
VarType() const2626 int VarType() const override { return CONST_VAR; }
2627
IsEqual(int64_t constant)2628 IntVar* IsEqual(int64_t constant) override {
2629 if (constant == value_) {
2630 return solver()->MakeIntConst(1);
2631 } else {
2632 return solver()->MakeIntConst(0);
2633 }
2634 }
2635
IsDifferent(int64_t constant)2636 IntVar* IsDifferent(int64_t constant) override {
2637 if (constant == value_) {
2638 return solver()->MakeIntConst(0);
2639 } else {
2640 return solver()->MakeIntConst(1);
2641 }
2642 }
2643
IsGreaterOrEqual(int64_t constant)2644 IntVar* IsGreaterOrEqual(int64_t constant) override {
2645 return solver()->MakeIntConst(value_ >= constant);
2646 }
2647
IsLessOrEqual(int64_t constant)2648 IntVar* IsLessOrEqual(int64_t constant) override {
2649 return solver()->MakeIntConst(value_ <= constant);
2650 }
2651
name() const2652 std::string name() const override {
2653 if (solver()->HasName(this)) {
2654 return PropagationBaseObject::name();
2655 } else {
2656 return absl::StrCat(value_);
2657 }
2658 }
2659
2660 private:
2661 int64_t value_;
2662 };
2663
2664 // ----- x + c variable, optimized case -----
2665
2666 class PlusCstVar : public IntVar {
2667 public:
PlusCstVar(Solver * const s,IntVar * v,int64_t c)2668 PlusCstVar(Solver* const s, IntVar* v, int64_t c)
2669 : IntVar(s), var_(v), cst_(c) {}
2670
~PlusCstVar()2671 ~PlusCstVar() override {}
2672
WhenRange(Demon * d)2673 void WhenRange(Demon* d) override { var_->WhenRange(d); }
2674
WhenBound(Demon * d)2675 void WhenBound(Demon* d) override { var_->WhenBound(d); }
2676
WhenDomain(Demon * d)2677 void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2678
OldMin() const2679 int64_t OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2680
OldMax() const2681 int64_t OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2682
DebugString() const2683 std::string DebugString() const override {
2684 if (HasName()) {
2685 return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2686 } else {
2687 return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2688 }
2689 }
2690
VarType() const2691 int VarType() const override { return VAR_ADD_CST; }
2692
Accept(ModelVisitor * const visitor) const2693 void Accept(ModelVisitor* const visitor) const override {
2694 visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2695 var_);
2696 }
2697
IsEqual(int64_t constant)2698 IntVar* IsEqual(int64_t constant) override {
2699 return var_->IsEqual(constant - cst_);
2700 }
2701
IsDifferent(int64_t constant)2702 IntVar* IsDifferent(int64_t constant) override {
2703 return var_->IsDifferent(constant - cst_);
2704 }
2705
IsGreaterOrEqual(int64_t constant)2706 IntVar* IsGreaterOrEqual(int64_t constant) override {
2707 return var_->IsGreaterOrEqual(constant - cst_);
2708 }
2709
IsLessOrEqual(int64_t constant)2710 IntVar* IsLessOrEqual(int64_t constant) override {
2711 return var_->IsLessOrEqual(constant - cst_);
2712 }
2713
SubVar() const2714 IntVar* SubVar() const { return var_; }
2715
Constant() const2716 int64_t Constant() const { return cst_; }
2717
2718 protected:
2719 IntVar* const var_;
2720 const int64_t cst_;
2721 };
2722
2723 class PlusCstIntVar : public PlusCstVar {
2724 public:
2725 class PlusCstIntVarIterator : public UnaryIterator {
2726 public:
PlusCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool rev)2727 PlusCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2728 : UnaryIterator(v, hole, rev), cst_(c) {}
2729
~PlusCstIntVarIterator()2730 ~PlusCstIntVarIterator() override {}
2731
Value() const2732 int64_t Value() const override { return iterator_->Value() + cst_; }
2733
2734 private:
2735 const int64_t cst_;
2736 };
2737
PlusCstIntVar(Solver * const s,IntVar * v,int64_t c)2738 PlusCstIntVar(Solver* const s, IntVar* v, int64_t c) : PlusCstVar(s, v, c) {}
2739
~PlusCstIntVar()2740 ~PlusCstIntVar() override {}
2741
Min() const2742 int64_t Min() const override { return var_->Min() + cst_; }
2743
SetMin(int64_t m)2744 void SetMin(int64_t m) override { var_->SetMin(CapSub(m, cst_)); }
2745
Max() const2746 int64_t Max() const override { return var_->Max() + cst_; }
2747
SetMax(int64_t m)2748 void SetMax(int64_t m) override { var_->SetMax(CapSub(m, cst_)); }
2749
SetRange(int64_t l,int64_t u)2750 void SetRange(int64_t l, int64_t u) override {
2751 var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2752 }
2753
SetValue(int64_t v)2754 void SetValue(int64_t v) override { var_->SetValue(v - cst_); }
2755
Value() const2756 int64_t Value() const override { return var_->Value() + cst_; }
2757
Bound() const2758 bool Bound() const override { return var_->Bound(); }
2759
RemoveValue(int64_t v)2760 void RemoveValue(int64_t v) override { var_->RemoveValue(v - cst_); }
2761
RemoveInterval(int64_t l,int64_t u)2762 void RemoveInterval(int64_t l, int64_t u) override {
2763 var_->RemoveInterval(l - cst_, u - cst_);
2764 }
2765
Size() const2766 uint64_t Size() const override { return var_->Size(); }
2767
Contains(int64_t v) const2768 bool Contains(int64_t v) const override { return var_->Contains(v - cst_); }
2769
MakeHoleIterator(bool reversible) const2770 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2771 return COND_REV_ALLOC(
2772 reversible, new PlusCstIntVarIterator(var_, cst_, true, reversible));
2773 }
MakeDomainIterator(bool reversible) const2774 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2775 return COND_REV_ALLOC(
2776 reversible, new PlusCstIntVarIterator(var_, cst_, false, reversible));
2777 }
2778 };
2779
2780 class PlusCstDomainIntVar : public PlusCstVar {
2781 public:
2782 class PlusCstDomainIntVarIterator : public UnaryIterator {
2783 public:
PlusCstDomainIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)2784 PlusCstDomainIntVarIterator(const IntVar* const v, int64_t c, bool hole,
2785 bool reversible)
2786 : UnaryIterator(v, hole, reversible), cst_(c) {}
2787
~PlusCstDomainIntVarIterator()2788 ~PlusCstDomainIntVarIterator() override {}
2789
Value() const2790 int64_t Value() const override { return iterator_->Value() + cst_; }
2791
2792 private:
2793 const int64_t cst_;
2794 };
2795
PlusCstDomainIntVar(Solver * const s,DomainIntVar * v,int64_t c)2796 PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64_t c)
2797 : PlusCstVar(s, v, c) {}
2798
~PlusCstDomainIntVar()2799 ~PlusCstDomainIntVar() override {}
2800
2801 int64_t Min() const override;
2802 void SetMin(int64_t m) override;
2803 int64_t Max() const override;
2804 void SetMax(int64_t m) override;
2805 void SetRange(int64_t l, int64_t u) override;
2806 void SetValue(int64_t v) override;
2807 bool Bound() const override;
2808 int64_t Value() const override;
2809 void RemoveValue(int64_t v) override;
2810 void RemoveInterval(int64_t l, int64_t u) override;
2811 uint64_t Size() const override;
2812 bool Contains(int64_t v) const override;
2813
domain_int_var() const2814 DomainIntVar* domain_int_var() const {
2815 return reinterpret_cast<DomainIntVar*>(var_);
2816 }
2817
MakeHoleIterator(bool reversible) const2818 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2819 return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2820 var_, cst_, true, reversible));
2821 }
MakeDomainIterator(bool reversible) const2822 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2823 return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2824 var_, cst_, false, reversible));
2825 }
2826 };
2827
Min() const2828 int64_t PlusCstDomainIntVar::Min() const {
2829 return domain_int_var()->min_.Value() + cst_;
2830 }
2831
SetMin(int64_t m)2832 void PlusCstDomainIntVar::SetMin(int64_t m) {
2833 domain_int_var()->DomainIntVar::SetMin(m - cst_);
2834 }
2835
Max() const2836 int64_t PlusCstDomainIntVar::Max() const {
2837 return domain_int_var()->max_.Value() + cst_;
2838 }
2839
SetMax(int64_t m)2840 void PlusCstDomainIntVar::SetMax(int64_t m) {
2841 domain_int_var()->DomainIntVar::SetMax(m - cst_);
2842 }
2843
SetRange(int64_t l,int64_t u)2844 void PlusCstDomainIntVar::SetRange(int64_t l, int64_t u) {
2845 domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2846 }
2847
SetValue(int64_t v)2848 void PlusCstDomainIntVar::SetValue(int64_t v) {
2849 domain_int_var()->DomainIntVar::SetValue(v - cst_);
2850 }
2851
Bound() const2852 bool PlusCstDomainIntVar::Bound() const {
2853 return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2854 }
2855
Value() const2856 int64_t PlusCstDomainIntVar::Value() const {
2857 CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2858 << " variable is not bound";
2859 return domain_int_var()->min_.Value() + cst_;
2860 }
2861
RemoveValue(int64_t v)2862 void PlusCstDomainIntVar::RemoveValue(int64_t v) {
2863 domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2864 }
2865
RemoveInterval(int64_t l,int64_t u)2866 void PlusCstDomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2867 domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2868 }
2869
Size() const2870 uint64_t PlusCstDomainIntVar::Size() const {
2871 return domain_int_var()->DomainIntVar::Size();
2872 }
2873
Contains(int64_t v) const2874 bool PlusCstDomainIntVar::Contains(int64_t v) const {
2875 return domain_int_var()->DomainIntVar::Contains(v - cst_);
2876 }
2877
2878 // c - x variable, optimized case
2879
2880 class SubCstIntVar : public IntVar {
2881 public:
2882 class SubCstIntVarIterator : public UnaryIterator {
2883 public:
SubCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool rev)2884 SubCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2885 : UnaryIterator(v, hole, rev), cst_(c) {}
~SubCstIntVarIterator()2886 ~SubCstIntVarIterator() override {}
2887
Value() const2888 int64_t Value() const override { return cst_ - iterator_->Value(); }
2889
2890 private:
2891 const int64_t cst_;
2892 };
2893
2894 SubCstIntVar(Solver* const s, IntVar* v, int64_t c);
2895 ~SubCstIntVar() override;
2896
2897 int64_t Min() const override;
2898 void SetMin(int64_t m) override;
2899 int64_t Max() const override;
2900 void SetMax(int64_t m) override;
2901 void SetRange(int64_t l, int64_t u) override;
2902 void SetValue(int64_t v) override;
2903 bool Bound() const override;
2904 int64_t Value() const override;
2905 void RemoveValue(int64_t v) override;
2906 void RemoveInterval(int64_t l, int64_t u) override;
2907 uint64_t Size() const override;
2908 bool Contains(int64_t v) const override;
2909 void WhenRange(Demon* d) override;
2910 void WhenBound(Demon* d) override;
2911 void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const2912 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2913 return COND_REV_ALLOC(
2914 reversible, new SubCstIntVarIterator(var_, cst_, true, reversible));
2915 }
MakeDomainIterator(bool reversible) const2916 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2917 return COND_REV_ALLOC(
2918 reversible, new SubCstIntVarIterator(var_, cst_, false, reversible));
2919 }
OldMin() const2920 int64_t OldMin() const override { return CapSub(cst_, var_->OldMax()); }
OldMax() const2921 int64_t OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2922 std::string DebugString() const override;
2923 std::string name() const override;
VarType() const2924 int VarType() const override { return CST_SUB_VAR; }
2925
Accept(ModelVisitor * const visitor) const2926 void Accept(ModelVisitor* const visitor) const override {
2927 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2928 cst_, var_);
2929 }
2930
IsEqual(int64_t constant)2931 IntVar* IsEqual(int64_t constant) override {
2932 return var_->IsEqual(cst_ - constant);
2933 }
2934
IsDifferent(int64_t constant)2935 IntVar* IsDifferent(int64_t constant) override {
2936 return var_->IsDifferent(cst_ - constant);
2937 }
2938
IsGreaterOrEqual(int64_t constant)2939 IntVar* IsGreaterOrEqual(int64_t constant) override {
2940 return var_->IsLessOrEqual(cst_ - constant);
2941 }
2942
IsLessOrEqual(int64_t constant)2943 IntVar* IsLessOrEqual(int64_t constant) override {
2944 return var_->IsGreaterOrEqual(cst_ - constant);
2945 }
2946
SubVar() const2947 IntVar* SubVar() const { return var_; }
Constant() const2948 int64_t Constant() const { return cst_; }
2949
2950 private:
2951 IntVar* const var_;
2952 const int64_t cst_;
2953 };
2954
SubCstIntVar(Solver * const s,IntVar * v,int64_t c)2955 SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64_t c)
2956 : IntVar(s), var_(v), cst_(c) {}
2957
~SubCstIntVar()2958 SubCstIntVar::~SubCstIntVar() {}
2959
Min() const2960 int64_t SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2961
SetMin(int64_t m)2962 void SubCstIntVar::SetMin(int64_t m) { var_->SetMax(CapSub(cst_, m)); }
2963
Max() const2964 int64_t SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2965
SetMax(int64_t m)2966 void SubCstIntVar::SetMax(int64_t m) { var_->SetMin(CapSub(cst_, m)); }
2967
SetRange(int64_t l,int64_t u)2968 void SubCstIntVar::SetRange(int64_t l, int64_t u) {
2969 var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2970 }
2971
SetValue(int64_t v)2972 void SubCstIntVar::SetValue(int64_t v) { var_->SetValue(cst_ - v); }
2973
Bound() const2974 bool SubCstIntVar::Bound() const { return var_->Bound(); }
2975
WhenRange(Demon * d)2976 void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2977
Value() const2978 int64_t SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2979
RemoveValue(int64_t v)2980 void SubCstIntVar::RemoveValue(int64_t v) { var_->RemoveValue(cst_ - v); }
2981
RemoveInterval(int64_t l,int64_t u)2982 void SubCstIntVar::RemoveInterval(int64_t l, int64_t u) {
2983 var_->RemoveInterval(cst_ - u, cst_ - l);
2984 }
2985
WhenBound(Demon * d)2986 void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2987
WhenDomain(Demon * d)2988 void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2989
Size() const2990 uint64_t SubCstIntVar::Size() const { return var_->Size(); }
2991
Contains(int64_t v) const2992 bool SubCstIntVar::Contains(int64_t v) const {
2993 return var_->Contains(cst_ - v);
2994 }
2995
DebugString() const2996 std::string SubCstIntVar::DebugString() const {
2997 if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2998 return absl::StrFormat("Not(%s)", var_->DebugString());
2999 } else {
3000 return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
3001 }
3002 }
3003
name() const3004 std::string SubCstIntVar::name() const {
3005 if (solver()->HasName(this)) {
3006 return PropagationBaseObject::name();
3007 } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3008 return absl::StrFormat("Not(%s)", var_->name());
3009 } else {
3010 return absl::StrFormat("(%d - %s)", cst_, var_->name());
3011 }
3012 }
3013
3014 // -x variable, optimized case
3015
3016 class OppIntVar : public IntVar {
3017 public:
3018 class OppIntVarIterator : public UnaryIterator {
3019 public:
OppIntVarIterator(const IntVar * const v,bool hole,bool reversible)3020 OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3021 : UnaryIterator(v, hole, reversible) {}
~OppIntVarIterator()3022 ~OppIntVarIterator() override {}
3023
Value() const3024 int64_t Value() const override { return -iterator_->Value(); }
3025 };
3026
3027 OppIntVar(Solver* const s, IntVar* v);
3028 ~OppIntVar() override;
3029
3030 int64_t Min() const override;
3031 void SetMin(int64_t m) override;
3032 int64_t Max() const override;
3033 void SetMax(int64_t m) override;
3034 void SetRange(int64_t l, int64_t u) override;
3035 void SetValue(int64_t v) override;
3036 bool Bound() const override;
3037 int64_t Value() const override;
3038 void RemoveValue(int64_t v) override;
3039 void RemoveInterval(int64_t l, int64_t u) override;
3040 uint64_t Size() const override;
3041 bool Contains(int64_t v) const override;
3042 void WhenRange(Demon* d) override;
3043 void WhenBound(Demon* d) override;
3044 void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3045 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3046 return COND_REV_ALLOC(reversible,
3047 new OppIntVarIterator(var_, true, reversible));
3048 }
MakeDomainIterator(bool reversible) const3049 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3050 return COND_REV_ALLOC(reversible,
3051 new OppIntVarIterator(var_, false, reversible));
3052 }
OldMin() const3053 int64_t OldMin() const override { return CapOpp(var_->OldMax()); }
OldMax() const3054 int64_t OldMax() const override { return CapOpp(var_->OldMin()); }
3055 std::string DebugString() const override;
VarType() const3056 int VarType() const override { return OPP_VAR; }
3057
Accept(ModelVisitor * const visitor) const3058 void Accept(ModelVisitor* const visitor) const override {
3059 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3060 var_);
3061 }
3062
IsEqual(int64_t constant)3063 IntVar* IsEqual(int64_t constant) override {
3064 return var_->IsEqual(-constant);
3065 }
3066
IsDifferent(int64_t constant)3067 IntVar* IsDifferent(int64_t constant) override {
3068 return var_->IsDifferent(-constant);
3069 }
3070
IsGreaterOrEqual(int64_t constant)3071 IntVar* IsGreaterOrEqual(int64_t constant) override {
3072 return var_->IsLessOrEqual(-constant);
3073 }
3074
IsLessOrEqual(int64_t constant)3075 IntVar* IsLessOrEqual(int64_t constant) override {
3076 return var_->IsGreaterOrEqual(-constant);
3077 }
3078
SubVar() const3079 IntVar* SubVar() const { return var_; }
3080
3081 private:
3082 IntVar* const var_;
3083 };
3084
OppIntVar(Solver * const s,IntVar * v)3085 OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3086
~OppIntVar()3087 OppIntVar::~OppIntVar() {}
3088
Min() const3089 int64_t OppIntVar::Min() const { return -var_->Max(); }
3090
SetMin(int64_t m)3091 void OppIntVar::SetMin(int64_t m) { var_->SetMax(CapOpp(m)); }
3092
Max() const3093 int64_t OppIntVar::Max() const { return -var_->Min(); }
3094
SetMax(int64_t m)3095 void OppIntVar::SetMax(int64_t m) { var_->SetMin(CapOpp(m)); }
3096
SetRange(int64_t l,int64_t u)3097 void OppIntVar::SetRange(int64_t l, int64_t u) {
3098 var_->SetRange(CapOpp(u), CapOpp(l));
3099 }
3100
SetValue(int64_t v)3101 void OppIntVar::SetValue(int64_t v) { var_->SetValue(CapOpp(v)); }
3102
Bound() const3103 bool OppIntVar::Bound() const { return var_->Bound(); }
3104
WhenRange(Demon * d)3105 void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3106
Value() const3107 int64_t OppIntVar::Value() const { return -var_->Value(); }
3108
RemoveValue(int64_t v)3109 void OppIntVar::RemoveValue(int64_t v) { var_->RemoveValue(-v); }
3110
RemoveInterval(int64_t l,int64_t u)3111 void OppIntVar::RemoveInterval(int64_t l, int64_t u) {
3112 var_->RemoveInterval(-u, -l);
3113 }
3114
WhenBound(Demon * d)3115 void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3116
WhenDomain(Demon * d)3117 void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3118
Size() const3119 uint64_t OppIntVar::Size() const { return var_->Size(); }
3120
Contains(int64_t v) const3121 bool OppIntVar::Contains(int64_t v) const { return var_->Contains(-v); }
3122
DebugString() const3123 std::string OppIntVar::DebugString() const {
3124 return absl::StrFormat("-(%s)", var_->DebugString());
3125 }
3126
3127 // ----- Utility functions -----
3128
3129 // x * c variable, optimized case
3130
3131 class TimesCstIntVar : public IntVar {
3132 public:
TimesCstIntVar(Solver * const s,IntVar * v,int64_t c)3133 TimesCstIntVar(Solver* const s, IntVar* v, int64_t c)
3134 : IntVar(s), var_(v), cst_(c) {}
~TimesCstIntVar()3135 ~TimesCstIntVar() override {}
3136
SubVar() const3137 IntVar* SubVar() const { return var_; }
Constant() const3138 int64_t Constant() const { return cst_; }
3139
Accept(ModelVisitor * const visitor) const3140 void Accept(ModelVisitor* const visitor) const override {
3141 visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3142 var_);
3143 }
3144
IsEqual(int64_t constant)3145 IntVar* IsEqual(int64_t constant) override {
3146 if (constant % cst_ == 0) {
3147 return var_->IsEqual(constant / cst_);
3148 } else {
3149 return solver()->MakeIntConst(0);
3150 }
3151 }
3152
IsDifferent(int64_t constant)3153 IntVar* IsDifferent(int64_t constant) override {
3154 if (constant % cst_ == 0) {
3155 return var_->IsDifferent(constant / cst_);
3156 } else {
3157 return solver()->MakeIntConst(1);
3158 }
3159 }
3160
IsGreaterOrEqual(int64_t constant)3161 IntVar* IsGreaterOrEqual(int64_t constant) override {
3162 if (cst_ > 0) {
3163 return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3164 } else {
3165 return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3166 }
3167 }
3168
IsLessOrEqual(int64_t constant)3169 IntVar* IsLessOrEqual(int64_t constant) override {
3170 if (cst_ > 0) {
3171 return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3172 } else {
3173 return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3174 }
3175 }
3176
DebugString() const3177 std::string DebugString() const override {
3178 return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3179 }
3180
VarType() const3181 int VarType() const override { return VAR_TIMES_CST; }
3182
3183 protected:
3184 IntVar* const var_;
3185 const int64_t cst_;
3186 };
3187
3188 class TimesPosCstIntVar : public TimesCstIntVar {
3189 public:
3190 class TimesPosCstIntVarIterator : public UnaryIterator {
3191 public:
TimesPosCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)3192 TimesPosCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3193 bool reversible)
3194 : UnaryIterator(v, hole, reversible), cst_(c) {}
~TimesPosCstIntVarIterator()3195 ~TimesPosCstIntVarIterator() override {}
3196
Value() const3197 int64_t Value() const override { return iterator_->Value() * cst_; }
3198
3199 private:
3200 const int64_t cst_;
3201 };
3202
3203 TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c);
3204 ~TimesPosCstIntVar() override;
3205
3206 int64_t Min() const override;
3207 void SetMin(int64_t m) override;
3208 int64_t Max() const override;
3209 void SetMax(int64_t m) override;
3210 void SetRange(int64_t l, int64_t u) override;
3211 void SetValue(int64_t v) override;
3212 bool Bound() const override;
3213 int64_t Value() const override;
3214 void RemoveValue(int64_t v) override;
3215 void RemoveInterval(int64_t l, int64_t u) override;
3216 uint64_t Size() const override;
3217 bool Contains(int64_t v) const override;
3218 void WhenRange(Demon* d) override;
3219 void WhenBound(Demon* d) override;
3220 void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3221 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3222 return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3223 var_, cst_, true, reversible));
3224 }
MakeDomainIterator(bool reversible) const3225 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3226 return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3227 var_, cst_, false, reversible));
3228 }
OldMin() const3229 int64_t OldMin() const override { return CapProd(var_->OldMin(), cst_); }
OldMax() const3230 int64_t OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3231 };
3232
3233 // ----- TimesPosCstIntVar -----
3234
TimesPosCstIntVar(Solver * const s,IntVar * v,int64_t c)3235 TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c)
3236 : TimesCstIntVar(s, v, c) {}
3237
~TimesPosCstIntVar()3238 TimesPosCstIntVar::~TimesPosCstIntVar() {}
3239
Min() const3240 int64_t TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3241
SetMin(int64_t m)3242 void TimesPosCstIntVar::SetMin(int64_t m) {
3243 if (m != std::numeric_limits<int64_t>::min()) {
3244 var_->SetMin(PosIntDivUp(m, cst_));
3245 }
3246 }
3247
Max() const3248 int64_t TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3249
SetMax(int64_t m)3250 void TimesPosCstIntVar::SetMax(int64_t m) {
3251 if (m != std::numeric_limits<int64_t>::max()) {
3252 var_->SetMax(PosIntDivDown(m, cst_));
3253 }
3254 }
3255
SetRange(int64_t l,int64_t u)3256 void TimesPosCstIntVar::SetRange(int64_t l, int64_t u) {
3257 var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3258 }
3259
SetValue(int64_t v)3260 void TimesPosCstIntVar::SetValue(int64_t v) {
3261 if (v % cst_ != 0) {
3262 solver()->Fail();
3263 }
3264 var_->SetValue(v / cst_);
3265 }
3266
Bound() const3267 bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3268
WhenRange(Demon * d)3269 void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3270
Value() const3271 int64_t TimesPosCstIntVar::Value() const {
3272 return CapProd(var_->Value(), cst_);
3273 }
3274
RemoveValue(int64_t v)3275 void TimesPosCstIntVar::RemoveValue(int64_t v) {
3276 if (v % cst_ == 0) {
3277 var_->RemoveValue(v / cst_);
3278 }
3279 }
3280
RemoveInterval(int64_t l,int64_t u)3281 void TimesPosCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3282 for (int64_t v = l; v <= u; ++v) {
3283 RemoveValue(v);
3284 }
3285 // TODO(user) : Improve me
3286 }
3287
WhenBound(Demon * d)3288 void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3289
WhenDomain(Demon * d)3290 void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3291
Size() const3292 uint64_t TimesPosCstIntVar::Size() const { return var_->Size(); }
3293
Contains(int64_t v) const3294 bool TimesPosCstIntVar::Contains(int64_t v) const {
3295 return (v % cst_ == 0 && var_->Contains(v / cst_));
3296 }
3297
3298 // b * c variable, optimized case
3299
3300 class TimesPosCstBoolVar : public TimesCstIntVar {
3301 public:
3302 class TimesPosCstBoolVarIterator : public UnaryIterator {
3303 public:
3304 // TODO(user) : optimize this.
TimesPosCstBoolVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)3305 TimesPosCstBoolVarIterator(const IntVar* const v, int64_t c, bool hole,
3306 bool reversible)
3307 : UnaryIterator(v, hole, reversible), cst_(c) {}
~TimesPosCstBoolVarIterator()3308 ~TimesPosCstBoolVarIterator() override {}
3309
Value() const3310 int64_t Value() const override { return iterator_->Value() * cst_; }
3311
3312 private:
3313 const int64_t cst_;
3314 };
3315
3316 TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64_t c);
3317 ~TimesPosCstBoolVar() override;
3318
3319 int64_t Min() const override;
3320 void SetMin(int64_t m) override;
3321 int64_t Max() const override;
3322 void SetMax(int64_t m) override;
3323 void SetRange(int64_t l, int64_t u) override;
3324 void SetValue(int64_t v) override;
3325 bool Bound() const override;
3326 int64_t Value() const override;
3327 void RemoveValue(int64_t v) override;
3328 void RemoveInterval(int64_t l, int64_t u) override;
3329 uint64_t Size() const override;
3330 bool Contains(int64_t v) const override;
3331 void WhenRange(Demon* d) override;
3332 void WhenBound(Demon* d) override;
3333 void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3334 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3335 return COND_REV_ALLOC(reversible, new EmptyIterator());
3336 }
MakeDomainIterator(bool reversible) const3337 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3338 return COND_REV_ALLOC(
3339 reversible,
3340 new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3341 }
OldMin() const3342 int64_t OldMin() const override { return 0; }
OldMax() const3343 int64_t OldMax() const override { return cst_; }
3344
boolean_var() const3345 BooleanVar* boolean_var() const {
3346 return reinterpret_cast<BooleanVar*>(var_);
3347 }
3348 };
3349
3350 // ----- TimesPosCstBoolVar -----
3351
TimesPosCstBoolVar(Solver * const s,BooleanVar * v,int64_t c)3352 TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v,
3353 int64_t c)
3354 : TimesCstIntVar(s, v, c) {}
3355
~TimesPosCstBoolVar()3356 TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3357
Min() const3358 int64_t TimesPosCstBoolVar::Min() const {
3359 return (boolean_var()->RawValue() == 1) * cst_;
3360 }
3361
SetMin(int64_t m)3362 void TimesPosCstBoolVar::SetMin(int64_t m) {
3363 if (m > cst_) {
3364 solver()->Fail();
3365 } else if (m > 0) {
3366 boolean_var()->SetMin(1);
3367 }
3368 }
3369
Max() const3370 int64_t TimesPosCstBoolVar::Max() const {
3371 return (boolean_var()->RawValue() != 0) * cst_;
3372 }
3373
SetMax(int64_t m)3374 void TimesPosCstBoolVar::SetMax(int64_t m) {
3375 if (m < 0) {
3376 solver()->Fail();
3377 } else if (m < cst_) {
3378 boolean_var()->SetMax(0);
3379 }
3380 }
3381
SetRange(int64_t l,int64_t u)3382 void TimesPosCstBoolVar::SetRange(int64_t l, int64_t u) {
3383 if (u < 0 || l > cst_ || l > u) {
3384 solver()->Fail();
3385 }
3386 if (l > 0) {
3387 boolean_var()->SetMin(1);
3388 } else if (u < cst_) {
3389 boolean_var()->SetMax(0);
3390 }
3391 }
3392
SetValue(int64_t v)3393 void TimesPosCstBoolVar::SetValue(int64_t v) {
3394 if (v == 0) {
3395 boolean_var()->SetValue(0);
3396 } else if (v == cst_) {
3397 boolean_var()->SetValue(1);
3398 } else {
3399 solver()->Fail();
3400 }
3401 }
3402
Bound() const3403 bool TimesPosCstBoolVar::Bound() const {
3404 return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3405 }
3406
WhenRange(Demon * d)3407 void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3408
Value() const3409 int64_t TimesPosCstBoolVar::Value() const {
3410 CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3411 << " variable is not bound";
3412 return boolean_var()->RawValue() * cst_;
3413 }
3414
RemoveValue(int64_t v)3415 void TimesPosCstBoolVar::RemoveValue(int64_t v) {
3416 if (v == 0) {
3417 boolean_var()->RemoveValue(0);
3418 } else if (v == cst_) {
3419 boolean_var()->RemoveValue(1);
3420 }
3421 }
3422
RemoveInterval(int64_t l,int64_t u)3423 void TimesPosCstBoolVar::RemoveInterval(int64_t l, int64_t u) {
3424 if (l <= 0 && u >= 0) {
3425 boolean_var()->RemoveValue(0);
3426 }
3427 if (l <= cst_ && u >= cst_) {
3428 boolean_var()->RemoveValue(1);
3429 }
3430 }
3431
WhenBound(Demon * d)3432 void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3433
WhenDomain(Demon * d)3434 void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3435
Size() const3436 uint64_t TimesPosCstBoolVar::Size() const {
3437 return (1 +
3438 (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3439 }
3440
Contains(int64_t v) const3441 bool TimesPosCstBoolVar::Contains(int64_t v) const {
3442 if (v == 0) {
3443 return boolean_var()->RawValue() != 1;
3444 } else if (v == cst_) {
3445 return boolean_var()->RawValue() != 0;
3446 }
3447 return false;
3448 }
3449
3450 // TimesNegCstIntVar
3451
3452 class TimesNegCstIntVar : public TimesCstIntVar {
3453 public:
3454 class TimesNegCstIntVarIterator : public UnaryIterator {
3455 public:
TimesNegCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)3456 TimesNegCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3457 bool reversible)
3458 : UnaryIterator(v, hole, reversible), cst_(c) {}
~TimesNegCstIntVarIterator()3459 ~TimesNegCstIntVarIterator() override {}
3460
Value() const3461 int64_t Value() const override { return iterator_->Value() * cst_; }
3462
3463 private:
3464 const int64_t cst_;
3465 };
3466
3467 TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c);
3468 ~TimesNegCstIntVar() override;
3469
3470 int64_t Min() const override;
3471 void SetMin(int64_t m) override;
3472 int64_t Max() const override;
3473 void SetMax(int64_t m) override;
3474 void SetRange(int64_t l, int64_t u) override;
3475 void SetValue(int64_t v) override;
3476 bool Bound() const override;
3477 int64_t Value() const override;
3478 void RemoveValue(int64_t v) override;
3479 void RemoveInterval(int64_t l, int64_t u) override;
3480 uint64_t Size() const override;
3481 bool Contains(int64_t v) const override;
3482 void WhenRange(Demon* d) override;
3483 void WhenBound(Demon* d) override;
3484 void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3485 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3486 return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3487 var_, cst_, true, reversible));
3488 }
MakeDomainIterator(bool reversible) const3489 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3490 return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3491 var_, cst_, false, reversible));
3492 }
OldMin() const3493 int64_t OldMin() const override { return CapProd(var_->OldMax(), cst_); }
OldMax() const3494 int64_t OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3495 };
3496
3497 // ----- TimesNegCstIntVar -----
3498
TimesNegCstIntVar(Solver * const s,IntVar * v,int64_t c)3499 TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c)
3500 : TimesCstIntVar(s, v, c) {}
3501
~TimesNegCstIntVar()3502 TimesNegCstIntVar::~TimesNegCstIntVar() {}
3503
Min() const3504 int64_t TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3505
SetMin(int64_t m)3506 void TimesNegCstIntVar::SetMin(int64_t m) {
3507 if (m != std::numeric_limits<int64_t>::min()) {
3508 var_->SetMax(PosIntDivDown(-m, -cst_));
3509 }
3510 }
3511
Max() const3512 int64_t TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3513
SetMax(int64_t m)3514 void TimesNegCstIntVar::SetMax(int64_t m) {
3515 if (m != std::numeric_limits<int64_t>::max()) {
3516 var_->SetMin(PosIntDivUp(-m, -cst_));
3517 }
3518 }
3519
SetRange(int64_t l,int64_t u)3520 void TimesNegCstIntVar::SetRange(int64_t l, int64_t u) {
3521 var_->SetRange(PosIntDivUp(-u, -cst_), PosIntDivDown(-l, -cst_));
3522 }
3523
SetValue(int64_t v)3524 void TimesNegCstIntVar::SetValue(int64_t v) {
3525 if (v % cst_ != 0) {
3526 solver()->Fail();
3527 }
3528 var_->SetValue(v / cst_);
3529 }
3530
Bound() const3531 bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3532
WhenRange(Demon * d)3533 void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3534
Value() const3535 int64_t TimesNegCstIntVar::Value() const {
3536 return CapProd(var_->Value(), cst_);
3537 }
3538
RemoveValue(int64_t v)3539 void TimesNegCstIntVar::RemoveValue(int64_t v) {
3540 if (v % cst_ == 0) {
3541 var_->RemoveValue(v / cst_);
3542 }
3543 }
3544
RemoveInterval(int64_t l,int64_t u)3545 void TimesNegCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3546 for (int64_t v = l; v <= u; ++v) {
3547 RemoveValue(v);
3548 }
3549 // TODO(user) : Improve me
3550 }
3551
WhenBound(Demon * d)3552 void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3553
WhenDomain(Demon * d)3554 void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3555
Size() const3556 uint64_t TimesNegCstIntVar::Size() const { return var_->Size(); }
3557
Contains(int64_t v) const3558 bool TimesNegCstIntVar::Contains(int64_t v) const {
3559 return (v % cst_ == 0 && var_->Contains(v / cst_));
3560 }
3561
3562 // ---------- arithmetic expressions ----------
3563
3564 // ----- PlusIntExpr -----
3565
3566 class PlusIntExpr : public BaseIntExpr {
3567 public:
PlusIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3568 PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3569 : BaseIntExpr(s), left_(l), right_(r) {}
3570
~PlusIntExpr()3571 ~PlusIntExpr() override {}
3572
Min() const3573 int64_t Min() const override { return left_->Min() + right_->Min(); }
3574
SetMin(int64_t m)3575 void SetMin(int64_t m) override {
3576 if (m > left_->Min() + right_->Min()) {
3577 left_->SetMin(m - right_->Max());
3578 right_->SetMin(m - left_->Max());
3579 }
3580 }
3581
SetRange(int64_t l,int64_t u)3582 void SetRange(int64_t l, int64_t u) override {
3583 const int64_t left_min = left_->Min();
3584 const int64_t right_min = right_->Min();
3585 const int64_t left_max = left_->Max();
3586 const int64_t right_max = right_->Max();
3587 if (l > left_min + right_min) {
3588 left_->SetMin(l - right_max);
3589 right_->SetMin(l - left_max);
3590 }
3591 if (u < left_max + right_max) {
3592 left_->SetMax(u - right_min);
3593 right_->SetMax(u - left_min);
3594 }
3595 }
3596
Max() const3597 int64_t Max() const override { return left_->Max() + right_->Max(); }
3598
SetMax(int64_t m)3599 void SetMax(int64_t m) override {
3600 if (m < left_->Max() + right_->Max()) {
3601 left_->SetMax(m - right_->Min());
3602 right_->SetMax(m - left_->Min());
3603 }
3604 }
3605
Bound() const3606 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3607
Range(int64_t * const mi,int64_t * const ma)3608 void Range(int64_t* const mi, int64_t* const ma) override {
3609 *mi = left_->Min() + right_->Min();
3610 *ma = left_->Max() + right_->Max();
3611 }
3612
name() const3613 std::string name() const override {
3614 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3615 }
3616
DebugString() const3617 std::string DebugString() const override {
3618 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3619 right_->DebugString());
3620 }
3621
WhenRange(Demon * d)3622 void WhenRange(Demon* d) override {
3623 left_->WhenRange(d);
3624 right_->WhenRange(d);
3625 }
3626
ExpandPlusIntExpr(IntExpr * const expr,std::vector<IntExpr * > * subs)3627 void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3628 PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3629 if (casted != nullptr) {
3630 ExpandPlusIntExpr(casted->left_, subs);
3631 ExpandPlusIntExpr(casted->right_, subs);
3632 } else {
3633 subs->push_back(expr);
3634 }
3635 }
3636
CastToVar()3637 IntVar* CastToVar() override {
3638 if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3639 dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3640 std::vector<IntExpr*> sub_exprs;
3641 ExpandPlusIntExpr(left_, &sub_exprs);
3642 ExpandPlusIntExpr(right_, &sub_exprs);
3643 if (sub_exprs.size() >= 3) {
3644 std::vector<IntVar*> sub_vars(sub_exprs.size());
3645 for (int i = 0; i < sub_exprs.size(); ++i) {
3646 sub_vars[i] = sub_exprs[i]->Var();
3647 }
3648 return solver()->MakeSum(sub_vars)->Var();
3649 }
3650 }
3651 return BaseIntExpr::CastToVar();
3652 }
3653
Accept(ModelVisitor * const visitor) const3654 void Accept(ModelVisitor* const visitor) const override {
3655 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3656 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3657 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3658 right_);
3659 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3660 }
3661
3662 private:
3663 IntExpr* const left_;
3664 IntExpr* const right_;
3665 };
3666
3667 class SafePlusIntExpr : public BaseIntExpr {
3668 public:
SafePlusIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3669 SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3670 : BaseIntExpr(s), left_(l), right_(r) {}
3671
~SafePlusIntExpr()3672 ~SafePlusIntExpr() override {}
3673
Min() const3674 int64_t Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3675
SetMin(int64_t m)3676 void SetMin(int64_t m) override {
3677 left_->SetMin(CapSub(m, right_->Max()));
3678 right_->SetMin(CapSub(m, left_->Max()));
3679 }
3680
SetRange(int64_t l,int64_t u)3681 void SetRange(int64_t l, int64_t u) override {
3682 const int64_t left_min = left_->Min();
3683 const int64_t right_min = right_->Min();
3684 const int64_t left_max = left_->Max();
3685 const int64_t right_max = right_->Max();
3686 if (l > CapAdd(left_min, right_min)) {
3687 left_->SetMin(CapSub(l, right_max));
3688 right_->SetMin(CapSub(l, left_max));
3689 }
3690 if (u < CapAdd(left_max, right_max)) {
3691 left_->SetMax(CapSub(u, right_min));
3692 right_->SetMax(CapSub(u, left_min));
3693 }
3694 }
3695
Max() const3696 int64_t Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3697
SetMax(int64_t m)3698 void SetMax(int64_t m) override {
3699 left_->SetMax(CapSub(m, right_->Min()));
3700 right_->SetMax(CapSub(m, left_->Min()));
3701 }
3702
Bound() const3703 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3704
name() const3705 std::string name() const override {
3706 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3707 }
3708
DebugString() const3709 std::string DebugString() const override {
3710 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3711 right_->DebugString());
3712 }
3713
WhenRange(Demon * d)3714 void WhenRange(Demon* d) override {
3715 left_->WhenRange(d);
3716 right_->WhenRange(d);
3717 }
3718
Accept(ModelVisitor * const visitor) const3719 void Accept(ModelVisitor* const visitor) const override {
3720 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3721 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3722 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3723 right_);
3724 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3725 }
3726
3727 private:
3728 IntExpr* const left_;
3729 IntExpr* const right_;
3730 };
3731
3732 // ----- PlusIntCstExpr -----
3733
3734 class PlusIntCstExpr : public BaseIntExpr {
3735 public:
PlusIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)3736 PlusIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3737 : BaseIntExpr(s), expr_(e), value_(v) {}
~PlusIntCstExpr()3738 ~PlusIntCstExpr() override {}
Min() const3739 int64_t Min() const override { return CapAdd(expr_->Min(), value_); }
SetMin(int64_t m)3740 void SetMin(int64_t m) override { expr_->SetMin(CapSub(m, value_)); }
Max() const3741 int64_t Max() const override { return CapAdd(expr_->Max(), value_); }
SetMax(int64_t m)3742 void SetMax(int64_t m) override { expr_->SetMax(CapSub(m, value_)); }
Bound() const3743 bool Bound() const override { return (expr_->Bound()); }
name() const3744 std::string name() const override {
3745 return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3746 }
DebugString() const3747 std::string DebugString() const override {
3748 return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3749 }
WhenRange(Demon * d)3750 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3751 IntVar* CastToVar() override;
Accept(ModelVisitor * const visitor) const3752 void Accept(ModelVisitor* const visitor) const override {
3753 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3754 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3755 expr_);
3756 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3757 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3758 }
3759
3760 private:
3761 IntExpr* const expr_;
3762 const int64_t value_;
3763 };
3764
CastToVar()3765 IntVar* PlusIntCstExpr::CastToVar() {
3766 Solver* const s = solver();
3767 IntVar* const var = expr_->Var();
3768 IntVar* cast = nullptr;
3769 if (AddOverflows(value_, expr_->Max()) ||
3770 AddOverflows(value_, expr_->Min())) {
3771 return BaseIntExpr::CastToVar();
3772 }
3773 switch (var->VarType()) {
3774 case DOMAIN_INT_VAR:
3775 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3776 s, reinterpret_cast<DomainIntVar*>(var), value_)));
3777 // FIXME: Break was inserted during fallthrough cleanup. Please check.
3778 break;
3779 default:
3780 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3781 break;
3782 }
3783 return cast;
3784 }
3785
3786 // ----- SubIntExpr -----
3787
3788 class SubIntExpr : public BaseIntExpr {
3789 public:
SubIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3790 SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3791 : BaseIntExpr(s), left_(l), right_(r) {}
3792
~SubIntExpr()3793 ~SubIntExpr() override {}
3794
Min() const3795 int64_t Min() const override { return left_->Min() - right_->Max(); }
3796
SetMin(int64_t m)3797 void SetMin(int64_t m) override {
3798 left_->SetMin(CapAdd(m, right_->Min()));
3799 right_->SetMax(CapSub(left_->Max(), m));
3800 }
3801
Max() const3802 int64_t Max() const override { return left_->Max() - right_->Min(); }
3803
SetMax(int64_t m)3804 void SetMax(int64_t m) override {
3805 left_->SetMax(CapAdd(m, right_->Max()));
3806 right_->SetMin(CapSub(left_->Min(), m));
3807 }
3808
Range(int64_t * mi,int64_t * ma)3809 void Range(int64_t* mi, int64_t* ma) override {
3810 *mi = left_->Min() - right_->Max();
3811 *ma = left_->Max() - right_->Min();
3812 }
3813
SetRange(int64_t l,int64_t u)3814 void SetRange(int64_t l, int64_t u) override {
3815 const int64_t left_min = left_->Min();
3816 const int64_t right_min = right_->Min();
3817 const int64_t left_max = left_->Max();
3818 const int64_t right_max = right_->Max();
3819 if (l > left_min - right_max) {
3820 left_->SetMin(CapAdd(l, right_min));
3821 right_->SetMax(CapSub(left_max, l));
3822 }
3823 if (u < left_max - right_min) {
3824 left_->SetMax(CapAdd(u, right_max));
3825 right_->SetMin(CapSub(left_min, u));
3826 }
3827 }
3828
Bound() const3829 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3830
name() const3831 std::string name() const override {
3832 return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3833 }
3834
DebugString() const3835 std::string DebugString() const override {
3836 return absl::StrFormat("(%s - %s)", left_->DebugString(),
3837 right_->DebugString());
3838 }
3839
WhenRange(Demon * d)3840 void WhenRange(Demon* d) override {
3841 left_->WhenRange(d);
3842 right_->WhenRange(d);
3843 }
3844
Accept(ModelVisitor * const visitor) const3845 void Accept(ModelVisitor* const visitor) const override {
3846 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3847 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3848 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3849 right_);
3850 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3851 }
3852
left() const3853 IntExpr* left() const { return left_; }
right() const3854 IntExpr* right() const { return right_; }
3855
3856 protected:
3857 IntExpr* const left_;
3858 IntExpr* const right_;
3859 };
3860
3861 class SafeSubIntExpr : public SubIntExpr {
3862 public:
SafeSubIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3863 SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3864 : SubIntExpr(s, l, r) {}
3865
~SafeSubIntExpr()3866 ~SafeSubIntExpr() override {}
3867
Min() const3868 int64_t Min() const override { return CapSub(left_->Min(), right_->Max()); }
3869
SetMin(int64_t m)3870 void SetMin(int64_t m) override {
3871 left_->SetMin(CapAdd(m, right_->Min()));
3872 right_->SetMax(CapSub(left_->Max(), m));
3873 }
3874
SetRange(int64_t l,int64_t u)3875 void SetRange(int64_t l, int64_t u) override {
3876 const int64_t left_min = left_->Min();
3877 const int64_t right_min = right_->Min();
3878 const int64_t left_max = left_->Max();
3879 const int64_t right_max = right_->Max();
3880 if (l > CapSub(left_min, right_max)) {
3881 left_->SetMin(CapAdd(l, right_min));
3882 right_->SetMax(CapSub(left_max, l));
3883 }
3884 if (u < CapSub(left_max, right_min)) {
3885 left_->SetMax(CapAdd(u, right_max));
3886 right_->SetMin(CapSub(left_min, u));
3887 }
3888 }
3889
Range(int64_t * mi,int64_t * ma)3890 void Range(int64_t* mi, int64_t* ma) override {
3891 *mi = CapSub(left_->Min(), right_->Max());
3892 *ma = CapSub(left_->Max(), right_->Min());
3893 }
3894
Max() const3895 int64_t Max() const override { return CapSub(left_->Max(), right_->Min()); }
3896
SetMax(int64_t m)3897 void SetMax(int64_t m) override {
3898 left_->SetMax(CapAdd(m, right_->Max()));
3899 right_->SetMin(CapSub(left_->Min(), m));
3900 }
3901 };
3902
3903 // l - r
3904
3905 // ----- SubIntCstExpr -----
3906
3907 class SubIntCstExpr : public BaseIntExpr {
3908 public:
SubIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)3909 SubIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3910 : BaseIntExpr(s), expr_(e), value_(v) {}
~SubIntCstExpr()3911 ~SubIntCstExpr() override {}
Min() const3912 int64_t Min() const override { return CapSub(value_, expr_->Max()); }
SetMin(int64_t m)3913 void SetMin(int64_t m) override { expr_->SetMax(CapSub(value_, m)); }
Max() const3914 int64_t Max() const override { return CapSub(value_, expr_->Min()); }
SetMax(int64_t m)3915 void SetMax(int64_t m) override { expr_->SetMin(CapSub(value_, m)); }
Bound() const3916 bool Bound() const override { return (expr_->Bound()); }
name() const3917 std::string name() const override {
3918 return absl::StrFormat("(%d - %s)", value_, expr_->name());
3919 }
DebugString() const3920 std::string DebugString() const override {
3921 return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3922 }
WhenRange(Demon * d)3923 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3924 IntVar* CastToVar() override;
3925
Accept(ModelVisitor * const visitor) const3926 void Accept(ModelVisitor* const visitor) const override {
3927 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3928 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3929 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3930 expr_);
3931 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3932 }
3933
3934 private:
3935 IntExpr* const expr_;
3936 const int64_t value_;
3937 };
3938
CastToVar()3939 IntVar* SubIntCstExpr::CastToVar() {
3940 if (SubOverflows(value_, expr_->Min()) ||
3941 SubOverflows(value_, expr_->Max())) {
3942 return BaseIntExpr::CastToVar();
3943 }
3944 Solver* const s = solver();
3945 IntVar* const var =
3946 s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3947 return var;
3948 }
3949
3950 // ----- OppIntExpr -----
3951
3952 class OppIntExpr : public BaseIntExpr {
3953 public:
OppIntExpr(Solver * const s,IntExpr * const e)3954 OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
~OppIntExpr()3955 ~OppIntExpr() override {}
Min() const3956 int64_t Min() const override { return (CapOpp(expr_->Max())); }
SetMin(int64_t m)3957 void SetMin(int64_t m) override { expr_->SetMax(CapOpp(m)); }
Max() const3958 int64_t Max() const override { return (CapOpp(expr_->Min())); }
SetMax(int64_t m)3959 void SetMax(int64_t m) override { expr_->SetMin(CapOpp(m)); }
Bound() const3960 bool Bound() const override { return (expr_->Bound()); }
name() const3961 std::string name() const override {
3962 return absl::StrFormat("(-%s)", expr_->name());
3963 }
DebugString() const3964 std::string DebugString() const override {
3965 return absl::StrFormat("(-%s)", expr_->DebugString());
3966 }
WhenRange(Demon * d)3967 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3968 IntVar* CastToVar() override;
3969
Accept(ModelVisitor * const visitor) const3970 void Accept(ModelVisitor* const visitor) const override {
3971 visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3972 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3973 expr_);
3974 visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3975 }
3976
3977 private:
3978 IntExpr* const expr_;
3979 };
3980
CastToVar()3981 IntVar* OppIntExpr::CastToVar() {
3982 Solver* const s = solver();
3983 IntVar* const var =
3984 s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
3985 return var;
3986 }
3987
3988 // ----- TimesIntCstExpr -----
3989
3990 class TimesIntCstExpr : public BaseIntExpr {
3991 public:
TimesIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)3992 TimesIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3993 : BaseIntExpr(s), expr_(e), value_(v) {}
3994
~TimesIntCstExpr()3995 ~TimesIntCstExpr() override {}
3996
Bound() const3997 bool Bound() const override { return (expr_->Bound()); }
3998
name() const3999 std::string name() const override {
4000 return absl::StrFormat("(%s * %d)", expr_->name(), value_);
4001 }
4002
DebugString() const4003 std::string DebugString() const override {
4004 return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
4005 }
4006
WhenRange(Demon * d)4007 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4008
Expr() const4009 IntExpr* Expr() const { return expr_; }
4010
Constant() const4011 int64_t Constant() const { return value_; }
4012
Accept(ModelVisitor * const visitor) const4013 void Accept(ModelVisitor* const visitor) const override {
4014 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4015 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4016 expr_);
4017 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4018 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4019 }
4020
4021 protected:
4022 IntExpr* const expr_;
4023 const int64_t value_;
4024 };
4025
4026 // ----- TimesPosIntCstExpr -----
4027
4028 class TimesPosIntCstExpr : public TimesIntCstExpr {
4029 public:
TimesPosIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)4030 TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4031 : TimesIntCstExpr(s, e, v) {
4032 CHECK_GT(v, 0);
4033 }
4034
~TimesPosIntCstExpr()4035 ~TimesPosIntCstExpr() override {}
4036
Min() const4037 int64_t Min() const override { return expr_->Min() * value_; }
4038
SetMin(int64_t m)4039 void SetMin(int64_t m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4040
Max() const4041 int64_t Max() const override { return expr_->Max() * value_; }
4042
SetMax(int64_t m)4043 void SetMax(int64_t m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4044
CastToVar()4045 IntVar* CastToVar() override {
4046 Solver* const s = solver();
4047 IntVar* var = nullptr;
4048 if (expr_->IsVar() &&
4049 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4050 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4051 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4052 } else {
4053 var = s->RegisterIntVar(
4054 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4055 }
4056 return var;
4057 }
4058 };
4059
4060 // This expressions adds safe arithmetic (w.r.t. overflows) compared
4061 // to the previous one.
4062 class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4063 public:
SafeTimesPosIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)4064 SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4065 : TimesIntCstExpr(s, e, v) {
4066 CHECK_GT(v, 0);
4067 }
4068
~SafeTimesPosIntCstExpr()4069 ~SafeTimesPosIntCstExpr() override {}
4070
Min() const4071 int64_t Min() const override { return CapProd(expr_->Min(), value_); }
4072
SetMin(int64_t m)4073 void SetMin(int64_t m) override {
4074 if (m != std::numeric_limits<int64_t>::min()) {
4075 expr_->SetMin(PosIntDivUp(m, value_));
4076 }
4077 }
4078
Max() const4079 int64_t Max() const override { return CapProd(expr_->Max(), value_); }
4080
SetMax(int64_t m)4081 void SetMax(int64_t m) override {
4082 if (m != std::numeric_limits<int64_t>::max()) {
4083 expr_->SetMax(PosIntDivDown(m, value_));
4084 }
4085 }
4086
CastToVar()4087 IntVar* CastToVar() override {
4088 Solver* const s = solver();
4089 IntVar* var = nullptr;
4090 if (expr_->IsVar() &&
4091 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4092 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4093 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4094 } else {
4095 // TODO(user): Check overflows.
4096 var = s->RegisterIntVar(
4097 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4098 }
4099 return var;
4100 }
4101 };
4102
4103 // ----- TimesIntNegCstExpr -----
4104
4105 class TimesIntNegCstExpr : public TimesIntCstExpr {
4106 public:
TimesIntNegCstExpr(Solver * const s,IntExpr * const e,int64_t v)4107 TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4108 : TimesIntCstExpr(s, e, v) {
4109 CHECK_LT(v, 0);
4110 }
4111
~TimesIntNegCstExpr()4112 ~TimesIntNegCstExpr() override {}
4113
Min() const4114 int64_t Min() const override { return CapProd(expr_->Max(), value_); }
4115
SetMin(int64_t m)4116 void SetMin(int64_t m) override {
4117 if (m != std::numeric_limits<int64_t>::min()) {
4118 expr_->SetMax(PosIntDivDown(-m, -value_));
4119 }
4120 }
4121
Max() const4122 int64_t Max() const override { return CapProd(expr_->Min(), value_); }
4123
SetMax(int64_t m)4124 void SetMax(int64_t m) override {
4125 if (m != std::numeric_limits<int64_t>::max()) {
4126 expr_->SetMin(PosIntDivUp(-m, -value_));
4127 }
4128 }
4129
CastToVar()4130 IntVar* CastToVar() override {
4131 Solver* const s = solver();
4132 IntVar* var = nullptr;
4133 var = s->RegisterIntVar(
4134 s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4135 return var;
4136 }
4137 };
4138
4139 // ----- Utilities for product expression -----
4140
4141 // Propagates set_min on left * right, left and right >= 0.
SetPosPosMinExpr(IntExpr * const left,IntExpr * const right,int64_t m)4142 void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4143 DCHECK_GE(left->Min(), 0);
4144 DCHECK_GE(right->Min(), 0);
4145 const int64_t lmax = left->Max();
4146 const int64_t rmax = right->Max();
4147 if (m > CapProd(lmax, rmax)) {
4148 left->solver()->Fail();
4149 }
4150 if (m > CapProd(left->Min(), right->Min())) {
4151 // Ok for m == 0 due to left and right being positive
4152 if (0 != rmax) {
4153 left->SetMin(PosIntDivUp(m, rmax));
4154 }
4155 if (0 != lmax) {
4156 right->SetMin(PosIntDivUp(m, lmax));
4157 }
4158 }
4159 }
4160
4161 // Propagates set_max on left * right, left and right >= 0.
SetPosPosMaxExpr(IntExpr * const left,IntExpr * const right,int64_t m)4162 void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4163 DCHECK_GE(left->Min(), 0);
4164 DCHECK_GE(right->Min(), 0);
4165 const int64_t lmin = left->Min();
4166 const int64_t rmin = right->Min();
4167 if (m < CapProd(lmin, rmin)) {
4168 left->solver()->Fail();
4169 }
4170 if (m < CapProd(left->Max(), right->Max())) {
4171 if (0 != lmin) {
4172 right->SetMax(PosIntDivDown(m, lmin));
4173 }
4174 if (0 != rmin) {
4175 left->SetMax(PosIntDivDown(m, rmin));
4176 }
4177 // else do nothing: 0 is supporting any value from other expr.
4178 }
4179 }
4180
4181 // Propagates set_min on left * right, left >= 0, right across 0.
SetPosGenMinExpr(IntExpr * const left,IntExpr * const right,int64_t m)4182 void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4183 DCHECK_GE(left->Min(), 0);
4184 DCHECK_GT(right->Max(), 0);
4185 DCHECK_LT(right->Min(), 0);
4186 const int64_t lmax = left->Max();
4187 const int64_t rmax = right->Max();
4188 if (m > CapProd(lmax, rmax)) {
4189 left->solver()->Fail();
4190 }
4191 if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4192 DCHECK_EQ(0, left->Min());
4193 DCHECK_LE(m, 0);
4194 } else {
4195 if (m > 0) { // We deduce right > 0.
4196 left->SetMin(PosIntDivUp(m, rmax));
4197 right->SetMin(PosIntDivUp(m, lmax));
4198 } else if (m == 0) {
4199 const int64_t lmin = left->Min();
4200 if (lmin > 0) {
4201 right->SetMin(0);
4202 }
4203 } else { // m < 0
4204 const int64_t lmin = left->Min();
4205 if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4206 right->SetMin(-PosIntDivDown(-m, lmin));
4207 }
4208 }
4209 }
4210 }
4211
4212 // Propagates set_min on left * right, left and right across 0.
SetGenGenMinExpr(IntExpr * const left,IntExpr * const right,int64_t m)4213 void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4214 DCHECK_LT(left->Min(), 0);
4215 DCHECK_GT(left->Max(), 0);
4216 DCHECK_GT(right->Max(), 0);
4217 DCHECK_LT(right->Min(), 0);
4218 const int64_t lmin = left->Min();
4219 const int64_t lmax = left->Max();
4220 const int64_t rmin = right->Min();
4221 const int64_t rmax = right->Max();
4222 if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4223 left->solver()->Fail();
4224 }
4225 if (m > lmin * rmin) { // Must be positive section * positive section.
4226 left->SetMin(PosIntDivUp(m, rmax));
4227 right->SetMin(PosIntDivUp(m, lmax));
4228 } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4229 left->SetMax(-PosIntDivUp(m, -rmin));
4230 right->SetMax(-PosIntDivUp(m, -lmin));
4231 }
4232 }
4233
TimesSetMin(IntExpr * const left,IntExpr * const right,IntExpr * const minus_left,IntExpr * const minus_right,int64_t m)4234 void TimesSetMin(IntExpr* const left, IntExpr* const right,
4235 IntExpr* const minus_left, IntExpr* const minus_right,
4236 int64_t m) {
4237 if (left->Min() >= 0) {
4238 if (right->Min() >= 0) {
4239 SetPosPosMinExpr(left, right, m);
4240 } else if (right->Max() <= 0) {
4241 SetPosPosMaxExpr(left, minus_right, -m);
4242 } else { // right->Min() < 0 && right->Max() > 0
4243 SetPosGenMinExpr(left, right, m);
4244 }
4245 } else if (left->Max() <= 0) {
4246 if (right->Min() >= 0) {
4247 SetPosPosMaxExpr(right, minus_left, -m);
4248 } else if (right->Max() <= 0) {
4249 SetPosPosMinExpr(minus_left, minus_right, m);
4250 } else { // right->Min() < 0 && right->Max() > 0
4251 SetPosGenMinExpr(minus_left, minus_right, m);
4252 }
4253 } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4254 SetPosGenMinExpr(right, left, m);
4255 } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4256 SetPosGenMinExpr(minus_right, minus_left, m);
4257 } else { // left->Min() < 0 && left->Max() > 0 &&
4258 // right->Min() < 0 && right->Max() > 0
4259 SetGenGenMinExpr(left, right, m);
4260 }
4261 }
4262
4263 class TimesIntExpr : public BaseIntExpr {
4264 public:
TimesIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)4265 TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4266 : BaseIntExpr(s),
4267 left_(l),
4268 right_(r),
4269 minus_left_(s->MakeOpposite(left_)),
4270 minus_right_(s->MakeOpposite(right_)) {}
~TimesIntExpr()4271 ~TimesIntExpr() override {}
Min() const4272 int64_t Min() const override {
4273 const int64_t lmin = left_->Min();
4274 const int64_t lmax = left_->Max();
4275 const int64_t rmin = right_->Min();
4276 const int64_t rmax = right_->Max();
4277 return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4278 std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4279 }
4280 void SetMin(int64_t m) override;
Max() const4281 int64_t Max() const override {
4282 const int64_t lmin = left_->Min();
4283 const int64_t lmax = left_->Max();
4284 const int64_t rmin = right_->Min();
4285 const int64_t rmax = right_->Max();
4286 return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4287 std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4288 }
4289 void SetMax(int64_t m) override;
4290 bool Bound() const override;
name() const4291 std::string name() const override {
4292 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4293 }
DebugString() const4294 std::string DebugString() const override {
4295 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4296 right_->DebugString());
4297 }
WhenRange(Demon * d)4298 void WhenRange(Demon* d) override {
4299 left_->WhenRange(d);
4300 right_->WhenRange(d);
4301 }
4302
Accept(ModelVisitor * const visitor) const4303 void Accept(ModelVisitor* const visitor) const override {
4304 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4305 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4306 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4307 right_);
4308 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4309 }
4310
4311 private:
4312 IntExpr* const left_;
4313 IntExpr* const right_;
4314 IntExpr* const minus_left_;
4315 IntExpr* const minus_right_;
4316 };
4317
SetMin(int64_t m)4318 void TimesIntExpr::SetMin(int64_t m) {
4319 if (m != std::numeric_limits<int64_t>::min()) {
4320 TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4321 }
4322 }
4323
SetMax(int64_t m)4324 void TimesIntExpr::SetMax(int64_t m) {
4325 if (m != std::numeric_limits<int64_t>::max()) {
4326 TimesSetMin(left_, minus_right_, minus_left_, right_, CapOpp(m));
4327 }
4328 }
4329
Bound() const4330 bool TimesIntExpr::Bound() const {
4331 const bool left_bound = left_->Bound();
4332 const bool right_bound = right_->Bound();
4333 return ((left_bound && left_->Max() == 0) ||
4334 (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4335 }
4336
4337 // ----- TimesPosIntExpr -----
4338
4339 class TimesPosIntExpr : public BaseIntExpr {
4340 public:
TimesPosIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)4341 TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4342 : BaseIntExpr(s), left_(l), right_(r) {}
~TimesPosIntExpr()4343 ~TimesPosIntExpr() override {}
Min() const4344 int64_t Min() const override { return (left_->Min() * right_->Min()); }
4345 void SetMin(int64_t m) override;
Max() const4346 int64_t Max() const override { return (left_->Max() * right_->Max()); }
4347 void SetMax(int64_t m) override;
4348 bool Bound() const override;
name() const4349 std::string name() const override {
4350 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4351 }
DebugString() const4352 std::string DebugString() const override {
4353 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4354 right_->DebugString());
4355 }
WhenRange(Demon * d)4356 void WhenRange(Demon* d) override {
4357 left_->WhenRange(d);
4358 right_->WhenRange(d);
4359 }
4360
Accept(ModelVisitor * const visitor) const4361 void Accept(ModelVisitor* const visitor) const override {
4362 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4363 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4364 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4365 right_);
4366 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4367 }
4368
4369 private:
4370 IntExpr* const left_;
4371 IntExpr* const right_;
4372 };
4373
SetMin(int64_t m)4374 void TimesPosIntExpr::SetMin(int64_t m) { SetPosPosMinExpr(left_, right_, m); }
4375
SetMax(int64_t m)4376 void TimesPosIntExpr::SetMax(int64_t m) { SetPosPosMaxExpr(left_, right_, m); }
4377
Bound() const4378 bool TimesPosIntExpr::Bound() const {
4379 return (left_->Max() == 0 || right_->Max() == 0 ||
4380 (left_->Bound() && right_->Bound()));
4381 }
4382
4383 // ----- SafeTimesPosIntExpr -----
4384
4385 class SafeTimesPosIntExpr : public BaseIntExpr {
4386 public:
SafeTimesPosIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)4387 SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4388 : BaseIntExpr(s), left_(l), right_(r) {}
~SafeTimesPosIntExpr()4389 ~SafeTimesPosIntExpr() override {}
Min() const4390 int64_t Min() const override { return CapProd(left_->Min(), right_->Min()); }
SetMin(int64_t m)4391 void SetMin(int64_t m) override {
4392 if (m != std::numeric_limits<int64_t>::min()) {
4393 SetPosPosMinExpr(left_, right_, m);
4394 }
4395 }
Max() const4396 int64_t Max() const override { return CapProd(left_->Max(), right_->Max()); }
SetMax(int64_t m)4397 void SetMax(int64_t m) override {
4398 if (m != std::numeric_limits<int64_t>::max()) {
4399 SetPosPosMaxExpr(left_, right_, m);
4400 }
4401 }
Bound() const4402 bool Bound() const override {
4403 return (left_->Max() == 0 || right_->Max() == 0 ||
4404 (left_->Bound() && right_->Bound()));
4405 }
name() const4406 std::string name() const override {
4407 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4408 }
DebugString() const4409 std::string DebugString() const override {
4410 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4411 right_->DebugString());
4412 }
WhenRange(Demon * d)4413 void WhenRange(Demon* d) override {
4414 left_->WhenRange(d);
4415 right_->WhenRange(d);
4416 }
4417
Accept(ModelVisitor * const visitor) const4418 void Accept(ModelVisitor* const visitor) const override {
4419 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4420 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4421 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4422 right_);
4423 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4424 }
4425
4426 private:
4427 IntExpr* const left_;
4428 IntExpr* const right_;
4429 };
4430
4431 // ----- TimesBooleanPosIntExpr -----
4432
4433 class TimesBooleanPosIntExpr : public BaseIntExpr {
4434 public:
TimesBooleanPosIntExpr(Solver * const s,BooleanVar * const b,IntExpr * const e)4435 TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4436 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
~TimesBooleanPosIntExpr()4437 ~TimesBooleanPosIntExpr() override {}
Min() const4438 int64_t Min() const override {
4439 return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4440 }
4441 void SetMin(int64_t m) override;
Max() const4442 int64_t Max() const override {
4443 return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4444 }
4445 void SetMax(int64_t m) override;
4446 void Range(int64_t* mi, int64_t* ma) override;
4447 void SetRange(int64_t mi, int64_t ma) override;
4448 bool Bound() const override;
name() const4449 std::string name() const override {
4450 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4451 }
DebugString() const4452 std::string DebugString() const override {
4453 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4454 expr_->DebugString());
4455 }
WhenRange(Demon * d)4456 void WhenRange(Demon* d) override {
4457 boolvar_->WhenRange(d);
4458 expr_->WhenRange(d);
4459 }
4460
Accept(ModelVisitor * const visitor) const4461 void Accept(ModelVisitor* const visitor) const override {
4462 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4463 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4464 boolvar_);
4465 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4466 expr_);
4467 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4468 }
4469
4470 private:
4471 BooleanVar* const boolvar_;
4472 IntExpr* const expr_;
4473 };
4474
SetMin(int64_t m)4475 void TimesBooleanPosIntExpr::SetMin(int64_t m) {
4476 if (m > 0) {
4477 boolvar_->SetValue(1);
4478 expr_->SetMin(m);
4479 }
4480 }
4481
SetMax(int64_t m)4482 void TimesBooleanPosIntExpr::SetMax(int64_t m) {
4483 if (m < 0) {
4484 solver()->Fail();
4485 }
4486 if (m < expr_->Min()) {
4487 boolvar_->SetValue(0);
4488 }
4489 if (boolvar_->RawValue() == 1) {
4490 expr_->SetMax(m);
4491 }
4492 }
4493
Range(int64_t * mi,int64_t * ma)4494 void TimesBooleanPosIntExpr::Range(int64_t* mi, int64_t* ma) {
4495 const int value = boolvar_->RawValue();
4496 if (value == 0) {
4497 *mi = 0;
4498 *ma = 0;
4499 } else if (value == 1) {
4500 expr_->Range(mi, ma);
4501 } else {
4502 *mi = 0;
4503 *ma = expr_->Max();
4504 }
4505 }
4506
SetRange(int64_t mi,int64_t ma)4507 void TimesBooleanPosIntExpr::SetRange(int64_t mi, int64_t ma) {
4508 if (ma < 0 || mi > ma) {
4509 solver()->Fail();
4510 }
4511 if (mi > 0) {
4512 boolvar_->SetValue(1);
4513 expr_->SetMin(mi);
4514 }
4515 if (ma < expr_->Min()) {
4516 boolvar_->SetValue(0);
4517 }
4518 if (boolvar_->RawValue() == 1) {
4519 expr_->SetMax(ma);
4520 }
4521 }
4522
Bound() const4523 bool TimesBooleanPosIntExpr::Bound() const {
4524 return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4525 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4526 expr_->Bound()));
4527 }
4528
4529 // ----- TimesBooleanIntExpr -----
4530
4531 class TimesBooleanIntExpr : public BaseIntExpr {
4532 public:
TimesBooleanIntExpr(Solver * const s,BooleanVar * const b,IntExpr * const e)4533 TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4534 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
~TimesBooleanIntExpr()4535 ~TimesBooleanIntExpr() override {}
Min() const4536 int64_t Min() const override {
4537 switch (boolvar_->RawValue()) {
4538 case 0: {
4539 return 0LL;
4540 }
4541 case 1: {
4542 return expr_->Min();
4543 }
4544 default: {
4545 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4546 return std::min(int64_t{0}, expr_->Min());
4547 }
4548 }
4549 }
4550 void SetMin(int64_t m) override;
Max() const4551 int64_t Max() const override {
4552 switch (boolvar_->RawValue()) {
4553 case 0: {
4554 return 0LL;
4555 }
4556 case 1: {
4557 return expr_->Max();
4558 }
4559 default: {
4560 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4561 return std::max(int64_t{0}, expr_->Max());
4562 }
4563 }
4564 }
4565 void SetMax(int64_t m) override;
4566 void Range(int64_t* mi, int64_t* ma) override;
4567 void SetRange(int64_t mi, int64_t ma) override;
4568 bool Bound() const override;
name() const4569 std::string name() const override {
4570 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4571 }
DebugString() const4572 std::string DebugString() const override {
4573 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4574 expr_->DebugString());
4575 }
WhenRange(Demon * d)4576 void WhenRange(Demon* d) override {
4577 boolvar_->WhenRange(d);
4578 expr_->WhenRange(d);
4579 }
4580
Accept(ModelVisitor * const visitor) const4581 void Accept(ModelVisitor* const visitor) const override {
4582 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4583 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4584 boolvar_);
4585 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4586 expr_);
4587 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4588 }
4589
4590 private:
4591 BooleanVar* const boolvar_;
4592 IntExpr* const expr_;
4593 };
4594
SetMin(int64_t m)4595 void TimesBooleanIntExpr::SetMin(int64_t m) {
4596 switch (boolvar_->RawValue()) {
4597 case 0: {
4598 if (m > 0) {
4599 solver()->Fail();
4600 }
4601 break;
4602 }
4603 case 1: {
4604 expr_->SetMin(m);
4605 break;
4606 }
4607 default: {
4608 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4609 if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4610 boolvar_->SetValue(1);
4611 expr_->SetMin(m);
4612 } else if (m <= 0 && expr_->Max() < m) {
4613 boolvar_->SetValue(0);
4614 }
4615 }
4616 }
4617 }
4618
SetMax(int64_t m)4619 void TimesBooleanIntExpr::SetMax(int64_t m) {
4620 switch (boolvar_->RawValue()) {
4621 case 0: {
4622 if (m < 0) {
4623 solver()->Fail();
4624 }
4625 break;
4626 }
4627 case 1: {
4628 expr_->SetMax(m);
4629 break;
4630 }
4631 default: {
4632 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4633 if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4634 boolvar_->SetValue(1);
4635 expr_->SetMax(m);
4636 } else if (m >= 0 && expr_->Min() > m) {
4637 boolvar_->SetValue(0);
4638 }
4639 }
4640 }
4641 }
4642
Range(int64_t * mi,int64_t * ma)4643 void TimesBooleanIntExpr::Range(int64_t* mi, int64_t* ma) {
4644 switch (boolvar_->RawValue()) {
4645 case 0: {
4646 *mi = 0;
4647 *ma = 0;
4648 break;
4649 }
4650 case 1: {
4651 *mi = expr_->Min();
4652 *ma = expr_->Max();
4653 break;
4654 }
4655 default: {
4656 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4657 *mi = std::min(int64_t{0}, expr_->Min());
4658 *ma = std::max(int64_t{0}, expr_->Max());
4659 break;
4660 }
4661 }
4662 }
4663
SetRange(int64_t mi,int64_t ma)4664 void TimesBooleanIntExpr::SetRange(int64_t mi, int64_t ma) {
4665 if (mi > ma) {
4666 solver()->Fail();
4667 }
4668 switch (boolvar_->RawValue()) {
4669 case 0: {
4670 if (mi > 0 || ma < 0) {
4671 solver()->Fail();
4672 }
4673 break;
4674 }
4675 case 1: {
4676 expr_->SetRange(mi, ma);
4677 break;
4678 }
4679 default: {
4680 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4681 if (mi > 0) {
4682 boolvar_->SetValue(1);
4683 expr_->SetMin(mi);
4684 } else if (mi == 0 && expr_->Max() < 0) {
4685 boolvar_->SetValue(0);
4686 }
4687 if (ma < 0) {
4688 boolvar_->SetValue(1);
4689 expr_->SetMax(ma);
4690 } else if (ma == 0 && expr_->Min() > 0) {
4691 boolvar_->SetValue(0);
4692 }
4693 break;
4694 }
4695 }
4696 }
4697
Bound() const4698 bool TimesBooleanIntExpr::Bound() const {
4699 return (boolvar_->RawValue() == 0 ||
4700 (expr_->Bound() &&
4701 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4702 expr_->Max() == 0)));
4703 }
4704
4705 // ----- DivPosIntCstExpr -----
4706
4707 class DivPosIntCstExpr : public BaseIntExpr {
4708 public:
DivPosIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)4709 DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4710 : BaseIntExpr(s), expr_(e), value_(v) {
4711 CHECK_GE(v, 0);
4712 }
~DivPosIntCstExpr()4713 ~DivPosIntCstExpr() override {}
4714
Min() const4715 int64_t Min() const override { return expr_->Min() / value_; }
4716
SetMin(int64_t m)4717 void SetMin(int64_t m) override {
4718 if (m > 0) {
4719 expr_->SetMin(m * value_);
4720 } else {
4721 expr_->SetMin((m - 1) * value_ + 1);
4722 }
4723 }
Max() const4724 int64_t Max() const override { return expr_->Max() / value_; }
4725
SetMax(int64_t m)4726 void SetMax(int64_t m) override {
4727 if (m >= 0) {
4728 expr_->SetMax((m + 1) * value_ - 1);
4729 } else {
4730 expr_->SetMax(m * value_);
4731 }
4732 }
4733
name() const4734 std::string name() const override {
4735 return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4736 }
4737
DebugString() const4738 std::string DebugString() const override {
4739 return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4740 }
4741
WhenRange(Demon * d)4742 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4743
Accept(ModelVisitor * const visitor) const4744 void Accept(ModelVisitor* const visitor) const override {
4745 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4746 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4747 expr_);
4748 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4749 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4750 }
4751
4752 private:
4753 IntExpr* const expr_;
4754 const int64_t value_;
4755 };
4756
4757 // DivPosIntExpr
4758
4759 class DivPosIntExpr : public BaseIntExpr {
4760 public:
DivPosIntExpr(Solver * const s,IntExpr * const num,IntExpr * const denom)4761 DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4762 : BaseIntExpr(s),
4763 num_(num),
4764 denom_(denom),
4765 opp_num_(s->MakeOpposite(num)) {}
4766
~DivPosIntExpr()4767 ~DivPosIntExpr() override {}
4768
Min() const4769 int64_t Min() const override {
4770 return num_->Min() >= 0
4771 ? num_->Min() / denom_->Max()
4772 : (denom_->Min() == 0 ? num_->Min()
4773 : num_->Min() / denom_->Min());
4774 }
4775
Max() const4776 int64_t Max() const override {
4777 return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4778 : num_->Max() / denom_->Min())
4779 : num_->Max() / denom_->Max();
4780 }
4781
SetPosMin(IntExpr * const num,IntExpr * const denom,int64_t m)4782 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4783 num->SetMin(m * denom->Min());
4784 denom->SetMax(num->Max() / m);
4785 }
4786
SetPosMax(IntExpr * const num,IntExpr * const denom,int64_t m)4787 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
4788 num->SetMax((m + 1) * denom->Max() - 1);
4789 denom->SetMin(num->Min() / (m + 1) + 1);
4790 }
4791
SetMin(int64_t m)4792 void SetMin(int64_t m) override {
4793 if (m > 0) {
4794 SetPosMin(num_, denom_, m);
4795 } else {
4796 SetPosMax(opp_num_, denom_, -m);
4797 }
4798 }
4799
SetMax(int64_t m)4800 void SetMax(int64_t m) override {
4801 if (m >= 0) {
4802 SetPosMax(num_, denom_, m);
4803 } else {
4804 SetPosMin(opp_num_, denom_, -m);
4805 }
4806 }
4807
name() const4808 std::string name() const override {
4809 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4810 }
DebugString() const4811 std::string DebugString() const override {
4812 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4813 denom_->DebugString());
4814 }
WhenRange(Demon * d)4815 void WhenRange(Demon* d) override {
4816 num_->WhenRange(d);
4817 denom_->WhenRange(d);
4818 }
4819
Accept(ModelVisitor * const visitor) const4820 void Accept(ModelVisitor* const visitor) const override {
4821 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4822 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4823 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4824 denom_);
4825 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4826 }
4827
4828 private:
4829 IntExpr* const num_;
4830 IntExpr* const denom_;
4831 IntExpr* const opp_num_;
4832 };
4833
4834 class DivPosPosIntExpr : public BaseIntExpr {
4835 public:
DivPosPosIntExpr(Solver * const s,IntExpr * const num,IntExpr * const denom)4836 DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4837 : BaseIntExpr(s), num_(num), denom_(denom) {}
4838
~DivPosPosIntExpr()4839 ~DivPosPosIntExpr() override {}
4840
Min() const4841 int64_t Min() const override {
4842 if (denom_->Max() == 0) {
4843 solver()->Fail();
4844 }
4845 return num_->Min() / denom_->Max();
4846 }
4847
Max() const4848 int64_t Max() const override {
4849 if (denom_->Min() == 0) {
4850 return num_->Max();
4851 } else {
4852 return num_->Max() / denom_->Min();
4853 }
4854 }
4855
SetMin(int64_t m)4856 void SetMin(int64_t m) override {
4857 if (m > 0) {
4858 num_->SetMin(m * denom_->Min());
4859 denom_->SetMax(num_->Max() / m);
4860 }
4861 }
4862
SetMax(int64_t m)4863 void SetMax(int64_t m) override {
4864 if (m >= 0) {
4865 num_->SetMax((m + 1) * denom_->Max() - 1);
4866 denom_->SetMin(num_->Min() / (m + 1) + 1);
4867 } else {
4868 solver()->Fail();
4869 }
4870 }
4871
name() const4872 std::string name() const override {
4873 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4874 }
4875
DebugString() const4876 std::string DebugString() const override {
4877 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4878 denom_->DebugString());
4879 }
4880
WhenRange(Demon * d)4881 void WhenRange(Demon* d) override {
4882 num_->WhenRange(d);
4883 denom_->WhenRange(d);
4884 }
4885
Accept(ModelVisitor * const visitor) const4886 void Accept(ModelVisitor* const visitor) const override {
4887 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4888 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4889 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4890 denom_);
4891 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4892 }
4893
4894 private:
4895 IntExpr* const num_;
4896 IntExpr* const denom_;
4897 };
4898
4899 // DivIntExpr
4900
4901 class DivIntExpr : public BaseIntExpr {
4902 public:
DivIntExpr(Solver * const s,IntExpr * const num,IntExpr * const denom)4903 DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4904 : BaseIntExpr(s),
4905 num_(num),
4906 denom_(denom),
4907 opp_num_(s->MakeOpposite(num)) {}
4908
~DivIntExpr()4909 ~DivIntExpr() override {}
4910
Min() const4911 int64_t Min() const override {
4912 const int64_t num_min = num_->Min();
4913 const int64_t num_max = num_->Max();
4914 const int64_t denom_min = denom_->Min();
4915 const int64_t denom_max = denom_->Max();
4916
4917 if (denom_min == 0 && denom_max == 0) {
4918 return std::numeric_limits<int64_t>::max(); // TODO(user): Check this
4919 // convention.
4920 }
4921
4922 if (denom_min >= 0) { // Denominator strictly positive.
4923 DCHECK_GT(denom_max, 0);
4924 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4925 return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4926 } else if (denom_max <= 0) { // Denominator strictly negative.
4927 DCHECK_LT(denom_min, 0);
4928 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4929 return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4930 } else { // Denominator across 0.
4931 return std::min(num_min, -num_max);
4932 }
4933 }
4934
Max() const4935 int64_t Max() const override {
4936 const int64_t num_min = num_->Min();
4937 const int64_t num_max = num_->Max();
4938 const int64_t denom_min = denom_->Min();
4939 const int64_t denom_max = denom_->Max();
4940
4941 if (denom_min == 0 && denom_max == 0) {
4942 return std::numeric_limits<int64_t>::min(); // TODO(user): Check this
4943 // convention.
4944 }
4945
4946 if (denom_min >= 0) { // Denominator strictly positive.
4947 DCHECK_GT(denom_max, 0);
4948 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4949 return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4950 } else if (denom_max <= 0) { // Denominator strictly negative.
4951 DCHECK_LT(denom_min, 0);
4952 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4953 return num_min >= 0 ? num_min / denom_min
4954 : -num_min / -adjusted_denom_max;
4955 } else { // Denominator across 0.
4956 return std::max(num_max, -num_min);
4957 }
4958 }
4959
AdjustDenominator()4960 void AdjustDenominator() {
4961 if (denom_->Min() == 0) {
4962 denom_->SetMin(1);
4963 } else if (denom_->Max() == 0) {
4964 denom_->SetMax(-1);
4965 }
4966 }
4967
4968 // m > 0.
SetPosMin(IntExpr * const num,IntExpr * const denom,int64_t m)4969 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4970 DCHECK_GT(m, 0);
4971 const int64_t num_min = num->Min();
4972 const int64_t num_max = num->Max();
4973 const int64_t denom_min = denom->Min();
4974 const int64_t denom_max = denom->Max();
4975 DCHECK_NE(denom_min, 0);
4976 DCHECK_NE(denom_max, 0);
4977 if (denom_min > 0) { // Denominator strictly positive.
4978 num->SetMin(m * denom_min);
4979 denom->SetMax(num_max / m);
4980 } else if (denom_max < 0) { // Denominator strictly negative.
4981 num->SetMax(m * denom_max);
4982 denom->SetMin(num_min / m);
4983 } else { // Denominator across 0.
4984 if (num_min >= 0) {
4985 num->SetMin(m);
4986 denom->SetRange(1, num_max / m);
4987 } else if (num_max <= 0) {
4988 num->SetMax(-m);
4989 denom->SetRange(num_min / m, -1);
4990 } else {
4991 if (m > -num_min) { // Denominator is forced positive.
4992 num->SetMin(m);
4993 denom->SetRange(1, num_max / m);
4994 } else if (m > num_max) { // Denominator is forced negative.
4995 num->SetMax(-m);
4996 denom->SetRange(num_min / m, -1);
4997 } else {
4998 denom->SetRange(num_min / m, num_max / m);
4999 }
5000 }
5001 }
5002 }
5003
5004 // m >= 0.
SetPosMax(IntExpr * const num,IntExpr * const denom,int64_t m)5005 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
5006 DCHECK_GE(m, 0);
5007 const int64_t num_min = num->Min();
5008 const int64_t num_max = num->Max();
5009 const int64_t denom_min = denom->Min();
5010 const int64_t denom_max = denom->Max();
5011 DCHECK_NE(denom_min, 0);
5012 DCHECK_NE(denom_max, 0);
5013 if (denom_min > 0) { // Denominator strictly positive.
5014 num->SetMax((m + 1) * denom_max - 1);
5015 denom->SetMin((num_min / (m + 1)) + 1);
5016 } else if (denom_max < 0) {
5017 num->SetMin((m + 1) * denom_min + 1);
5018 denom->SetMax(num_max / (m + 1) - 1);
5019 } else if (num_min > (m + 1) * denom_max - 1) {
5020 denom->SetMax(-1);
5021 } else if (num_max < (m + 1) * denom_min + 1) {
5022 denom->SetMin(1);
5023 }
5024 }
5025
SetMin(int64_t m)5026 void SetMin(int64_t m) override {
5027 AdjustDenominator();
5028 if (m > 0) {
5029 SetPosMin(num_, denom_, m);
5030 } else {
5031 SetPosMax(opp_num_, denom_, -m);
5032 }
5033 }
5034
SetMax(int64_t m)5035 void SetMax(int64_t m) override {
5036 AdjustDenominator();
5037 if (m >= 0) {
5038 SetPosMax(num_, denom_, m);
5039 } else {
5040 SetPosMin(opp_num_, denom_, -m);
5041 }
5042 }
5043
name() const5044 std::string name() const override {
5045 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5046 }
DebugString() const5047 std::string DebugString() const override {
5048 return absl::StrFormat("(%s div %s)", num_->DebugString(),
5049 denom_->DebugString());
5050 }
WhenRange(Demon * d)5051 void WhenRange(Demon* d) override {
5052 num_->WhenRange(d);
5053 denom_->WhenRange(d);
5054 }
5055
Accept(ModelVisitor * const visitor) const5056 void Accept(ModelVisitor* const visitor) const override {
5057 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5058 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5059 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5060 denom_);
5061 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5062 }
5063
5064 private:
5065 IntExpr* const num_;
5066 IntExpr* const denom_;
5067 IntExpr* const opp_num_;
5068 };
5069
5070 // ----- IntAbs And IntAbsConstraint ------
5071
5072 class IntAbsConstraint : public CastConstraint {
5073 public:
IntAbsConstraint(Solver * const s,IntVar * const sub,IntVar * const target)5074 IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5075 : CastConstraint(s, target), sub_(sub) {}
5076
~IntAbsConstraint()5077 ~IntAbsConstraint() override {}
5078
Post()5079 void Post() override {
5080 Demon* const sub_demon = MakeConstraintDemon0(
5081 solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5082 sub_->WhenRange(sub_demon);
5083 Demon* const target_demon = MakeConstraintDemon0(
5084 solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5085 target_var_->WhenRange(target_demon);
5086 }
5087
InitialPropagate()5088 void InitialPropagate() override {
5089 PropagateSub();
5090 PropagateTarget();
5091 }
5092
PropagateSub()5093 void PropagateSub() {
5094 const int64_t smin = sub_->Min();
5095 const int64_t smax = sub_->Max();
5096 if (smax <= 0) {
5097 target_var_->SetRange(-smax, -smin);
5098 } else if (smin >= 0) {
5099 target_var_->SetRange(smin, smax);
5100 } else {
5101 target_var_->SetRange(0, std::max(-smin, smax));
5102 }
5103 }
5104
PropagateTarget()5105 void PropagateTarget() {
5106 const int64_t target_max = target_var_->Max();
5107 sub_->SetRange(-target_max, target_max);
5108 const int64_t target_min = target_var_->Min();
5109 if (target_min > 0) {
5110 if (sub_->Min() > -target_min) {
5111 sub_->SetMin(target_min);
5112 } else if (sub_->Max() < target_min) {
5113 sub_->SetMax(-target_min);
5114 }
5115 }
5116 }
5117
DebugString() const5118 std::string DebugString() const override {
5119 return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5120 target_var_->DebugString());
5121 }
5122
Accept(ModelVisitor * const visitor) const5123 void Accept(ModelVisitor* const visitor) const override {
5124 visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5125 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5126 sub_);
5127 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5128 target_var_);
5129 visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5130 }
5131
5132 private:
5133 IntVar* const sub_;
5134 };
5135
5136 class IntAbs : public BaseIntExpr {
5137 public:
IntAbs(Solver * const s,IntExpr * const e)5138 IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5139
~IntAbs()5140 ~IntAbs() override {}
5141
Min() const5142 int64_t Min() const override {
5143 int64_t emin = 0;
5144 int64_t emax = 0;
5145 expr_->Range(&emin, &emax);
5146 if (emin >= 0) {
5147 return emin;
5148 }
5149 if (emax <= 0) {
5150 return -emax;
5151 }
5152 return 0;
5153 }
5154
SetMin(int64_t m)5155 void SetMin(int64_t m) override {
5156 if (m > 0) {
5157 int64_t emin = 0;
5158 int64_t emax = 0;
5159 expr_->Range(&emin, &emax);
5160 if (emin > -m) {
5161 expr_->SetMin(m);
5162 } else if (emax < m) {
5163 expr_->SetMax(-m);
5164 }
5165 }
5166 }
5167
Max() const5168 int64_t Max() const override {
5169 int64_t emin = 0;
5170 int64_t emax = 0;
5171 expr_->Range(&emin, &emax);
5172 return std::max(-emin, emax);
5173 }
5174
SetMax(int64_t m)5175 void SetMax(int64_t m) override { expr_->SetRange(-m, m); }
5176
SetRange(int64_t mi,int64_t ma)5177 void SetRange(int64_t mi, int64_t ma) override {
5178 expr_->SetRange(-ma, ma);
5179 if (mi > 0) {
5180 int64_t emin = 0;
5181 int64_t emax = 0;
5182 expr_->Range(&emin, &emax);
5183 if (emin > -mi) {
5184 expr_->SetMin(mi);
5185 } else if (emax < mi) {
5186 expr_->SetMax(-mi);
5187 }
5188 }
5189 }
5190
Range(int64_t * mi,int64_t * ma)5191 void Range(int64_t* mi, int64_t* ma) override {
5192 int64_t emin = 0;
5193 int64_t emax = 0;
5194 expr_->Range(&emin, &emax);
5195 if (emin >= 0) {
5196 *mi = emin;
5197 *ma = emax;
5198 } else if (emax <= 0) {
5199 *mi = -emax;
5200 *ma = -emin;
5201 } else {
5202 *mi = 0;
5203 *ma = std::max(-emin, emax);
5204 }
5205 }
5206
Bound() const5207 bool Bound() const override { return expr_->Bound(); }
5208
WhenRange(Demon * d)5209 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5210
name() const5211 std::string name() const override {
5212 return absl::StrFormat("IntAbs(%s)", expr_->name());
5213 }
5214
DebugString() const5215 std::string DebugString() const override {
5216 return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5217 }
5218
Accept(ModelVisitor * const visitor) const5219 void Accept(ModelVisitor* const visitor) const override {
5220 visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5221 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5222 expr_);
5223 visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5224 }
5225
CastToVar()5226 IntVar* CastToVar() override {
5227 int64_t min_value = 0;
5228 int64_t max_value = 0;
5229 Range(&min_value, &max_value);
5230 Solver* const s = solver();
5231 const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5232 IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5233 CastConstraint* const ct =
5234 s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5235 s->AddCastConstraint(ct, target, this);
5236 return target;
5237 }
5238
5239 private:
5240 IntExpr* const expr_;
5241 };
5242
5243 // ----- Square -----
5244
5245 // TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5246 class IntSquare : public BaseIntExpr {
5247 public:
IntSquare(Solver * const s,IntExpr * const e)5248 IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
~IntSquare()5249 ~IntSquare() override {}
5250
Min() const5251 int64_t Min() const override {
5252 const int64_t emin = expr_->Min();
5253 if (emin >= 0) {
5254 return emin >= std::numeric_limits<int32_t>::max()
5255 ? std::numeric_limits<int64_t>::max()
5256 : emin * emin;
5257 }
5258 const int64_t emax = expr_->Max();
5259 if (emax < 0) {
5260 return emax <= -std::numeric_limits<int32_t>::max()
5261 ? std::numeric_limits<int64_t>::max()
5262 : emax * emax;
5263 }
5264 return 0LL;
5265 }
SetMin(int64_t m)5266 void SetMin(int64_t m) override {
5267 if (m <= 0) {
5268 return;
5269 }
5270 // TODO(user): What happens if m is kint64max?
5271 const int64_t emin = expr_->Min();
5272 const int64_t emax = expr_->Max();
5273 const int64_t root =
5274 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5275 if (emin >= 0) {
5276 expr_->SetMin(root);
5277 } else if (emax <= 0) {
5278 expr_->SetMax(-root);
5279 } else if (expr_->IsVar()) {
5280 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5281 }
5282 }
Max() const5283 int64_t Max() const override {
5284 const int64_t emax = expr_->Max();
5285 const int64_t emin = expr_->Min();
5286 if (emax >= std::numeric_limits<int32_t>::max() ||
5287 emin <= -std::numeric_limits<int32_t>::max()) {
5288 return std::numeric_limits<int64_t>::max();
5289 }
5290 return std::max(emin * emin, emax * emax);
5291 }
SetMax(int64_t m)5292 void SetMax(int64_t m) override {
5293 if (m < 0) {
5294 solver()->Fail();
5295 }
5296 if (m == std::numeric_limits<int64_t>::max()) {
5297 return;
5298 }
5299 const int64_t root =
5300 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5301 expr_->SetRange(-root, root);
5302 }
Bound() const5303 bool Bound() const override { return expr_->Bound(); }
WhenRange(Demon * d)5304 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
name() const5305 std::string name() const override {
5306 return absl::StrFormat("IntSquare(%s)", expr_->name());
5307 }
DebugString() const5308 std::string DebugString() const override {
5309 return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5310 }
5311
Accept(ModelVisitor * const visitor) const5312 void Accept(ModelVisitor* const visitor) const override {
5313 visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5314 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5315 expr_);
5316 visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5317 }
5318
expr() const5319 IntExpr* expr() const { return expr_; }
5320
5321 protected:
5322 IntExpr* const expr_;
5323 };
5324
5325 class PosIntSquare : public IntSquare {
5326 public:
PosIntSquare(Solver * const s,IntExpr * const e)5327 PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
~PosIntSquare()5328 ~PosIntSquare() override {}
5329
Min() const5330 int64_t Min() const override {
5331 const int64_t emin = expr_->Min();
5332 return emin >= std::numeric_limits<int32_t>::max()
5333 ? std::numeric_limits<int64_t>::max()
5334 : emin * emin;
5335 }
SetMin(int64_t m)5336 void SetMin(int64_t m) override {
5337 if (m <= 0) {
5338 return;
5339 }
5340 const int64_t root =
5341 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5342 expr_->SetMin(root);
5343 }
Max() const5344 int64_t Max() const override {
5345 const int64_t emax = expr_->Max();
5346 return emax >= std::numeric_limits<int32_t>::max()
5347 ? std::numeric_limits<int64_t>::max()
5348 : emax * emax;
5349 }
SetMax(int64_t m)5350 void SetMax(int64_t m) override {
5351 if (m < 0) {
5352 solver()->Fail();
5353 }
5354 if (m == std::numeric_limits<int64_t>::max()) {
5355 return;
5356 }
5357 const int64_t root =
5358 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5359 expr_->SetMax(root);
5360 }
5361 };
5362
5363 // ----- EvenPower -----
5364
IntPower(int64_t value,int64_t power)5365 int64_t IntPower(int64_t value, int64_t power) {
5366 int64_t result = value;
5367 // TODO(user): Speed that up.
5368 for (int i = 1; i < power; ++i) {
5369 result *= value;
5370 }
5371 return result;
5372 }
5373
OverflowLimit(int64_t power)5374 int64_t OverflowLimit(int64_t power) {
5375 return static_cast<int64_t>(floor(exp(
5376 log(static_cast<double>(std::numeric_limits<int64_t>::max())) / power)));
5377 }
5378
5379 class BasePower : public BaseIntExpr {
5380 public:
BasePower(Solver * const s,IntExpr * const e,int64_t n)5381 BasePower(Solver* const s, IntExpr* const e, int64_t n)
5382 : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5383 CHECK_GT(n, 0);
5384 }
5385
~BasePower()5386 ~BasePower() override {}
5387
Bound() const5388 bool Bound() const override { return expr_->Bound(); }
5389
expr() const5390 IntExpr* expr() const { return expr_; }
5391
exponant() const5392 int64_t exponant() const { return pow_; }
5393
WhenRange(Demon * d)5394 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5395
name() const5396 std::string name() const override {
5397 return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5398 }
5399
DebugString() const5400 std::string DebugString() const override {
5401 return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5402 }
5403
Accept(ModelVisitor * const visitor) const5404 void Accept(ModelVisitor* const visitor) const override {
5405 visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5406 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5407 expr_);
5408 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5409 visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5410 }
5411
5412 protected:
Pown(int64_t value) const5413 int64_t Pown(int64_t value) const {
5414 if (value >= limit_) {
5415 return std::numeric_limits<int64_t>::max();
5416 }
5417 if (value <= -limit_) {
5418 if (pow_ % 2 == 0) {
5419 return std::numeric_limits<int64_t>::max();
5420 } else {
5421 return std::numeric_limits<int64_t>::min();
5422 }
5423 }
5424 return IntPower(value, pow_);
5425 }
5426
SqrnDown(int64_t value) const5427 int64_t SqrnDown(int64_t value) const {
5428 if (value == std::numeric_limits<int64_t>::min()) {
5429 return std::numeric_limits<int64_t>::min();
5430 }
5431 if (value == std::numeric_limits<int64_t>::max()) {
5432 return std::numeric_limits<int64_t>::max();
5433 }
5434 int64_t res = 0;
5435 const double d_value = static_cast<double>(value);
5436 if (value >= 0) {
5437 const double sq = exp(log(d_value) / pow_);
5438 res = static_cast<int64_t>(floor(sq));
5439 } else {
5440 CHECK_EQ(1, pow_ % 2);
5441 const double sq = exp(log(-d_value) / pow_);
5442 res = -static_cast<int64_t>(ceil(sq));
5443 }
5444 const int64_t pow_res = Pown(res + 1);
5445 if (pow_res <= value) {
5446 return res + 1;
5447 } else {
5448 return res;
5449 }
5450 }
5451
SqrnUp(int64_t value) const5452 int64_t SqrnUp(int64_t value) const {
5453 if (value == std::numeric_limits<int64_t>::min()) {
5454 return std::numeric_limits<int64_t>::min();
5455 }
5456 if (value == std::numeric_limits<int64_t>::max()) {
5457 return std::numeric_limits<int64_t>::max();
5458 }
5459 int64_t res = 0;
5460 const double d_value = static_cast<double>(value);
5461 if (value >= 0) {
5462 const double sq = exp(log(d_value) / pow_);
5463 res = static_cast<int64_t>(ceil(sq));
5464 } else {
5465 CHECK_EQ(1, pow_ % 2);
5466 const double sq = exp(log(-d_value) / pow_);
5467 res = -static_cast<int64_t>(floor(sq));
5468 }
5469 const int64_t pow_res = Pown(res - 1);
5470 if (pow_res >= value) {
5471 return res - 1;
5472 } else {
5473 return res;
5474 }
5475 }
5476
5477 IntExpr* const expr_;
5478 const int64_t pow_;
5479 const int64_t limit_;
5480 };
5481
5482 class IntEvenPower : public BasePower {
5483 public:
IntEvenPower(Solver * const s,IntExpr * const e,int64_t n)5484 IntEvenPower(Solver* const s, IntExpr* const e, int64_t n)
5485 : BasePower(s, e, n) {
5486 CHECK_EQ(0, n % 2);
5487 }
5488
~IntEvenPower()5489 ~IntEvenPower() override {}
5490
Min() const5491 int64_t Min() const override {
5492 int64_t emin = 0;
5493 int64_t emax = 0;
5494 expr_->Range(&emin, &emax);
5495 if (emin >= 0) {
5496 return Pown(emin);
5497 }
5498 if (emax < 0) {
5499 return Pown(emax);
5500 }
5501 return 0LL;
5502 }
SetMin(int64_t m)5503 void SetMin(int64_t m) override {
5504 if (m <= 0) {
5505 return;
5506 }
5507 int64_t emin = 0;
5508 int64_t emax = 0;
5509 expr_->Range(&emin, &emax);
5510 const int64_t root = SqrnUp(m);
5511 if (emin > -root) {
5512 expr_->SetMin(root);
5513 } else if (emax < root) {
5514 expr_->SetMax(-root);
5515 } else if (expr_->IsVar()) {
5516 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5517 }
5518 }
5519
Max() const5520 int64_t Max() const override {
5521 return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5522 }
5523
SetMax(int64_t m)5524 void SetMax(int64_t m) override {
5525 if (m < 0) {
5526 solver()->Fail();
5527 }
5528 if (m == std::numeric_limits<int64_t>::max()) {
5529 return;
5530 }
5531 const int64_t root = SqrnDown(m);
5532 expr_->SetRange(-root, root);
5533 }
5534 };
5535
5536 class PosIntEvenPower : public BasePower {
5537 public:
PosIntEvenPower(Solver * const s,IntExpr * const e,int64_t pow)5538 PosIntEvenPower(Solver* const s, IntExpr* const e, int64_t pow)
5539 : BasePower(s, e, pow) {
5540 CHECK_EQ(0, pow % 2);
5541 }
5542
~PosIntEvenPower()5543 ~PosIntEvenPower() override {}
5544
Min() const5545 int64_t Min() const override { return Pown(expr_->Min()); }
5546
SetMin(int64_t m)5547 void SetMin(int64_t m) override {
5548 if (m <= 0) {
5549 return;
5550 }
5551 expr_->SetMin(SqrnUp(m));
5552 }
Max() const5553 int64_t Max() const override { return Pown(expr_->Max()); }
5554
SetMax(int64_t m)5555 void SetMax(int64_t m) override {
5556 if (m < 0) {
5557 solver()->Fail();
5558 }
5559 if (m == std::numeric_limits<int64_t>::max()) {
5560 return;
5561 }
5562 expr_->SetMax(SqrnDown(m));
5563 }
5564 };
5565
5566 class IntOddPower : public BasePower {
5567 public:
IntOddPower(Solver * const s,IntExpr * const e,int64_t n)5568 IntOddPower(Solver* const s, IntExpr* const e, int64_t n)
5569 : BasePower(s, e, n) {
5570 CHECK_EQ(1, n % 2);
5571 }
5572
~IntOddPower()5573 ~IntOddPower() override {}
5574
Min() const5575 int64_t Min() const override { return Pown(expr_->Min()); }
5576
SetMin(int64_t m)5577 void SetMin(int64_t m) override { expr_->SetMin(SqrnUp(m)); }
5578
Max() const5579 int64_t Max() const override { return Pown(expr_->Max()); }
5580
SetMax(int64_t m)5581 void SetMax(int64_t m) override { expr_->SetMax(SqrnDown(m)); }
5582 };
5583
5584 // ----- Min(expr, expr) -----
5585
5586 class MinIntExpr : public BaseIntExpr {
5587 public:
MinIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)5588 MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5589 : BaseIntExpr(s), left_(l), right_(r) {}
~MinIntExpr()5590 ~MinIntExpr() override {}
Min() const5591 int64_t Min() const override {
5592 const int64_t lmin = left_->Min();
5593 const int64_t rmin = right_->Min();
5594 return std::min(lmin, rmin);
5595 }
SetMin(int64_t m)5596 void SetMin(int64_t m) override {
5597 left_->SetMin(m);
5598 right_->SetMin(m);
5599 }
Max() const5600 int64_t Max() const override {
5601 const int64_t lmax = left_->Max();
5602 const int64_t rmax = right_->Max();
5603 return std::min(lmax, rmax);
5604 }
SetMax(int64_t m)5605 void SetMax(int64_t m) override {
5606 if (left_->Min() > m) {
5607 right_->SetMax(m);
5608 }
5609 if (right_->Min() > m) {
5610 left_->SetMax(m);
5611 }
5612 }
name() const5613 std::string name() const override {
5614 return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5615 }
DebugString() const5616 std::string DebugString() const override {
5617 return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5618 right_->DebugString());
5619 }
WhenRange(Demon * d)5620 void WhenRange(Demon* d) override {
5621 left_->WhenRange(d);
5622 right_->WhenRange(d);
5623 }
5624
Accept(ModelVisitor * const visitor) const5625 void Accept(ModelVisitor* const visitor) const override {
5626 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5627 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5628 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5629 right_);
5630 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5631 }
5632
5633 private:
5634 IntExpr* const left_;
5635 IntExpr* const right_;
5636 };
5637
5638 // ----- Min(expr, constant) -----
5639
5640 class MinCstIntExpr : public BaseIntExpr {
5641 public:
MinCstIntExpr(Solver * const s,IntExpr * const e,int64_t v)5642 MinCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5643 : BaseIntExpr(s), expr_(e), value_(v) {}
5644
~MinCstIntExpr()5645 ~MinCstIntExpr() override {}
5646
Min() const5647 int64_t Min() const override { return std::min(expr_->Min(), value_); }
5648
SetMin(int64_t m)5649 void SetMin(int64_t m) override {
5650 if (m > value_) {
5651 solver()->Fail();
5652 }
5653 expr_->SetMin(m);
5654 }
5655
Max() const5656 int64_t Max() const override { return std::min(expr_->Max(), value_); }
5657
SetMax(int64_t m)5658 void SetMax(int64_t m) override {
5659 if (value_ > m) {
5660 expr_->SetMax(m);
5661 }
5662 }
5663
Bound() const5664 bool Bound() const override {
5665 return (expr_->Bound() || expr_->Min() >= value_);
5666 }
5667
name() const5668 std::string name() const override {
5669 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5670 }
5671
DebugString() const5672 std::string DebugString() const override {
5673 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5674 value_);
5675 }
5676
WhenRange(Demon * d)5677 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5678
Accept(ModelVisitor * const visitor) const5679 void Accept(ModelVisitor* const visitor) const override {
5680 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5681 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5682 expr_);
5683 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5684 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5685 }
5686
5687 private:
5688 IntExpr* const expr_;
5689 const int64_t value_;
5690 };
5691
5692 // ----- Max(expr, expr) -----
5693
5694 class MaxIntExpr : public BaseIntExpr {
5695 public:
MaxIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)5696 MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5697 : BaseIntExpr(s), left_(l), right_(r) {}
5698
~MaxIntExpr()5699 ~MaxIntExpr() override {}
5700
Min() const5701 int64_t Min() const override { return std::max(left_->Min(), right_->Min()); }
5702
SetMin(int64_t m)5703 void SetMin(int64_t m) override {
5704 if (left_->Max() < m) {
5705 right_->SetMin(m);
5706 } else {
5707 if (right_->Max() < m) {
5708 left_->SetMin(m);
5709 }
5710 }
5711 }
5712
Max() const5713 int64_t Max() const override { return std::max(left_->Max(), right_->Max()); }
5714
SetMax(int64_t m)5715 void SetMax(int64_t m) override {
5716 left_->SetMax(m);
5717 right_->SetMax(m);
5718 }
5719
name() const5720 std::string name() const override {
5721 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5722 }
5723
DebugString() const5724 std::string DebugString() const override {
5725 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5726 right_->DebugString());
5727 }
5728
WhenRange(Demon * d)5729 void WhenRange(Demon* d) override {
5730 left_->WhenRange(d);
5731 right_->WhenRange(d);
5732 }
5733
Accept(ModelVisitor * const visitor) const5734 void Accept(ModelVisitor* const visitor) const override {
5735 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5736 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5737 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5738 right_);
5739 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5740 }
5741
5742 private:
5743 IntExpr* const left_;
5744 IntExpr* const right_;
5745 };
5746
5747 // ----- Max(expr, constant) -----
5748
5749 class MaxCstIntExpr : public BaseIntExpr {
5750 public:
MaxCstIntExpr(Solver * const s,IntExpr * const e,int64_t v)5751 MaxCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5752 : BaseIntExpr(s), expr_(e), value_(v) {}
5753
~MaxCstIntExpr()5754 ~MaxCstIntExpr() override {}
5755
Min() const5756 int64_t Min() const override { return std::max(expr_->Min(), value_); }
5757
SetMin(int64_t m)5758 void SetMin(int64_t m) override {
5759 if (value_ < m) {
5760 expr_->SetMin(m);
5761 }
5762 }
5763
Max() const5764 int64_t Max() const override { return std::max(expr_->Max(), value_); }
5765
SetMax(int64_t m)5766 void SetMax(int64_t m) override {
5767 if (m < value_) {
5768 solver()->Fail();
5769 }
5770 expr_->SetMax(m);
5771 }
5772
Bound() const5773 bool Bound() const override {
5774 return (expr_->Bound() || expr_->Max() <= value_);
5775 }
5776
name() const5777 std::string name() const override {
5778 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5779 }
5780
DebugString() const5781 std::string DebugString() const override {
5782 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5783 value_);
5784 }
5785
WhenRange(Demon * d)5786 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5787
Accept(ModelVisitor * const visitor) const5788 void Accept(ModelVisitor* const visitor) const override {
5789 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5790 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5791 expr_);
5792 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5793 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5794 }
5795
5796 private:
5797 IntExpr* const expr_;
5798 const int64_t value_;
5799 };
5800
5801 // ----- Convex Piecewise -----
5802
5803 // This class is a very simple convex piecewise linear function. The
5804 // argument of the function is the expression. Between early_date and
5805 // late_date, the value of the function is 0. Before early date, it
5806 // is affine and the cost is early_cost * (early_date - x). After
5807 // late_date, the cost is late_cost * (x - late_date).
5808
5809 class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5810 public:
SimpleConvexPiecewiseExpr(Solver * const s,IntExpr * const e,int64_t ec,int64_t ed,int64_t ld,int64_t lc)5811 SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64_t ec,
5812 int64_t ed, int64_t ld, int64_t lc)
5813 : BaseIntExpr(s),
5814 expr_(e),
5815 early_cost_(ec),
5816 early_date_(ec == 0 ? std::numeric_limits<int64_t>::min() : ed),
5817 late_date_(lc == 0 ? std::numeric_limits<int64_t>::max() : ld),
5818 late_cost_(lc) {
5819 DCHECK_GE(ec, int64_t{0});
5820 DCHECK_GE(lc, int64_t{0});
5821 DCHECK_GE(ld, ed);
5822
5823 // If the penalty is 0, we can push the "confort zone or zone
5824 // of no cost towards infinity.
5825 }
5826
~SimpleConvexPiecewiseExpr()5827 ~SimpleConvexPiecewiseExpr() override {}
5828
Min() const5829 int64_t Min() const override {
5830 const int64_t vmin = expr_->Min();
5831 const int64_t vmax = expr_->Max();
5832 if (vmin >= late_date_) {
5833 return (vmin - late_date_) * late_cost_;
5834 } else if (vmax <= early_date_) {
5835 return (early_date_ - vmax) * early_cost_;
5836 } else {
5837 return 0LL;
5838 }
5839 }
5840
SetMin(int64_t m)5841 void SetMin(int64_t m) override {
5842 if (m <= 0) {
5843 return;
5844 }
5845 int64_t vmin = 0;
5846 int64_t vmax = 0;
5847 expr_->Range(&vmin, &vmax);
5848
5849 const int64_t rb =
5850 (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5851 const int64_t lb =
5852 (early_cost_ == 0 ? vmin
5853 : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5854
5855 if (expr_->IsVar()) {
5856 expr_->Var()->RemoveInterval(lb, rb);
5857 }
5858 }
5859
Max() const5860 int64_t Max() const override {
5861 const int64_t vmin = expr_->Min();
5862 const int64_t vmax = expr_->Max();
5863 const int64_t mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5864 const int64_t ml =
5865 vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5866 return std::max(mr, ml);
5867 }
5868
SetMax(int64_t m)5869 void SetMax(int64_t m) override {
5870 if (m < 0) {
5871 solver()->Fail();
5872 }
5873 if (late_cost_ != 0LL) {
5874 const int64_t rb = late_date_ + PosIntDivDown(m, late_cost_);
5875 if (early_cost_ != 0LL) {
5876 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5877 expr_->SetRange(lb, rb);
5878 } else {
5879 expr_->SetMax(rb);
5880 }
5881 } else {
5882 if (early_cost_ != 0LL) {
5883 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5884 expr_->SetMin(lb);
5885 }
5886 }
5887 }
5888
name() const5889 std::string name() const override {
5890 return absl::StrFormat(
5891 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5892 expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5893 }
5894
DebugString() const5895 std::string DebugString() const override {
5896 return absl::StrFormat(
5897 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5898 expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5899 }
5900
WhenRange(Demon * d)5901 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5902
Accept(ModelVisitor * const visitor) const5903 void Accept(ModelVisitor* const visitor) const override {
5904 visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5905 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5906 expr_);
5907 visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5908 early_cost_);
5909 visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5910 early_date_);
5911 visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5912 visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5913 visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5914 }
5915
5916 private:
5917 IntExpr* const expr_;
5918 const int64_t early_cost_;
5919 const int64_t early_date_;
5920 const int64_t late_date_;
5921 const int64_t late_cost_;
5922 };
5923
5924 // ----- Semi Continuous -----
5925
5926 class SemiContinuousExpr : public BaseIntExpr {
5927 public:
SemiContinuousExpr(Solver * const s,IntExpr * const e,int64_t fixed_charge,int64_t step)5928 SemiContinuousExpr(Solver* const s, IntExpr* const e, int64_t fixed_charge,
5929 int64_t step)
5930 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5931 DCHECK_GE(fixed_charge, int64_t{0});
5932 DCHECK_GT(step, int64_t{0});
5933 }
5934
~SemiContinuousExpr()5935 ~SemiContinuousExpr() override {}
5936
Value(int64_t x) const5937 int64_t Value(int64_t x) const {
5938 if (x <= 0) {
5939 return 0;
5940 } else {
5941 return CapAdd(fixed_charge_, CapProd(x, step_));
5942 }
5943 }
5944
Min() const5945 int64_t Min() const override { return Value(expr_->Min()); }
5946
SetMin(int64_t m)5947 void SetMin(int64_t m) override {
5948 if (m >= CapAdd(fixed_charge_, step_)) {
5949 const int64_t y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5950 expr_->SetMin(y);
5951 } else if (m > 0) {
5952 expr_->SetMin(1);
5953 }
5954 }
5955
Max() const5956 int64_t Max() const override { return Value(expr_->Max()); }
5957
SetMax(int64_t m)5958 void SetMax(int64_t m) override {
5959 if (m < 0) {
5960 solver()->Fail();
5961 }
5962 if (m == std::numeric_limits<int64_t>::max()) {
5963 return;
5964 }
5965 if (m < CapAdd(fixed_charge_, step_)) {
5966 expr_->SetMax(0);
5967 } else {
5968 const int64_t y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5969 expr_->SetMax(y);
5970 }
5971 }
5972
name() const5973 std::string name() const override {
5974 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5975 expr_->name(), fixed_charge_, step_);
5976 }
5977
DebugString() const5978 std::string DebugString() const override {
5979 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5980 expr_->DebugString(), fixed_charge_, step_);
5981 }
5982
WhenRange(Demon * d)5983 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5984
Accept(ModelVisitor * const visitor) const5985 void Accept(ModelVisitor* const visitor) const override {
5986 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5987 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5988 expr_);
5989 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
5990 fixed_charge_);
5991 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
5992 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5993 }
5994
5995 private:
5996 IntExpr* const expr_;
5997 const int64_t fixed_charge_;
5998 const int64_t step_;
5999 };
6000
6001 class SemiContinuousStepOneExpr : public BaseIntExpr {
6002 public:
SemiContinuousStepOneExpr(Solver * const s,IntExpr * const e,int64_t fixed_charge)6003 SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
6004 int64_t fixed_charge)
6005 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6006 DCHECK_GE(fixed_charge, int64_t{0});
6007 }
6008
~SemiContinuousStepOneExpr()6009 ~SemiContinuousStepOneExpr() override {}
6010
Value(int64_t x) const6011 int64_t Value(int64_t x) const {
6012 if (x <= 0) {
6013 return 0;
6014 } else {
6015 return fixed_charge_ + x;
6016 }
6017 }
6018
Min() const6019 int64_t Min() const override { return Value(expr_->Min()); }
6020
SetMin(int64_t m)6021 void SetMin(int64_t m) override {
6022 if (m >= fixed_charge_ + 1) {
6023 expr_->SetMin(m - fixed_charge_);
6024 } else if (m > 0) {
6025 expr_->SetMin(1);
6026 }
6027 }
6028
Max() const6029 int64_t Max() const override { return Value(expr_->Max()); }
6030
SetMax(int64_t m)6031 void SetMax(int64_t m) override {
6032 if (m < 0) {
6033 solver()->Fail();
6034 }
6035 if (m < fixed_charge_ + 1) {
6036 expr_->SetMax(0);
6037 } else {
6038 expr_->SetMax(m - fixed_charge_);
6039 }
6040 }
6041
name() const6042 std::string name() const override {
6043 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6044 expr_->name(), fixed_charge_);
6045 }
6046
DebugString() const6047 std::string DebugString() const override {
6048 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6049 expr_->DebugString(), fixed_charge_);
6050 }
6051
WhenRange(Demon * d)6052 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6053
Accept(ModelVisitor * const visitor) const6054 void Accept(ModelVisitor* const visitor) const override {
6055 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6056 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6057 expr_);
6058 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6059 fixed_charge_);
6060 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6061 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6062 }
6063
6064 private:
6065 IntExpr* const expr_;
6066 const int64_t fixed_charge_;
6067 };
6068
6069 class SemiContinuousStepZeroExpr : public BaseIntExpr {
6070 public:
SemiContinuousStepZeroExpr(Solver * const s,IntExpr * const e,int64_t fixed_charge)6071 SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6072 int64_t fixed_charge)
6073 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6074 DCHECK_GT(fixed_charge, int64_t{0});
6075 }
6076
~SemiContinuousStepZeroExpr()6077 ~SemiContinuousStepZeroExpr() override {}
6078
Value(int64_t x) const6079 int64_t Value(int64_t x) const {
6080 if (x <= 0) {
6081 return 0;
6082 } else {
6083 return fixed_charge_;
6084 }
6085 }
6086
Min() const6087 int64_t Min() const override { return Value(expr_->Min()); }
6088
SetMin(int64_t m)6089 void SetMin(int64_t m) override {
6090 if (m >= fixed_charge_) {
6091 solver()->Fail();
6092 } else if (m > 0) {
6093 expr_->SetMin(1);
6094 }
6095 }
6096
Max() const6097 int64_t Max() const override { return Value(expr_->Max()); }
6098
SetMax(int64_t m)6099 void SetMax(int64_t m) override {
6100 if (m < 0) {
6101 solver()->Fail();
6102 }
6103 if (m < fixed_charge_) {
6104 expr_->SetMax(0);
6105 }
6106 }
6107
name() const6108 std::string name() const override {
6109 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6110 expr_->name(), fixed_charge_);
6111 }
6112
DebugString() const6113 std::string DebugString() const override {
6114 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6115 expr_->DebugString(), fixed_charge_);
6116 }
6117
WhenRange(Demon * d)6118 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6119
Accept(ModelVisitor * const visitor) const6120 void Accept(ModelVisitor* const visitor) const override {
6121 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6122 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6123 expr_);
6124 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6125 fixed_charge_);
6126 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6127 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6128 }
6129
6130 private:
6131 IntExpr* const expr_;
6132 const int64_t fixed_charge_;
6133 };
6134
6135 // This constraints links an expression and the variable it is casted into
6136 class LinkExprAndVar : public CastConstraint {
6137 public:
LinkExprAndVar(Solver * const s,IntExpr * const expr,IntVar * const var)6138 LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6139 : CastConstraint(s, var), expr_(expr) {}
6140
~LinkExprAndVar()6141 ~LinkExprAndVar() override {}
6142
Post()6143 void Post() override {
6144 Solver* const s = solver();
6145 Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6146 expr_->WhenRange(d);
6147 target_var_->WhenRange(d);
6148 }
6149
InitialPropagate()6150 void InitialPropagate() override {
6151 expr_->SetRange(target_var_->Min(), target_var_->Max());
6152 int64_t l, u;
6153 expr_->Range(&l, &u);
6154 target_var_->SetRange(l, u);
6155 }
6156
DebugString() const6157 std::string DebugString() const override {
6158 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6159 target_var_->DebugString());
6160 }
6161
Accept(ModelVisitor * const visitor) const6162 void Accept(ModelVisitor* const visitor) const override {
6163 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6164 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6165 expr_);
6166 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6167 target_var_);
6168 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6169 }
6170
6171 private:
6172 IntExpr* const expr_;
6173 };
6174
6175 // ----- Conditional Expression -----
6176
6177 class ExprWithEscapeValue : public BaseIntExpr {
6178 public:
ExprWithEscapeValue(Solver * const s,IntVar * const c,IntExpr * const e,int64_t unperformed_value)6179 ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6180 int64_t unperformed_value)
6181 : BaseIntExpr(s),
6182 condition_(c),
6183 expression_(e),
6184 unperformed_value_(unperformed_value) {}
6185
~ExprWithEscapeValue()6186 ~ExprWithEscapeValue() override {}
6187
Min() const6188 int64_t Min() const override {
6189 if (condition_->Min() == 1) {
6190 return expression_->Min();
6191 } else if (condition_->Max() == 1) {
6192 return std::min(unperformed_value_, expression_->Min());
6193 } else {
6194 return unperformed_value_;
6195 }
6196 }
6197
SetMin(int64_t m)6198 void SetMin(int64_t m) override {
6199 if (m > unperformed_value_) {
6200 condition_->SetValue(1);
6201 expression_->SetMin(m);
6202 } else if (condition_->Min() == 1) {
6203 expression_->SetMin(m);
6204 } else if (m > expression_->Max()) {
6205 condition_->SetValue(0);
6206 }
6207 }
6208
Max() const6209 int64_t Max() const override {
6210 if (condition_->Min() == 1) {
6211 return expression_->Max();
6212 } else if (condition_->Max() == 1) {
6213 return std::max(unperformed_value_, expression_->Max());
6214 } else {
6215 return unperformed_value_;
6216 }
6217 }
6218
SetMax(int64_t m)6219 void SetMax(int64_t m) override {
6220 if (m < unperformed_value_) {
6221 condition_->SetValue(1);
6222 expression_->SetMax(m);
6223 } else if (condition_->Min() == 1) {
6224 expression_->SetMax(m);
6225 } else if (m < expression_->Min()) {
6226 condition_->SetValue(0);
6227 }
6228 }
6229
SetRange(int64_t mi,int64_t ma)6230 void SetRange(int64_t mi, int64_t ma) override {
6231 if (ma < unperformed_value_ || mi > unperformed_value_) {
6232 condition_->SetValue(1);
6233 expression_->SetRange(mi, ma);
6234 } else if (condition_->Min() == 1) {
6235 expression_->SetRange(mi, ma);
6236 } else if (ma < expression_->Min() || mi > expression_->Max()) {
6237 condition_->SetValue(0);
6238 }
6239 }
6240
SetValue(int64_t v)6241 void SetValue(int64_t v) override {
6242 if (v != unperformed_value_) {
6243 condition_->SetValue(1);
6244 expression_->SetValue(v);
6245 } else if (condition_->Min() == 1) {
6246 expression_->SetValue(v);
6247 } else if (v < expression_->Min() || v > expression_->Max()) {
6248 condition_->SetValue(0);
6249 }
6250 }
6251
Bound() const6252 bool Bound() const override {
6253 return condition_->Max() == 0 || expression_->Bound();
6254 }
6255
WhenRange(Demon * d)6256 void WhenRange(Demon* d) override {
6257 expression_->WhenRange(d);
6258 condition_->WhenBound(d);
6259 }
6260
DebugString() const6261 std::string DebugString() const override {
6262 return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6263 condition_->DebugString(),
6264 expression_->DebugString(), unperformed_value_);
6265 }
6266
Accept(ModelVisitor * const visitor) const6267 void Accept(ModelVisitor* const visitor) const override {
6268 visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6269 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6270 condition_);
6271 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6272 expression_);
6273 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6274 unperformed_value_);
6275 visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6276 }
6277
6278 private:
6279 IntVar* const condition_;
6280 IntExpr* const expression_;
6281 const int64_t unperformed_value_;
6282 DISALLOW_COPY_AND_ASSIGN(ExprWithEscapeValue);
6283 };
6284
6285 // ----- This is a specialized case when the variable exact type is known -----
6286 class LinkExprAndDomainIntVar : public CastConstraint {
6287 public:
LinkExprAndDomainIntVar(Solver * const s,IntExpr * const expr,DomainIntVar * const var)6288 LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6289 DomainIntVar* const var)
6290 : CastConstraint(s, var),
6291 expr_(expr),
6292 cached_min_(std::numeric_limits<int64_t>::min()),
6293 cached_max_(std::numeric_limits<int64_t>::max()),
6294 fail_stamp_(uint64_t{0}) {}
6295
~LinkExprAndDomainIntVar()6296 ~LinkExprAndDomainIntVar() override {}
6297
var() const6298 DomainIntVar* var() const {
6299 return reinterpret_cast<DomainIntVar*>(target_var_);
6300 }
6301
Post()6302 void Post() override {
6303 Solver* const s = solver();
6304 Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6305 expr_->WhenRange(d);
6306 Demon* const target_var_demon = MakeConstraintDemon0(
6307 solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6308 target_var_->WhenRange(target_var_demon);
6309 }
6310
InitialPropagate()6311 void InitialPropagate() override {
6312 expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6313 expr_->Range(&cached_min_, &cached_max_);
6314 var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6315 }
6316
Propagate()6317 void Propagate() {
6318 if (var()->min_.Value() > cached_min_ ||
6319 var()->max_.Value() < cached_max_ ||
6320 solver()->fail_stamp() != fail_stamp_) {
6321 InitialPropagate();
6322 fail_stamp_ = solver()->fail_stamp();
6323 }
6324 }
6325
DebugString() const6326 std::string DebugString() const override {
6327 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6328 target_var_->DebugString());
6329 }
6330
Accept(ModelVisitor * const visitor) const6331 void Accept(ModelVisitor* const visitor) const override {
6332 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6333 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6334 expr_);
6335 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6336 target_var_);
6337 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6338 }
6339
6340 private:
6341 IntExpr* const expr_;
6342 int64_t cached_min_;
6343 int64_t cached_max_;
6344 uint64_t fail_stamp_;
6345 };
6346 } // namespace
6347
6348 // ----- Misc -----
6349
MakeHoleIterator(bool reversible) const6350 IntVarIterator* BooleanVar::MakeHoleIterator(bool reversible) const {
6351 return COND_REV_ALLOC(reversible, new EmptyIterator());
6352 }
MakeDomainIterator(bool reversible) const6353 IntVarIterator* BooleanVar::MakeDomainIterator(bool reversible) const {
6354 return COND_REV_ALLOC(reversible, new RangeIterator(this));
6355 }
6356
6357 // ----- API -----
6358
CleanVariableOnFail(IntVar * const var)6359 void CleanVariableOnFail(IntVar* const var) {
6360 DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6361 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6362 dvar->CleanInProcess();
6363 }
6364
SetIsEqual(IntVar * const var,const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)6365 Constraint* SetIsEqual(IntVar* const var, const std::vector<int64_t>& values,
6366 const std::vector<IntVar*>& vars) {
6367 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6368 CHECK(dvar != nullptr);
6369 return dvar->SetIsEqual(values, vars);
6370 }
6371
SetIsGreaterOrEqual(IntVar * const var,const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)6372 Constraint* SetIsGreaterOrEqual(IntVar* const var,
6373 const std::vector<int64_t>& values,
6374 const std::vector<IntVar*>& vars) {
6375 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6376 CHECK(dvar != nullptr);
6377 return dvar->SetIsGreaterOrEqual(values, vars);
6378 }
6379
RestoreBoolValue(IntVar * const var)6380 void RestoreBoolValue(IntVar* const var) {
6381 DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6382 BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6383 boolean_var->RestoreValue();
6384 }
6385
6386 // ----- API -----
6387
MakeIntVar(int64_t min,int64_t max,const std::string & name)6388 IntVar* Solver::MakeIntVar(int64_t min, int64_t max, const std::string& name) {
6389 if (min == max) {
6390 return MakeIntConst(min, name);
6391 }
6392 if (min == 0 && max == 1) {
6393 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6394 } else if (CapSub(max, min) == 1) {
6395 const std::string inner_name = "inner_" + name;
6396 return RegisterIntVar(
6397 MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6398 ->VarWithName(name));
6399 } else {
6400 return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6401 }
6402 }
6403
MakeIntVar(int64_t min,int64_t max)6404 IntVar* Solver::MakeIntVar(int64_t min, int64_t max) {
6405 return MakeIntVar(min, max, "");
6406 }
6407
MakeBoolVar(const std::string & name)6408 IntVar* Solver::MakeBoolVar(const std::string& name) {
6409 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6410 }
6411
MakeBoolVar()6412 IntVar* Solver::MakeBoolVar() {
6413 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6414 }
6415
MakeIntVar(const std::vector<int64_t> & values,const std::string & name)6416 IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values,
6417 const std::string& name) {
6418 DCHECK(!values.empty());
6419 // Fast-track the case where we have a single value.
6420 if (values.size() == 1) return MakeIntConst(values[0], name);
6421 // Sort and remove duplicates.
6422 std::vector<int64_t> unique_sorted_values = values;
6423 gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6424 // Case when we have a single value, after clean-up.
6425 if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6426 // Case when the values are a dense interval of integers.
6427 if (unique_sorted_values.size() ==
6428 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6429 return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6430 name);
6431 }
6432 // Compute the GCD: if it's not 1, we can express the variable's domain as
6433 // the product of the GCD and of a domain with smaller values.
6434 int64_t gcd = 0;
6435 for (const int64_t v : unique_sorted_values) {
6436 if (gcd == 0) {
6437 gcd = std::abs(v);
6438 } else {
6439 gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6440 }
6441 if (gcd == 1) {
6442 // If it's 1, though, we can't do anything special, so we
6443 // immediately return a new DomainIntVar.
6444 return RegisterIntVar(
6445 RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6446 }
6447 }
6448 DCHECK_GT(gcd, 1);
6449 for (int64_t& v : unique_sorted_values) {
6450 DCHECK_EQ(0, v % gcd);
6451 v /= gcd;
6452 }
6453 const std::string new_name = name.empty() ? "" : "inner_" + name;
6454 // Catch the case where the divided values are a dense set of integers.
6455 IntVar* inner_intvar = nullptr;
6456 if (unique_sorted_values.size() ==
6457 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6458 inner_intvar = MakeIntVar(unique_sorted_values.front(),
6459 unique_sorted_values.back(), new_name);
6460 } else {
6461 inner_intvar = RegisterIntVar(
6462 RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6463 }
6464 return MakeProd(inner_intvar, gcd)->Var();
6465 }
6466
MakeIntVar(const std::vector<int64_t> & values)6467 IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values) {
6468 return MakeIntVar(values, "");
6469 }
6470
MakeIntVar(const std::vector<int> & values,const std::string & name)6471 IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6472 const std::string& name) {
6473 return MakeIntVar(ToInt64Vector(values), name);
6474 }
6475
MakeIntVar(const std::vector<int> & values)6476 IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6477 return MakeIntVar(values, "");
6478 }
6479
MakeIntConst(int64_t val,const std::string & name)6480 IntVar* Solver::MakeIntConst(int64_t val, const std::string& name) {
6481 // If IntConst is going to be named after its creation,
6482 // cp_share_int_consts should be set to false otherwise names can potentially
6483 // be overwritten.
6484 if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6485 val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6486 return cached_constants_[val - MIN_CACHED_INT_CONST];
6487 }
6488 return RevAlloc(new IntConst(this, val, name));
6489 }
6490
MakeIntConst(int64_t val)6491 IntVar* Solver::MakeIntConst(int64_t val) { return MakeIntConst(val, ""); }
6492
6493 // ----- Int Var and associated methods -----
6494
6495 namespace {
IndexedName(const std::string & prefix,int index,int max_index)6496 std::string IndexedName(const std::string& prefix, int index, int max_index) {
6497 #if 0
6498 #if defined(_MSC_VER)
6499 const int digits = max_index > 0 ?
6500 static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6501 1;
6502 #else
6503 const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6504 #endif
6505 return absl::StrFormat("%s%0*d", prefix, digits, index);
6506 #else
6507 return absl::StrCat(prefix, index);
6508 #endif
6509 }
6510 } // namespace
6511
MakeIntVarArray(int var_count,int64_t vmin,int64_t vmax,const std::string & name,std::vector<IntVar * > * vars)6512 void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6513 const std::string& name,
6514 std::vector<IntVar*>* vars) {
6515 for (int i = 0; i < var_count; ++i) {
6516 vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6517 }
6518 }
6519
MakeIntVarArray(int var_count,int64_t vmin,int64_t vmax,std::vector<IntVar * > * vars)6520 void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6521 std::vector<IntVar*>* vars) {
6522 for (int i = 0; i < var_count; ++i) {
6523 vars->push_back(MakeIntVar(vmin, vmax));
6524 }
6525 }
6526
MakeIntVarArray(int var_count,int64_t vmin,int64_t vmax,const std::string & name)6527 IntVar** Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6528 const std::string& name) {
6529 IntVar** vars = new IntVar*[var_count];
6530 for (int i = 0; i < var_count; ++i) {
6531 vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6532 }
6533 return vars;
6534 }
6535
MakeBoolVarArray(int var_count,const std::string & name,std::vector<IntVar * > * vars)6536 void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6537 std::vector<IntVar*>* vars) {
6538 for (int i = 0; i < var_count; ++i) {
6539 vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6540 }
6541 }
6542
MakeBoolVarArray(int var_count,std::vector<IntVar * > * vars)6543 void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6544 for (int i = 0; i < var_count; ++i) {
6545 vars->push_back(MakeBoolVar());
6546 }
6547 }
6548
MakeBoolVarArray(int var_count,const std::string & name)6549 IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6550 IntVar** vars = new IntVar*[var_count];
6551 for (int i = 0; i < var_count; ++i) {
6552 vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6553 }
6554 return vars;
6555 }
6556
InitCachedIntConstants()6557 void Solver::InitCachedIntConstants() {
6558 for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6559 cached_constants_[i - MIN_CACHED_INT_CONST] =
6560 RevAlloc(new IntConst(this, i, "")); // note the empty name
6561 }
6562 }
6563
MakeSum(IntExpr * const left,IntExpr * const right)6564 IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6565 CHECK_EQ(this, left->solver());
6566 CHECK_EQ(this, right->solver());
6567 if (right->Bound()) {
6568 return MakeSum(left, right->Min());
6569 }
6570 if (left->Bound()) {
6571 return MakeSum(right, left->Min());
6572 }
6573 if (left == right) {
6574 return MakeProd(left, 2);
6575 }
6576 IntExpr* cache = model_cache_->FindExprExprExpression(
6577 left, right, ModelCache::EXPR_EXPR_SUM);
6578 if (cache == nullptr) {
6579 cache = model_cache_->FindExprExprExpression(right, left,
6580 ModelCache::EXPR_EXPR_SUM);
6581 }
6582 if (cache != nullptr) {
6583 return cache;
6584 } else {
6585 IntExpr* const result =
6586 AddOverflows(left->Max(), right->Max()) ||
6587 AddOverflows(left->Min(), right->Min())
6588 ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6589 : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6590 model_cache_->InsertExprExprExpression(result, left, right,
6591 ModelCache::EXPR_EXPR_SUM);
6592 return result;
6593 }
6594 }
6595
MakeSum(IntExpr * const expr,int64_t value)6596 IntExpr* Solver::MakeSum(IntExpr* const expr, int64_t value) {
6597 CHECK_EQ(this, expr->solver());
6598 if (expr->Bound()) {
6599 return MakeIntConst(expr->Min() + value);
6600 }
6601 if (value == 0) {
6602 return expr;
6603 }
6604 IntExpr* result = Cache()->FindExprConstantExpression(
6605 expr, value, ModelCache::EXPR_CONSTANT_SUM);
6606 if (result == nullptr) {
6607 if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6608 !AddOverflows(value, expr->Min())) {
6609 IntVar* const var = expr->Var();
6610 switch (var->VarType()) {
6611 case DOMAIN_INT_VAR: {
6612 result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6613 this, reinterpret_cast<DomainIntVar*>(var), value)));
6614 break;
6615 }
6616 case CONST_VAR: {
6617 result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6618 break;
6619 }
6620 case VAR_ADD_CST: {
6621 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6622 IntVar* const sub_var = add_var->SubVar();
6623 const int64_t new_constant = value + add_var->Constant();
6624 if (new_constant == 0) {
6625 result = sub_var;
6626 } else {
6627 if (sub_var->VarType() == DOMAIN_INT_VAR) {
6628 DomainIntVar* const dvar =
6629 reinterpret_cast<DomainIntVar*>(sub_var);
6630 result = RegisterIntExpr(
6631 RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6632 } else {
6633 result = RegisterIntExpr(
6634 RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6635 }
6636 }
6637 break;
6638 }
6639 case CST_SUB_VAR: {
6640 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6641 IntVar* const sub_var = add_var->SubVar();
6642 const int64_t new_constant = value + add_var->Constant();
6643 result = RegisterIntExpr(
6644 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6645 break;
6646 }
6647 case OPP_VAR: {
6648 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6649 IntVar* const sub_var = add_var->SubVar();
6650 result =
6651 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6652 break;
6653 }
6654 default:
6655 result =
6656 RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6657 }
6658 } else {
6659 result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6660 }
6661 Cache()->InsertExprConstantExpression(result, expr, value,
6662 ModelCache::EXPR_CONSTANT_SUM);
6663 }
6664 return result;
6665 }
6666
MakeDifference(IntExpr * const left,IntExpr * const right)6667 IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6668 CHECK_EQ(this, left->solver());
6669 CHECK_EQ(this, right->solver());
6670 if (left->Bound()) {
6671 return MakeDifference(left->Min(), right);
6672 }
6673 if (right->Bound()) {
6674 return MakeSum(left, -right->Min());
6675 }
6676 IntExpr* sub_left = nullptr;
6677 IntExpr* sub_right = nullptr;
6678 int64_t left_coef = 1;
6679 int64_t right_coef = 1;
6680 if (IsProduct(left, &sub_left, &left_coef) &&
6681 IsProduct(right, &sub_right, &right_coef)) {
6682 const int64_t abs_gcd =
6683 MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6684 if (abs_gcd != 0 && abs_gcd != 1) {
6685 return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6686 MakeProd(sub_right, right_coef / abs_gcd)),
6687 abs_gcd);
6688 }
6689 }
6690
6691 IntExpr* result = Cache()->FindExprExprExpression(
6692 left, right, ModelCache::EXPR_EXPR_DIFFERENCE);
6693 if (result == nullptr) {
6694 if (!SubOverflows(left->Min(), right->Max()) &&
6695 !SubOverflows(left->Max(), right->Min())) {
6696 result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6697 } else {
6698 result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6699 }
6700 Cache()->InsertExprExprExpression(result, left, right,
6701 ModelCache::EXPR_EXPR_DIFFERENCE);
6702 }
6703 return result;
6704 }
6705
6706 // warning: this is 'value - expr'.
MakeDifference(int64_t value,IntExpr * const expr)6707 IntExpr* Solver::MakeDifference(int64_t value, IntExpr* const expr) {
6708 CHECK_EQ(this, expr->solver());
6709 if (expr->Bound()) {
6710 return MakeIntConst(value - expr->Min());
6711 }
6712 if (value == 0) {
6713 return MakeOpposite(expr);
6714 }
6715 IntExpr* result = Cache()->FindExprConstantExpression(
6716 expr, value, ModelCache::EXPR_CONSTANT_DIFFERENCE);
6717 if (result == nullptr) {
6718 if (expr->IsVar() && expr->Min() != std::numeric_limits<int64_t>::min() &&
6719 !SubOverflows(value, expr->Min()) &&
6720 !SubOverflows(value, expr->Max())) {
6721 IntVar* const var = expr->Var();
6722 switch (var->VarType()) {
6723 case VAR_ADD_CST: {
6724 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6725 IntVar* const sub_var = add_var->SubVar();
6726 const int64_t new_constant = value - add_var->Constant();
6727 if (new_constant == 0) {
6728 result = sub_var;
6729 } else {
6730 result = RegisterIntExpr(
6731 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6732 }
6733 break;
6734 }
6735 case CST_SUB_VAR: {
6736 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6737 IntVar* const sub_var = add_var->SubVar();
6738 const int64_t new_constant = value - add_var->Constant();
6739 result = MakeSum(sub_var, new_constant);
6740 break;
6741 }
6742 case OPP_VAR: {
6743 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6744 IntVar* const sub_var = add_var->SubVar();
6745 result = MakeSum(sub_var, value);
6746 break;
6747 }
6748 default:
6749 result =
6750 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6751 }
6752 } else {
6753 result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6754 }
6755 Cache()->InsertExprConstantExpression(result, expr, value,
6756 ModelCache::EXPR_CONSTANT_DIFFERENCE);
6757 }
6758 return result;
6759 }
6760
MakeOpposite(IntExpr * const expr)6761 IntExpr* Solver::MakeOpposite(IntExpr* const expr) {
6762 CHECK_EQ(this, expr->solver());
6763 if (expr->Bound()) {
6764 return MakeIntConst(-expr->Min());
6765 }
6766 IntExpr* result =
6767 Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6768 if (result == nullptr) {
6769 if (expr->IsVar()) {
6770 result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6771 } else {
6772 result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6773 }
6774 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6775 }
6776 return result;
6777 }
6778
MakeProd(IntExpr * const expr,int64_t value)6779 IntExpr* Solver::MakeProd(IntExpr* const expr, int64_t value) {
6780 CHECK_EQ(this, expr->solver());
6781 IntExpr* result = Cache()->FindExprConstantExpression(
6782 expr, value, ModelCache::EXPR_CONSTANT_PROD);
6783 if (result != nullptr) {
6784 return result;
6785 } else {
6786 IntExpr* m_expr = nullptr;
6787 int64_t coefficient = 1;
6788 if (IsProduct(expr, &m_expr, &coefficient)) {
6789 coefficient *= value;
6790 } else {
6791 m_expr = expr;
6792 coefficient = value;
6793 }
6794 if (m_expr->Bound()) {
6795 return MakeIntConst(coefficient * m_expr->Min());
6796 } else if (coefficient == 1) {
6797 return m_expr;
6798 } else if (coefficient == -1) {
6799 return MakeOpposite(m_expr);
6800 } else if (coefficient > 0) {
6801 if (m_expr->Max() > std::numeric_limits<int64_t>::max() / coefficient ||
6802 m_expr->Min() < std::numeric_limits<int64_t>::min() / coefficient) {
6803 result = RegisterIntExpr(
6804 RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6805 } else {
6806 result = RegisterIntExpr(
6807 RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6808 }
6809 } else if (coefficient == 0) {
6810 result = MakeIntConst(0);
6811 } else { // coefficient < 0.
6812 result = RegisterIntExpr(
6813 RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6814 }
6815 if (m_expr->IsVar() &&
6816 !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6817 result = result->Var();
6818 }
6819 Cache()->InsertExprConstantExpression(result, expr, value,
6820 ModelCache::EXPR_CONSTANT_PROD);
6821 return result;
6822 }
6823 }
6824
6825 namespace {
ExtractPower(IntExpr ** const expr,int64_t * const exponant)6826 void ExtractPower(IntExpr** const expr, int64_t* const exponant) {
6827 if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6828 BasePower* const power = dynamic_cast<BasePower*>(*expr);
6829 *expr = power->expr();
6830 *exponant = power->exponant();
6831 }
6832 if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6833 IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6834 *expr = power->expr();
6835 *exponant = 2;
6836 }
6837 if ((*expr)->IsVar()) {
6838 IntVar* const var = (*expr)->Var();
6839 IntExpr* const sub = var->solver()->CastExpression(var);
6840 if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6841 BasePower* const power = dynamic_cast<BasePower*>(sub);
6842 *expr = power->expr();
6843 *exponant = power->exponant();
6844 }
6845 if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6846 IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6847 *expr = power->expr();
6848 *exponant = 2;
6849 }
6850 }
6851 }
6852
ExtractProduct(IntExpr ** const expr,int64_t * const coefficient,bool * modified)6853 void ExtractProduct(IntExpr** const expr, int64_t* const coefficient,
6854 bool* modified) {
6855 if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6856 TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6857 *coefficient *= left_prod->Constant();
6858 *expr = left_prod->SubVar();
6859 *modified = true;
6860 } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6861 TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6862 *coefficient *= left_prod->Constant();
6863 *expr = left_prod->Expr();
6864 *modified = true;
6865 }
6866 }
6867 } // namespace
6868
MakeProd(IntExpr * const left,IntExpr * const right)6869 IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6870 if (left->Bound()) {
6871 return MakeProd(right, left->Min());
6872 }
6873
6874 if (right->Bound()) {
6875 return MakeProd(left, right->Min());
6876 }
6877
6878 // ----- Discover squares and powers -----
6879
6880 IntExpr* m_left = left;
6881 IntExpr* m_right = right;
6882 int64_t left_exponant = 1;
6883 int64_t right_exponant = 1;
6884 ExtractPower(&m_left, &left_exponant);
6885 ExtractPower(&m_right, &right_exponant);
6886
6887 if (m_left == m_right) {
6888 return MakePower(m_left, left_exponant + right_exponant);
6889 }
6890
6891 // ----- Discover nested products -----
6892
6893 m_left = left;
6894 m_right = right;
6895 int64_t coefficient = 1;
6896 bool modified = false;
6897
6898 ExtractProduct(&m_left, &coefficient, &modified);
6899 ExtractProduct(&m_right, &coefficient, &modified);
6900 if (modified) {
6901 return MakeProd(MakeProd(m_left, m_right), coefficient);
6902 }
6903
6904 // ----- Standard build -----
6905
6906 CHECK_EQ(this, left->solver());
6907 CHECK_EQ(this, right->solver());
6908 IntExpr* result = model_cache_->FindExprExprExpression(
6909 left, right, ModelCache::EXPR_EXPR_PROD);
6910 if (result == nullptr) {
6911 result = model_cache_->FindExprExprExpression(right, left,
6912 ModelCache::EXPR_EXPR_PROD);
6913 }
6914 if (result != nullptr) {
6915 return result;
6916 }
6917 if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6918 if (right->Min() >= 0) {
6919 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6920 this, reinterpret_cast<BooleanVar*>(left), right)));
6921 } else {
6922 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6923 this, reinterpret_cast<BooleanVar*>(left), right)));
6924 }
6925 } else if (right->IsVar() &&
6926 reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6927 if (left->Min() >= 0) {
6928 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6929 this, reinterpret_cast<BooleanVar*>(right), left)));
6930 } else {
6931 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6932 this, reinterpret_cast<BooleanVar*>(right), left)));
6933 }
6934 } else if (left->Min() >= 0 && right->Min() >= 0) {
6935 if (CapProd(left->Max(), right->Max()) ==
6936 std::numeric_limits<int64_t>::max()) { // Potential overflow.
6937 result =
6938 RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6939 } else {
6940 result =
6941 RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6942 }
6943 } else {
6944 result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6945 }
6946 model_cache_->InsertExprExprExpression(result, left, right,
6947 ModelCache::EXPR_EXPR_PROD);
6948 return result;
6949 }
6950
MakeDiv(IntExpr * const numerator,IntExpr * const denominator)6951 IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6952 CHECK(numerator != nullptr);
6953 CHECK(denominator != nullptr);
6954 if (denominator->Bound()) {
6955 return MakeDiv(numerator, denominator->Min());
6956 }
6957 IntExpr* result = model_cache_->FindExprExprExpression(
6958 numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6959 if (result != nullptr) {
6960 return result;
6961 }
6962
6963 if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6964 AddConstraint(MakeNonEquality(denominator, 0));
6965 }
6966
6967 if (denominator->Min() >= 0) {
6968 if (numerator->Min() >= 0) {
6969 result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6970 } else {
6971 result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
6972 }
6973 } else if (denominator->Max() <= 0) {
6974 if (numerator->Max() <= 0) {
6975 result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
6976 MakeOpposite(denominator)));
6977 } else {
6978 result = MakeOpposite(RevAlloc(
6979 new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
6980 }
6981 } else {
6982 result = RevAlloc(new DivIntExpr(this, numerator, denominator));
6983 }
6984 model_cache_->InsertExprExprExpression(result, numerator, denominator,
6985 ModelCache::EXPR_EXPR_DIV);
6986 return result;
6987 }
6988
MakeDiv(IntExpr * const expr,int64_t value)6989 IntExpr* Solver::MakeDiv(IntExpr* const expr, int64_t value) {
6990 CHECK(expr != nullptr);
6991 CHECK_EQ(this, expr->solver());
6992 if (expr->Bound()) {
6993 return MakeIntConst(expr->Min() / value);
6994 } else if (value == 1) {
6995 return expr;
6996 } else if (value == -1) {
6997 return MakeOpposite(expr);
6998 } else if (value > 0) {
6999 return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
7000 } else if (value == 0) {
7001 LOG(FATAL) << "Cannot divide by 0";
7002 return nullptr;
7003 } else {
7004 return RegisterIntExpr(
7005 MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
7006 // TODO(user) : implement special case.
7007 }
7008 }
7009
MakeAbsEquality(IntVar * const var,IntVar * const abs_var)7010 Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
7011 if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
7012 Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
7013 }
7014 return RevAlloc(new IntAbsConstraint(this, var, abs_var));
7015 }
7016
MakeAbs(IntExpr * const e)7017 IntExpr* Solver::MakeAbs(IntExpr* const e) {
7018 CHECK_EQ(this, e->solver());
7019 if (e->Min() >= 0) {
7020 return e;
7021 } else if (e->Max() <= 0) {
7022 return MakeOpposite(e);
7023 }
7024 IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
7025 if (result == nullptr) {
7026 int64_t coefficient = 1;
7027 IntExpr* expr = nullptr;
7028 if (IsProduct(e, &expr, &coefficient)) {
7029 result = MakeProd(MakeAbs(expr), std::abs(coefficient));
7030 } else {
7031 result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
7032 }
7033 Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7034 }
7035 return result;
7036 }
7037
MakeSquare(IntExpr * const expr)7038 IntExpr* Solver::MakeSquare(IntExpr* const expr) {
7039 CHECK_EQ(this, expr->solver());
7040 if (expr->Bound()) {
7041 const int64_t v = expr->Min();
7042 return MakeIntConst(v * v);
7043 }
7044 IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7045 if (result == nullptr) {
7046 if (expr->Min() >= 0) {
7047 result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7048 } else {
7049 result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7050 }
7051 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7052 }
7053 return result;
7054 }
7055
MakePower(IntExpr * const expr,int64_t n)7056 IntExpr* Solver::MakePower(IntExpr* const expr, int64_t n) {
7057 CHECK_EQ(this, expr->solver());
7058 CHECK_GE(n, 0);
7059 if (expr->Bound()) {
7060 const int64_t v = expr->Min();
7061 if (v >= OverflowLimit(n)) { // Overflow.
7062 return MakeIntConst(std::numeric_limits<int64_t>::max());
7063 }
7064 return MakeIntConst(IntPower(v, n));
7065 }
7066 switch (n) {
7067 case 0:
7068 return MakeIntConst(1);
7069 case 1:
7070 return expr;
7071 case 2:
7072 return MakeSquare(expr);
7073 default: {
7074 IntExpr* result = nullptr;
7075 if (n % 2 == 0) { // even.
7076 if (expr->Min() >= 0) {
7077 result =
7078 RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7079 } else {
7080 result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7081 }
7082 } else {
7083 result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7084 }
7085 return result;
7086 }
7087 }
7088 }
7089
MakeMin(IntExpr * const left,IntExpr * const right)7090 IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7091 CHECK_EQ(this, left->solver());
7092 CHECK_EQ(this, right->solver());
7093 if (left->Bound()) {
7094 return MakeMin(right, left->Min());
7095 }
7096 if (right->Bound()) {
7097 return MakeMin(left, right->Min());
7098 }
7099 if (left->Min() >= right->Max()) {
7100 return right;
7101 }
7102 if (right->Min() >= left->Max()) {
7103 return left;
7104 }
7105 return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7106 }
7107
MakeMin(IntExpr * const expr,int64_t value)7108 IntExpr* Solver::MakeMin(IntExpr* const expr, int64_t value) {
7109 CHECK_EQ(this, expr->solver());
7110 if (value <= expr->Min()) {
7111 return MakeIntConst(value);
7112 }
7113 if (expr->Bound()) {
7114 return MakeIntConst(std::min(expr->Min(), value));
7115 }
7116 if (expr->Max() <= value) {
7117 return expr;
7118 }
7119 return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7120 }
7121
MakeMin(IntExpr * const expr,int value)7122 IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7123 return MakeMin(expr, static_cast<int64_t>(value));
7124 }
7125
MakeMax(IntExpr * const left,IntExpr * const right)7126 IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7127 CHECK_EQ(this, left->solver());
7128 CHECK_EQ(this, right->solver());
7129 if (left->Bound()) {
7130 return MakeMax(right, left->Min());
7131 }
7132 if (right->Bound()) {
7133 return MakeMax(left, right->Min());
7134 }
7135 if (left->Min() >= right->Max()) {
7136 return left;
7137 }
7138 if (right->Min() >= left->Max()) {
7139 return right;
7140 }
7141 return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7142 }
7143
MakeMax(IntExpr * const expr,int64_t value)7144 IntExpr* Solver::MakeMax(IntExpr* const expr, int64_t value) {
7145 CHECK_EQ(this, expr->solver());
7146 if (expr->Bound()) {
7147 return MakeIntConst(std::max(expr->Min(), value));
7148 }
7149 if (value <= expr->Min()) {
7150 return expr;
7151 }
7152 if (expr->Max() <= value) {
7153 return MakeIntConst(value);
7154 }
7155 return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7156 }
7157
MakeMax(IntExpr * const expr,int value)7158 IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7159 return MakeMax(expr, static_cast<int64_t>(value));
7160 }
7161
MakeConvexPiecewiseExpr(IntExpr * expr,int64_t early_cost,int64_t early_date,int64_t late_date,int64_t late_cost)7162 IntExpr* Solver::MakeConvexPiecewiseExpr(IntExpr* expr, int64_t early_cost,
7163 int64_t early_date, int64_t late_date,
7164 int64_t late_cost) {
7165 return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7166 this, expr, early_cost, early_date, late_date, late_cost)));
7167 }
7168
MakeSemiContinuousExpr(IntExpr * const expr,int64_t fixed_charge,int64_t step)7169 IntExpr* Solver::MakeSemiContinuousExpr(IntExpr* const expr,
7170 int64_t fixed_charge, int64_t step) {
7171 if (step == 0) {
7172 if (fixed_charge == 0) {
7173 return MakeIntConst(int64_t{0});
7174 } else {
7175 return RegisterIntExpr(
7176 RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7177 }
7178 } else if (step == 1) {
7179 return RegisterIntExpr(
7180 RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7181 } else {
7182 return RegisterIntExpr(
7183 RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7184 }
7185 // TODO(user) : benchmark with virtualization of
7186 // PosIntDivDown and PosIntDivUp - or function pointers.
7187 }
7188
7189 // ----- Piecewise Linear -----
7190
7191 class PiecewiseLinearExpr : public BaseIntExpr {
7192 public:
PiecewiseLinearExpr(Solver * solver,IntExpr * expr,const PiecewiseLinearFunction & f)7193 PiecewiseLinearExpr(Solver* solver, IntExpr* expr,
7194 const PiecewiseLinearFunction& f)
7195 : BaseIntExpr(solver), expr_(expr), f_(f) {}
~PiecewiseLinearExpr()7196 ~PiecewiseLinearExpr() override {}
Min() const7197 int64_t Min() const override {
7198 return f_.GetMinimum(expr_->Min(), expr_->Max());
7199 }
SetMin(int64_t m)7200 void SetMin(int64_t m) override {
7201 const auto& range =
7202 f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7203 expr_->SetRange(range.first, range.second);
7204 }
7205
Max() const7206 int64_t Max() const override {
7207 return f_.GetMaximum(expr_->Min(), expr_->Max());
7208 }
7209
SetMax(int64_t m)7210 void SetMax(int64_t m) override {
7211 const auto& range =
7212 f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7213 expr_->SetRange(range.first, range.second);
7214 }
7215
SetRange(int64_t l,int64_t u)7216 void SetRange(int64_t l, int64_t u) override {
7217 const auto& range =
7218 f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7219 expr_->SetRange(range.first, range.second);
7220 }
name() const7221 std::string name() const override {
7222 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7223 f_.DebugString());
7224 }
7225
DebugString() const7226 std::string DebugString() const override {
7227 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7228 f_.DebugString());
7229 }
7230
WhenRange(Demon * d)7231 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7232
Accept(ModelVisitor * const visitor) const7233 void Accept(ModelVisitor* const visitor) const override {
7234 // TODO(user): Implement visitor.
7235 }
7236
7237 private:
7238 IntExpr* const expr_;
7239 const PiecewiseLinearFunction f_;
7240 };
7241
MakePiecewiseLinearExpr(IntExpr * expr,const PiecewiseLinearFunction & f)7242 IntExpr* Solver::MakePiecewiseLinearExpr(IntExpr* expr,
7243 const PiecewiseLinearFunction& f) {
7244 return RegisterIntExpr(RevAlloc(new PiecewiseLinearExpr(this, expr, f)));
7245 }
7246
7247 // ----- Conditional Expression -----
7248
MakeConditionalExpression(IntVar * const condition,IntExpr * const expr,int64_t unperformed_value)7249 IntExpr* Solver::MakeConditionalExpression(IntVar* const condition,
7250 IntExpr* const expr,
7251 int64_t unperformed_value) {
7252 if (condition->Min() == 1) {
7253 return expr;
7254 } else if (condition->Max() == 0) {
7255 return MakeIntConst(unperformed_value);
7256 } else {
7257 IntExpr* cache = Cache()->FindExprExprConstantExpression(
7258 condition, expr, unperformed_value,
7259 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7260 if (cache == nullptr) {
7261 cache = RevAlloc(
7262 new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7263 Cache()->InsertExprExprConstantExpression(
7264 cache, condition, expr, unperformed_value,
7265 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7266 }
7267 return cache;
7268 }
7269 }
7270
7271 // ----- Modulo -----
7272
MakeModulo(IntExpr * const x,int64_t mod)7273 IntExpr* Solver::MakeModulo(IntExpr* const x, int64_t mod) {
7274 IntVar* const result =
7275 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7276 if (mod >= 0) {
7277 AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7278 } else {
7279 AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7280 }
7281 return result;
7282 }
7283
MakeModulo(IntExpr * const x,IntExpr * const mod)7284 IntExpr* Solver::MakeModulo(IntExpr* const x, IntExpr* const mod) {
7285 if (mod->Bound()) {
7286 return MakeModulo(x, mod->Min());
7287 }
7288 IntVar* const result =
7289 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7290 AddConstraint(MakeLess(result, MakeAbs(mod)));
7291 AddConstraint(MakeGreater(result, MakeOpposite(MakeAbs(mod))));
7292 return result;
7293 }
7294
7295 // --------- IntVar ---------
7296
VarType() const7297 int IntVar::VarType() const { return UNSPECIFIED; }
7298
RemoveValues(const std::vector<int64_t> & values)7299 void IntVar::RemoveValues(const std::vector<int64_t>& values) {
7300 // TODO(user): Check and maybe inline this code.
7301 const int size = values.size();
7302 DCHECK_GE(size, 0);
7303 switch (size) {
7304 case 0: {
7305 return;
7306 }
7307 case 1: {
7308 RemoveValue(values[0]);
7309 return;
7310 }
7311 case 2: {
7312 RemoveValue(values[0]);
7313 RemoveValue(values[1]);
7314 return;
7315 }
7316 case 3: {
7317 RemoveValue(values[0]);
7318 RemoveValue(values[1]);
7319 RemoveValue(values[2]);
7320 return;
7321 }
7322 default: {
7323 // 4 values, let's start doing some more clever things.
7324 // TODO(user) : Sort values!
7325 int start_index = 0;
7326 int64_t new_min = Min();
7327 if (values[start_index] <= new_min) {
7328 while (start_index < size - 1 &&
7329 values[start_index + 1] == values[start_index] + 1) {
7330 new_min = values[start_index + 1] + 1;
7331 start_index++;
7332 }
7333 }
7334 int end_index = size - 1;
7335 int64_t new_max = Max();
7336 if (values[end_index] >= new_max) {
7337 while (end_index > start_index + 1 &&
7338 values[end_index - 1] == values[end_index] - 1) {
7339 new_max = values[end_index - 1] - 1;
7340 end_index--;
7341 }
7342 }
7343 SetRange(new_min, new_max);
7344 for (int i = start_index; i <= end_index; ++i) {
7345 RemoveValue(values[i]);
7346 }
7347 }
7348 }
7349 }
7350
Accept(ModelVisitor * const visitor) const7351 void IntVar::Accept(ModelVisitor* const visitor) const {
7352 IntExpr* const casted = solver()->CastExpression(this);
7353 visitor->VisitIntegerVariable(this, casted);
7354 }
7355
SetValues(const std::vector<int64_t> & values)7356 void IntVar::SetValues(const std::vector<int64_t>& values) {
7357 switch (values.size()) {
7358 case 0: {
7359 solver()->Fail();
7360 break;
7361 }
7362 case 1: {
7363 SetValue(values.back());
7364 break;
7365 }
7366 case 2: {
7367 if (Contains(values[0])) {
7368 if (Contains(values[1])) {
7369 const int64_t l = std::min(values[0], values[1]);
7370 const int64_t u = std::max(values[0], values[1]);
7371 SetRange(l, u);
7372 if (u > l + 1) {
7373 RemoveInterval(l + 1, u - 1);
7374 }
7375 } else {
7376 SetValue(values[0]);
7377 }
7378 } else {
7379 SetValue(values[1]);
7380 }
7381 break;
7382 }
7383 default: {
7384 // TODO(user): use a clean and safe SortedUniqueCopy() class
7385 // that uses a global, static shared (and locked) storage.
7386 // TODO(user): [optional] consider porting
7387 // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7388 // existing open_source/base/stl_util.h and using it here.
7389 // TODO(user): We could filter out values not in the var.
7390 std::vector<int64_t>& tmp = solver()->tmp_vector_;
7391 tmp.clear();
7392 tmp.insert(tmp.end(), values.begin(), values.end());
7393 std::sort(tmp.begin(), tmp.end());
7394 tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7395 const int size = tmp.size();
7396 const int64_t vmin = Min();
7397 const int64_t vmax = Max();
7398 int first = 0;
7399 int last = size - 1;
7400 if (tmp.front() > vmax || tmp.back() < vmin) {
7401 solver()->Fail();
7402 }
7403 // TODO(user) : We could find the first position >= vmin by dichotomy.
7404 while (tmp[first] < vmin || !Contains(tmp[first])) {
7405 ++first;
7406 if (first > last || tmp[first] > vmax) {
7407 solver()->Fail();
7408 }
7409 }
7410 while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7411 // Note that last >= first implies tmp[last] >= vmin.
7412 --last;
7413 }
7414 DCHECK_GE(last, first);
7415 SetRange(tmp[first], tmp[last]);
7416 while (first < last) {
7417 const int64_t start = tmp[first] + 1;
7418 const int64_t end = tmp[first + 1] - 1;
7419 if (start <= end) {
7420 RemoveInterval(start, end);
7421 }
7422 first++;
7423 }
7424 }
7425 }
7426 }
7427 // ---------- BaseIntExpr ---------
7428
LinkVarExpr(Solver * const s,IntExpr * const expr,IntVar * const var)7429 void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var) {
7430 if (!var->Bound()) {
7431 if (var->VarType() == DOMAIN_INT_VAR) {
7432 DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7433 s->AddCastConstraint(
7434 s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7435 } else {
7436 s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7437 expr);
7438 }
7439 }
7440 }
7441
Var()7442 IntVar* BaseIntExpr::Var() {
7443 if (var_ == nullptr) {
7444 solver()->SaveValue(reinterpret_cast<void**>(&var_));
7445 var_ = CastToVar();
7446 }
7447 return var_;
7448 }
7449
CastToVar()7450 IntVar* BaseIntExpr::CastToVar() {
7451 int64_t vmin, vmax;
7452 Range(&vmin, &vmax);
7453 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7454 LinkVarExpr(solver(), this, var);
7455 return var;
7456 }
7457
7458 // Discovery methods
IsADifference(IntExpr * expr,IntExpr ** const left,IntExpr ** const right)7459 bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7460 IntExpr** const right) {
7461 if (expr->IsVar()) {
7462 IntVar* const expr_var = expr->Var();
7463 expr = CastExpression(expr_var);
7464 }
7465 // This is a dynamic cast to check the type of expr.
7466 // It returns nullptr is expr is not a subclass of SubIntExpr.
7467 SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7468 if (sub_expr != nullptr) {
7469 *left = sub_expr->left();
7470 *right = sub_expr->right();
7471 return true;
7472 }
7473 return false;
7474 }
7475
IsBooleanVar(IntExpr * const expr,IntVar ** inner_var,bool * is_negated) const7476 bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7477 bool* is_negated) const {
7478 if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7479 *inner_var = expr->Var();
7480 *is_negated = false;
7481 return true;
7482 } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7483 SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7484 if (sub_var != nullptr && sub_var->Constant() == 1 &&
7485 sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7486 *is_negated = true;
7487 *inner_var = sub_var->SubVar();
7488 return true;
7489 }
7490 }
7491 return false;
7492 }
7493
IsProduct(IntExpr * const expr,IntExpr ** inner_expr,int64_t * coefficient)7494 bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7495 int64_t* coefficient) {
7496 if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7497 TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7498 *coefficient = var->Constant();
7499 *inner_expr = var->SubVar();
7500 return true;
7501 } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7502 TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7503 *coefficient = prod->Constant();
7504 *inner_expr = prod->Expr();
7505 return true;
7506 }
7507 *inner_expr = expr;
7508 *coefficient = 1;
7509 return false;
7510 }
7511
7512 #undef COND_REV_ALLOC
7513
7514 } // namespace operations_research
7515