1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #include <algorithm>
15 #include <cmath>
16 #include <cstdint>
17 #include <limits>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "ortools/base/commandlineflags.h"
27 #include "ortools/base/integral_types.h"
28 #include "ortools/base/logging.h"
29 #include "ortools/base/map_util.h"
30 #include "ortools/base/mathutil.h"
31 #include "ortools/base/stl_util.h"
32 #include "ortools/constraint_solver/constraint_solver.h"
33 #include "ortools/constraint_solver/constraint_solveri.h"
34 #include "ortools/util/bitset.h"
35 #include "ortools/util/saturated_arithmetic.h"
36 #include "ortools/util/string_array.h"
37 
38 ABSL_FLAG(bool, cp_disable_expression_optimization, false,
39           "Disable special optimization when creating expressions.");
40 ABSL_FLAG(bool, cp_share_int_consts, true,
41           "Share IntConst's with the same value.");
42 
43 #if defined(_MSC_VER)
44 #pragma warning(disable : 4351 4355)
45 #endif
46 
47 namespace operations_research {
48 
49 // ---------- IntExpr ----------
50 
VarWithName(const std::string & name)51 IntVar* IntExpr::VarWithName(const std::string& name) {
52   IntVar* const var = Var();
53   var->set_name(name);
54   return var;
55 }
56 
57 // ---------- IntVar ----------
58 
IntVar(Solver * const s)59 IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
60 
IntVar(Solver * const s,const std::string & name)61 IntVar::IntVar(Solver* const s, const std::string& name)
62     : IntExpr(s), index_(s->GetNewIntVarIndex()) {
63   set_name(name);
64 }
65 
66 // ----- Boolean variable -----
67 
68 const int BooleanVar::kUnboundBooleanVarValue = 2;
69 
SetMin(int64_t m)70 void BooleanVar::SetMin(int64_t m) {
71   if (m <= 0) return;
72   if (m > 1) solver()->Fail();
73   SetValue(1);
74 }
75 
SetMax(int64_t m)76 void BooleanVar::SetMax(int64_t m) {
77   if (m >= 1) return;
78   if (m < 0) solver()->Fail();
79   SetValue(0);
80 }
81 
SetRange(int64_t mi,int64_t ma)82 void BooleanVar::SetRange(int64_t mi, int64_t ma) {
83   if (mi > 1 || ma < 0 || mi > ma) {
84     solver()->Fail();
85   }
86   if (mi == 1) {
87     SetValue(1);
88   } else if (ma == 0) {
89     SetValue(0);
90   }
91 }
92 
RemoveValue(int64_t v)93 void BooleanVar::RemoveValue(int64_t v) {
94   if (value_ == kUnboundBooleanVarValue) {
95     if (v == 0) {
96       SetValue(1);
97     } else if (v == 1) {
98       SetValue(0);
99     }
100   } else if (v == value_) {
101     solver()->Fail();
102   }
103 }
104 
RemoveInterval(int64_t l,int64_t u)105 void BooleanVar::RemoveInterval(int64_t l, int64_t u) {
106   if (u < l) return;
107   if (l <= 0 && u >= 1) {
108     solver()->Fail();
109   } else if (l == 1) {
110     SetValue(0);
111   } else if (u == 0) {
112     SetValue(1);
113   }
114 }
115 
WhenBound(Demon * d)116 void BooleanVar::WhenBound(Demon* d) {
117   if (value_ == kUnboundBooleanVarValue) {
118     if (d->priority() == Solver::DELAYED_PRIORITY) {
119       delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
120     } else {
121       bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
122     }
123   }
124 }
125 
Size() const126 uint64_t BooleanVar::Size() const {
127   return (1 + (value_ == kUnboundBooleanVarValue));
128 }
129 
Contains(int64_t v) const130 bool BooleanVar::Contains(int64_t v) const {
131   return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
132 }
133 
IsEqual(int64_t constant)134 IntVar* BooleanVar::IsEqual(int64_t constant) {
135   if (constant > 1 || constant < 0) {
136     return solver()->MakeIntConst(0);
137   }
138   if (constant == 1) {
139     return this;
140   } else {  // constant == 0.
141     return solver()->MakeDifference(1, this)->Var();
142   }
143 }
144 
IsDifferent(int64_t constant)145 IntVar* BooleanVar::IsDifferent(int64_t constant) {
146   if (constant > 1 || constant < 0) {
147     return solver()->MakeIntConst(1);
148   }
149   if (constant == 1) {
150     return solver()->MakeDifference(1, this)->Var();
151   } else {  // constant == 0.
152     return this;
153   }
154 }
155 
IsGreaterOrEqual(int64_t constant)156 IntVar* BooleanVar::IsGreaterOrEqual(int64_t constant) {
157   if (constant > 1) {
158     return solver()->MakeIntConst(0);
159   } else if (constant <= 0) {
160     return solver()->MakeIntConst(1);
161   } else {
162     return this;
163   }
164 }
165 
IsLessOrEqual(int64_t constant)166 IntVar* BooleanVar::IsLessOrEqual(int64_t constant) {
167   if (constant < 0) {
168     return solver()->MakeIntConst(0);
169   } else if (constant >= 1) {
170     return solver()->MakeIntConst(1);
171   } else {
172     return IsEqual(0);
173   }
174 }
175 
DebugString() const176 std::string BooleanVar::DebugString() const {
177   std::string out;
178   const std::string& var_name = name();
179   if (!var_name.empty()) {
180     out = var_name + "(";
181   } else {
182     out = "BooleanVar(";
183   }
184   switch (value_) {
185     case 0:
186       out += "0";
187       break;
188     case 1:
189       out += "1";
190       break;
191     case kUnboundBooleanVarValue:
192       out += "0 .. 1";
193       break;
194   }
195   out += ")";
196   return out;
197 }
198 
199 namespace {
200 // ---------- Subclasses of IntVar ----------
201 
202 // ----- Domain Int Var: base class for variables -----
203 // It Contains bounds and a bitset representation of possible values.
204 class DomainIntVar : public IntVar {
205  public:
206   // Utility classes
207   class BitSetIterator : public BaseObject {
208    public:
BitSetIterator(uint64_t * const bitset,int64_t omin)209     BitSetIterator(uint64_t* const bitset, int64_t omin)
210         : bitset_(bitset),
211           omin_(omin),
212           max_(std::numeric_limits<int64_t>::min()),
213           current_(std::numeric_limits<int64_t>::max()) {}
214 
~BitSetIterator()215     ~BitSetIterator() override {}
216 
Init(int64_t min,int64_t max)217     void Init(int64_t min, int64_t max) {
218       max_ = max;
219       current_ = min;
220     }
221 
Ok() const222     bool Ok() const { return current_ <= max_; }
223 
Value() const224     int64_t Value() const { return current_; }
225 
Next()226     void Next() {
227       if (++current_ <= max_) {
228         current_ = UnsafeLeastSignificantBitPosition64(
229                        bitset_, current_ - omin_, max_ - omin_) +
230                    omin_;
231       }
232     }
233 
DebugString() const234     std::string DebugString() const override { return "BitSetIterator"; }
235 
236    private:
237     uint64_t* const bitset_;
238     const int64_t omin_;
239     int64_t max_;
240     int64_t current_;
241   };
242 
243   class BitSet : public BaseObject {
244    public:
BitSet(Solver * const s)245     explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
~BitSet()246     ~BitSet() override {}
247 
248     virtual int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) = 0;
249     virtual int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) = 0;
250     virtual bool Contains(int64_t val) const = 0;
251     virtual bool SetValue(int64_t val) = 0;
252     virtual bool RemoveValue(int64_t val) = 0;
253     virtual uint64_t Size() const = 0;
254     virtual void DelayRemoveValue(int64_t val) = 0;
255     virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
256     virtual void ClearRemovedValues() = 0;
257     virtual std::string pretty_DebugString(int64_t min, int64_t max) const = 0;
258     virtual BitSetIterator* MakeIterator() = 0;
259 
InitHoles()260     void InitHoles() {
261       const uint64_t current_stamp = solver_->stamp();
262       if (holes_stamp_ < current_stamp) {
263         holes_.clear();
264         holes_stamp_ = current_stamp;
265       }
266     }
267 
ClearHoles()268     virtual void ClearHoles() { holes_.clear(); }
269 
Holes()270     const std::vector<int64_t>& Holes() { return holes_; }
271 
AddHole(int64_t value)272     void AddHole(int64_t value) { holes_.push_back(value); }
273 
NumHoles() const274     int NumHoles() const {
275       return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
276     }
277 
278    protected:
279     Solver* const solver_;
280 
281    private:
282     std::vector<int64_t> holes_;
283     uint64_t holes_stamp_;
284   };
285 
286   class QueueHandler : public Demon {
287    public:
QueueHandler(DomainIntVar * const var)288     explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
~QueueHandler()289     ~QueueHandler() override {}
Run(Solver * const s)290     void Run(Solver* const s) override {
291       s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
292       var_->Process();
293       s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
294     }
priority() const295     Solver::DemonPriority priority() const override {
296       return Solver::VAR_PRIORITY;
297     }
DebugString() const298     std::string DebugString() const override {
299       return absl::StrFormat("Handler(%s)", var_->DebugString());
300     }
301 
302    private:
303     DomainIntVar* const var_;
304   };
305 
306   // Bounds and Value watchers
307 
308   // This class stores the watchers variables attached to values. It is
309   // reversible and it helps maintaining the set of 'active' watchers
310   // (variables not bound to a single value).
311   template <class T>
312   class RevIntPtrMap {
313    public:
RevIntPtrMap(Solver * const solver,int64_t rmin,int64_t rmax)314     RevIntPtrMap(Solver* const solver, int64_t rmin, int64_t rmax)
315         : solver_(solver), range_min_(rmin), start_(0) {}
316 
~RevIntPtrMap()317     ~RevIntPtrMap() {}
318 
Empty() const319     bool Empty() const { return start_.Value() == elements_.size(); }
320 
SortActive()321     void SortActive() { std::sort(elements_.begin(), elements_.end()); }
322 
323     // Access with value API.
324 
325     // Add the pointer to the map attached to the given value.
UnsafeRevInsert(int64_t value,T * elem)326     void UnsafeRevInsert(int64_t value, T* elem) {
327       elements_.push_back(std::make_pair(value, elem));
328       if (solver_->state() != Solver::OUTSIDE_SEARCH) {
329         solver_->AddBacktrackAction(
330             [this, value](Solver* s) { Uninsert(value); }, false);
331       }
332     }
333 
FindPtrOrNull(int64_t value,int * position)334     T* FindPtrOrNull(int64_t value, int* position) {
335       for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
336         if (elements_[pos].first == value) {
337           if (position != nullptr) *position = pos;
338           return At(pos).second;
339         }
340       }
341       return nullptr;
342     }
343 
344     // Access map through the underlying vector.
RemoveAt(int position)345     void RemoveAt(int position) {
346       const int start = start_.Value();
347       DCHECK_GE(position, start);
348       DCHECK_LT(position, elements_.size());
349       if (position > start) {
350         // Swap the current element with the one at the start position, and
351         // increase start.
352         const std::pair<int64_t, T*> copy = elements_[start];
353         elements_[start] = elements_[position];
354         elements_[position] = copy;
355       }
356       start_.Incr(solver_);
357     }
358 
At(int position) const359     const std::pair<int64_t, T*>& At(int position) const {
360       DCHECK_GE(position, start_.Value());
361       DCHECK_LT(position, elements_.size());
362       return elements_[position];
363     }
364 
RemoveAll()365     void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
366 
start() const367     int start() const { return start_.Value(); }
end() const368     int end() const { return elements_.size(); }
369     // Number of active elements.
Size() const370     int Size() const { return elements_.size() - start_.Value(); }
371 
372     // Removes the object permanently from the map.
Uninsert(int64_t value)373     void Uninsert(int64_t value) {
374       for (int pos = 0; pos < elements_.size(); ++pos) {
375         if (elements_[pos].first == value) {
376           DCHECK_GE(pos, start_.Value());
377           const int last = elements_.size() - 1;
378           if (pos != last) {  // Swap the current with the last.
379             elements_[pos] = elements_.back();
380           }
381           elements_.pop_back();
382           return;
383         }
384       }
385       LOG(FATAL) << "The element should have been removed";
386     }
387 
388    private:
389     Solver* const solver_;
390     const int64_t range_min_;
391     NumericalRev<int> start_;
392     std::vector<std::pair<int64_t, T*>> elements_;
393   };
394 
395   // Base class for value watchers
396   class BaseValueWatcher : public Constraint {
397    public:
BaseValueWatcher(Solver * const solver)398     explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
399 
~BaseValueWatcher()400     ~BaseValueWatcher() override {}
401 
402     virtual IntVar* GetOrMakeValueWatcher(int64_t value) = 0;
403 
404     virtual void SetValueWatcher(IntVar* const boolvar, int64_t value) = 0;
405   };
406 
407   // This class monitors the domain of the variable and updates the
408   // IsEqual/IsDifferent boolean variables accordingly.
409   class ValueWatcher : public BaseValueWatcher {
410    public:
411     class WatchDemon : public Demon {
412      public:
WatchDemon(ValueWatcher * const watcher,int64_t value,IntVar * var)413       WatchDemon(ValueWatcher* const watcher, int64_t value, IntVar* var)
414           : value_watcher_(watcher), value_(value), var_(var) {}
~WatchDemon()415       ~WatchDemon() override {}
416 
Run(Solver * const solver)417       void Run(Solver* const solver) override {
418         value_watcher_->ProcessValueWatcher(value_, var_);
419       }
420 
421      private:
422       ValueWatcher* const value_watcher_;
423       const int64_t value_;
424       IntVar* const var_;
425     };
426 
427     class VarDemon : public Demon {
428      public:
VarDemon(ValueWatcher * const watcher)429       explicit VarDemon(ValueWatcher* const watcher)
430           : value_watcher_(watcher) {}
431 
~VarDemon()432       ~VarDemon() override {}
433 
Run(Solver * const solver)434       void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
435 
436      private:
437       ValueWatcher* const value_watcher_;
438     };
439 
ValueWatcher(Solver * const solver,DomainIntVar * const variable)440     ValueWatcher(Solver* const solver, DomainIntVar* const variable)
441         : BaseValueWatcher(solver),
442           variable_(variable),
443           hole_iterator_(variable_->MakeHoleIterator(true)),
444           var_demon_(nullptr),
445           watchers_(solver, variable->Min(), variable->Max()) {}
446 
~ValueWatcher()447     ~ValueWatcher() override {}
448 
GetOrMakeValueWatcher(int64_t value)449     IntVar* GetOrMakeValueWatcher(int64_t value) override {
450       IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
451       if (watcher != nullptr) return watcher;
452       if (variable_->Contains(value)) {
453         if (variable_->Bound()) {
454           return solver()->MakeIntConst(1);
455         } else {
456           const std::string vname = variable_->HasName()
457                                         ? variable_->name()
458                                         : variable_->DebugString();
459           const std::string bname =
460               absl::StrFormat("Watch<%s == %d>", vname, value);
461           IntVar* const boolvar = solver()->MakeBoolVar(bname);
462           watchers_.UnsafeRevInsert(value, boolvar);
463           if (posted_.Switched()) {
464             boolvar->WhenBound(
465                 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
466             var_demon_->desinhibit(solver());
467           }
468           return boolvar;
469         }
470       } else {
471         return variable_->solver()->MakeIntConst(0);
472       }
473     }
474 
SetValueWatcher(IntVar * const boolvar,int64_t value)475     void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
476       CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
477       if (!boolvar->Bound()) {
478         watchers_.UnsafeRevInsert(value, boolvar);
479         if (posted_.Switched() && !boolvar->Bound()) {
480           boolvar->WhenBound(
481               solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
482           var_demon_->desinhibit(solver());
483         }
484       }
485     }
486 
Post()487     void Post() override {
488       var_demon_ = solver()->RevAlloc(new VarDemon(this));
489       variable_->WhenDomain(var_demon_);
490       for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
491         const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
492         const int64_t value = w.first;
493         IntVar* const boolvar = w.second;
494         if (!boolvar->Bound() && variable_->Contains(value)) {
495           boolvar->WhenBound(
496               solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
497         }
498       }
499       posted_.Switch(solver());
500     }
501 
InitialPropagate()502     void InitialPropagate() override {
503       if (variable_->Bound()) {
504         VariableBound();
505       } else {
506         for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
507           const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
508           const int64_t value = w.first;
509           IntVar* const boolvar = w.second;
510           if (!variable_->Contains(value)) {
511             boolvar->SetValue(0);
512             watchers_.RemoveAt(pos);
513           } else {
514             if (boolvar->Bound()) {
515               ProcessValueWatcher(value, boolvar);
516               watchers_.RemoveAt(pos);
517             }
518           }
519         }
520         CheckInhibit();
521       }
522     }
523 
ProcessValueWatcher(int64_t value,IntVar * boolvar)524     void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
525       if (boolvar->Min() == 0) {
526         if (variable_->Size() < 0xFFFFFF) {
527           variable_->RemoveValue(value);
528         } else {
529           // Delay removal.
530           solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
531         }
532       } else {
533         variable_->SetValue(value);
534       }
535     }
536 
ProcessVar()537     void ProcessVar() {
538       const int kSmallList = 16;
539       if (variable_->Bound()) {
540         VariableBound();
541       } else if (watchers_.Size() <= kSmallList ||
542                  variable_->Min() != variable_->OldMin() ||
543                  variable_->Max() != variable_->OldMax()) {
544         // Brute force loop for small numbers of watchers, or if the bounds have
545         // changed, which would have required a sort (n log(n)) anyway to take
546         // advantage of.
547         ScanWatchers();
548         CheckInhibit();
549       } else {
550         // If there is no bitset, then there are no holes.
551         // In that case, the two loops above should have performed all
552         // propagation. Otherwise, scan the remaining watchers.
553         BitSet* const bitset = variable_->bitset();
554         if (bitset != nullptr && !watchers_.Empty()) {
555           if (bitset->NumHoles() * 2 < watchers_.Size()) {
556             for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
557               int pos = 0;
558               IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
559               if (boolvar != nullptr) {
560                 boolvar->SetValue(0);
561                 watchers_.RemoveAt(pos);
562               }
563             }
564           } else {
565             ScanWatchers();
566           }
567         }
568         CheckInhibit();
569       }
570     }
571 
572     // Optimized case if the variable is bound.
VariableBound()573     void VariableBound() {
574       DCHECK(variable_->Bound());
575       const int64_t value = variable_->Min();
576       for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
577         const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
578         w.second->SetValue(w.first == value);
579       }
580       watchers_.RemoveAll();
581       var_demon_->inhibit(solver());
582     }
583 
584     // Scans all the watchers to check and assign them.
ScanWatchers()585     void ScanWatchers() {
586       for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
587         const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
588         if (!variable_->Contains(w.first)) {
589           IntVar* const boolvar = w.second;
590           boolvar->SetValue(0);
591           watchers_.RemoveAt(pos);
592         }
593       }
594     }
595 
596     // If the set of active watchers is empty, we can inhibit the demon on the
597     // main variable.
CheckInhibit()598     void CheckInhibit() {
599       if (watchers_.Empty()) {
600         var_demon_->inhibit(solver());
601       }
602     }
603 
Accept(ModelVisitor * const visitor) const604     void Accept(ModelVisitor* const visitor) const override {
605       visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
606       visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
607                                               variable_);
608       std::vector<int64_t> all_coefficients;
609       std::vector<IntVar*> all_bool_vars;
610       for (int position = watchers_.start(); position < watchers_.end();
611            ++position) {
612         const std::pair<int64_t, IntVar*>& w = watchers_.At(position);
613         all_coefficients.push_back(w.first);
614         all_bool_vars.push_back(w.second);
615       }
616       visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
617                                                  all_bool_vars);
618       visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
619                                          all_coefficients);
620       visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
621     }
622 
DebugString() const623     std::string DebugString() const override {
624       return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
625     }
626 
627    private:
628     DomainIntVar* const variable_;
629     IntVarIterator* const hole_iterator_;
630     RevSwitch posted_;
631     Demon* var_demon_;
632     RevIntPtrMap<IntVar> watchers_;
633   };
634 
635   // Optimized case for small maps.
636   class DenseValueWatcher : public BaseValueWatcher {
637    public:
638     class WatchDemon : public Demon {
639      public:
WatchDemon(DenseValueWatcher * const watcher,int64_t value,IntVar * var)640       WatchDemon(DenseValueWatcher* const watcher, int64_t value, IntVar* var)
641           : value_watcher_(watcher), value_(value), var_(var) {}
~WatchDemon()642       ~WatchDemon() override {}
643 
Run(Solver * const solver)644       void Run(Solver* const solver) override {
645         value_watcher_->ProcessValueWatcher(value_, var_);
646       }
647 
648      private:
649       DenseValueWatcher* const value_watcher_;
650       const int64_t value_;
651       IntVar* const var_;
652     };
653 
654     class VarDemon : public Demon {
655      public:
VarDemon(DenseValueWatcher * const watcher)656       explicit VarDemon(DenseValueWatcher* const watcher)
657           : value_watcher_(watcher) {}
658 
~VarDemon()659       ~VarDemon() override {}
660 
Run(Solver * const solver)661       void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
662 
663      private:
664       DenseValueWatcher* const value_watcher_;
665     };
666 
DenseValueWatcher(Solver * const solver,DomainIntVar * const variable)667     DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
668         : BaseValueWatcher(solver),
669           variable_(variable),
670           hole_iterator_(variable_->MakeHoleIterator(true)),
671           var_demon_(nullptr),
672           offset_(variable->Min()),
673           watchers_(variable->Max() - variable->Min() + 1, nullptr),
674           active_watchers_(0) {}
675 
~DenseValueWatcher()676     ~DenseValueWatcher() override {}
677 
GetOrMakeValueWatcher(int64_t value)678     IntVar* GetOrMakeValueWatcher(int64_t value) override {
679       const int64_t var_max = offset_ + watchers_.size() - 1;  // Bad cast.
680       if (value < offset_ || value > var_max) {
681         return solver()->MakeIntConst(0);
682       }
683       const int index = value - offset_;
684       IntVar* const watcher = watchers_[index];
685       if (watcher != nullptr) return watcher;
686       if (variable_->Contains(value)) {
687         if (variable_->Bound()) {
688           return solver()->MakeIntConst(1);
689         } else {
690           const std::string vname = variable_->HasName()
691                                         ? variable_->name()
692                                         : variable_->DebugString();
693           const std::string bname =
694               absl::StrFormat("Watch<%s == %d>", vname, value);
695           IntVar* const boolvar = solver()->MakeBoolVar(bname);
696           RevInsert(index, boolvar);
697           if (posted_.Switched()) {
698             boolvar->WhenBound(
699                 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
700             var_demon_->desinhibit(solver());
701           }
702           return boolvar;
703         }
704       } else {
705         return variable_->solver()->MakeIntConst(0);
706       }
707     }
708 
SetValueWatcher(IntVar * const boolvar,int64_t value)709     void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
710       const int index = value - offset_;
711       CHECK(watchers_[index] == nullptr);
712       if (!boolvar->Bound()) {
713         RevInsert(index, boolvar);
714         if (posted_.Switched() && !boolvar->Bound()) {
715           boolvar->WhenBound(
716               solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
717           var_demon_->desinhibit(solver());
718         }
719       }
720     }
721 
Post()722     void Post() override {
723       var_demon_ = solver()->RevAlloc(new VarDemon(this));
724       variable_->WhenDomain(var_demon_);
725       for (int pos = 0; pos < watchers_.size(); ++pos) {
726         const int64_t value = pos + offset_;
727         IntVar* const boolvar = watchers_[pos];
728         if (boolvar != nullptr && !boolvar->Bound() &&
729             variable_->Contains(value)) {
730           boolvar->WhenBound(
731               solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
732         }
733       }
734       posted_.Switch(solver());
735     }
736 
InitialPropagate()737     void InitialPropagate() override {
738       if (variable_->Bound()) {
739         VariableBound();
740       } else {
741         for (int pos = 0; pos < watchers_.size(); ++pos) {
742           IntVar* const boolvar = watchers_[pos];
743           if (boolvar == nullptr) continue;
744           const int64_t value = pos + offset_;
745           if (!variable_->Contains(value)) {
746             boolvar->SetValue(0);
747             RevRemove(pos);
748           } else if (boolvar->Bound()) {
749             ProcessValueWatcher(value, boolvar);
750             RevRemove(pos);
751           }
752         }
753         if (active_watchers_.Value() == 0) {
754           var_demon_->inhibit(solver());
755         }
756       }
757     }
758 
ProcessValueWatcher(int64_t value,IntVar * boolvar)759     void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
760       if (boolvar->Min() == 0) {
761         variable_->RemoveValue(value);
762       } else {
763         variable_->SetValue(value);
764       }
765     }
766 
ProcessVar()767     void ProcessVar() {
768       if (variable_->Bound()) {
769         VariableBound();
770       } else {
771         // Brute force loop for small numbers of watchers.
772         ScanWatchers();
773         if (active_watchers_.Value() == 0) {
774           var_demon_->inhibit(solver());
775         }
776       }
777     }
778 
779     // Optimized case if the variable is bound.
VariableBound()780     void VariableBound() {
781       DCHECK(variable_->Bound());
782       const int64_t value = variable_->Min();
783       for (int pos = 0; pos < watchers_.size(); ++pos) {
784         IntVar* const boolvar = watchers_[pos];
785         if (boolvar != nullptr) {
786           boolvar->SetValue(pos + offset_ == value);
787           RevRemove(pos);
788         }
789       }
790       var_demon_->inhibit(solver());
791     }
792 
793     // Scans all the watchers to check and assign them.
ScanWatchers()794     void ScanWatchers() {
795       const int64_t old_min_index = variable_->OldMin() - offset_;
796       const int64_t old_max_index = variable_->OldMax() - offset_;
797       const int64_t min_index = variable_->Min() - offset_;
798       const int64_t max_index = variable_->Max() - offset_;
799       for (int pos = old_min_index; pos < min_index; ++pos) {
800         IntVar* const boolvar = watchers_[pos];
801         if (boolvar != nullptr) {
802           boolvar->SetValue(0);
803           RevRemove(pos);
804         }
805       }
806       for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
807         IntVar* const boolvar = watchers_[pos];
808         if (boolvar != nullptr) {
809           boolvar->SetValue(0);
810           RevRemove(pos);
811         }
812       }
813       BitSet* const bitset = variable_->bitset();
814       if (bitset != nullptr) {
815         if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
816           for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
817             IntVar* const boolvar = watchers_[hole - offset_];
818             if (boolvar != nullptr) {
819               boolvar->SetValue(0);
820               RevRemove(hole - offset_);
821             }
822           }
823         } else {
824           for (int pos = min_index + 1; pos < max_index; ++pos) {
825             IntVar* const boolvar = watchers_[pos];
826             if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
827               boolvar->SetValue(0);
828               RevRemove(pos);
829             }
830           }
831         }
832       }
833     }
834 
RevRemove(int pos)835     void RevRemove(int pos) {
836       solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
837       watchers_[pos] = nullptr;
838       active_watchers_.Decr(solver());
839     }
840 
RevInsert(int pos,IntVar * boolvar)841     void RevInsert(int pos, IntVar* boolvar) {
842       solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
843       watchers_[pos] = boolvar;
844       active_watchers_.Incr(solver());
845     }
846 
Accept(ModelVisitor * const visitor) const847     void Accept(ModelVisitor* const visitor) const override {
848       visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
849       visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
850                                               variable_);
851       std::vector<int64_t> all_coefficients;
852       std::vector<IntVar*> all_bool_vars;
853       for (int position = 0; position < watchers_.size(); ++position) {
854         if (watchers_[position] != nullptr) {
855           all_coefficients.push_back(position + offset_);
856           all_bool_vars.push_back(watchers_[position]);
857         }
858       }
859       visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
860                                                  all_bool_vars);
861       visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
862                                          all_coefficients);
863       visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
864     }
865 
DebugString() const866     std::string DebugString() const override {
867       return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
868     }
869 
870    private:
871     DomainIntVar* const variable_;
872     IntVarIterator* const hole_iterator_;
873     RevSwitch posted_;
874     Demon* var_demon_;
875     const int64_t offset_;
876     std::vector<IntVar*> watchers_;
877     NumericalRev<int> active_watchers_;
878   };
879 
880   class BaseUpperBoundWatcher : public Constraint {
881    public:
BaseUpperBoundWatcher(Solver * const solver)882     explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
883 
~BaseUpperBoundWatcher()884     ~BaseUpperBoundWatcher() override {}
885 
886     virtual IntVar* GetOrMakeUpperBoundWatcher(int64_t value) = 0;
887 
888     virtual void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) = 0;
889   };
890 
891   // This class watches the bounds of the variable and updates the
892   // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
893   // accordingly.
894   class UpperBoundWatcher : public BaseUpperBoundWatcher {
895    public:
896     class WatchDemon : public Demon {
897      public:
WatchDemon(UpperBoundWatcher * const watcher,int64_t index,IntVar * const var)898       WatchDemon(UpperBoundWatcher* const watcher, int64_t index,
899                  IntVar* const var)
900           : value_watcher_(watcher), index_(index), var_(var) {}
~WatchDemon()901       ~WatchDemon() override {}
902 
Run(Solver * const solver)903       void Run(Solver* const solver) override {
904         value_watcher_->ProcessUpperBoundWatcher(index_, var_);
905       }
906 
907      private:
908       UpperBoundWatcher* const value_watcher_;
909       const int64_t index_;
910       IntVar* const var_;
911     };
912 
913     class VarDemon : public Demon {
914      public:
VarDemon(UpperBoundWatcher * const watcher)915       explicit VarDemon(UpperBoundWatcher* const watcher)
916           : value_watcher_(watcher) {}
~VarDemon()917       ~VarDemon() override {}
918 
Run(Solver * const solver)919       void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
920 
921      private:
922       UpperBoundWatcher* const value_watcher_;
923     };
924 
UpperBoundWatcher(Solver * const solver,DomainIntVar * const variable)925     UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
926         : BaseUpperBoundWatcher(solver),
927           variable_(variable),
928           var_demon_(nullptr),
929           watchers_(solver, variable->Min(), variable->Max()),
930           start_(0),
931           end_(0),
932           sorted_(false) {}
933 
~UpperBoundWatcher()934     ~UpperBoundWatcher() override {}
935 
GetOrMakeUpperBoundWatcher(int64_t value)936     IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
937       IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
938       if (watcher != nullptr) {
939         return watcher;
940       }
941       if (variable_->Max() >= value) {
942         if (variable_->Min() >= value) {
943           return solver()->MakeIntConst(1);
944         } else {
945           const std::string vname = variable_->HasName()
946                                         ? variable_->name()
947                                         : variable_->DebugString();
948           const std::string bname =
949               absl::StrFormat("Watch<%s >= %d>", vname, value);
950           IntVar* const boolvar = solver()->MakeBoolVar(bname);
951           watchers_.UnsafeRevInsert(value, boolvar);
952           if (posted_.Switched()) {
953             boolvar->WhenBound(
954                 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
955             var_demon_->desinhibit(solver());
956             sorted_ = false;
957           }
958           return boolvar;
959         }
960       } else {
961         return variable_->solver()->MakeIntConst(0);
962       }
963     }
964 
SetUpperBoundWatcher(IntVar * const boolvar,int64_t value)965     void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
966       CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
967       watchers_.UnsafeRevInsert(value, boolvar);
968       if (posted_.Switched() && !boolvar->Bound()) {
969         boolvar->WhenBound(
970             solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
971         var_demon_->desinhibit(solver());
972         sorted_ = false;
973       }
974     }
975 
Post()976     void Post() override {
977       const int kTooSmallToSort = 8;
978       var_demon_ = solver()->RevAlloc(new VarDemon(this));
979       variable_->WhenRange(var_demon_);
980 
981       if (watchers_.Size() > kTooSmallToSort) {
982         watchers_.SortActive();
983         sorted_ = true;
984         start_.SetValue(solver(), watchers_.start());
985         end_.SetValue(solver(), watchers_.end() - 1);
986       }
987 
988       for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
989         const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
990         IntVar* const boolvar = w.second;
991         const int64_t value = w.first;
992         if (!boolvar->Bound() && value > variable_->Min() &&
993             value <= variable_->Max()) {
994           boolvar->WhenBound(
995               solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
996         }
997       }
998       posted_.Switch(solver());
999     }
1000 
InitialPropagate()1001     void InitialPropagate() override {
1002       const int64_t var_min = variable_->Min();
1003       const int64_t var_max = variable_->Max();
1004       if (sorted_) {
1005         while (start_.Value() <= end_.Value()) {
1006           const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1007           if (w.first <= var_min) {
1008             w.second->SetValue(1);
1009             start_.Incr(solver());
1010           } else {
1011             break;
1012           }
1013         }
1014         while (end_.Value() >= start_.Value()) {
1015           const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1016           if (w.first > var_max) {
1017             w.second->SetValue(0);
1018             end_.Decr(solver());
1019           } else {
1020             break;
1021           }
1022         }
1023         for (int i = start_.Value(); i <= end_.Value(); ++i) {
1024           const std::pair<int64_t, IntVar*>& w = watchers_.At(i);
1025           if (w.second->Bound()) {
1026             ProcessUpperBoundWatcher(w.first, w.second);
1027           }
1028         }
1029         if (start_.Value() > end_.Value()) {
1030           var_demon_->inhibit(solver());
1031         }
1032       } else {
1033         for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1034           const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1035           const int64_t value = w.first;
1036           IntVar* const boolvar = w.second;
1037 
1038           if (value <= var_min) {
1039             boolvar->SetValue(1);
1040             watchers_.RemoveAt(pos);
1041           } else if (value > var_max) {
1042             boolvar->SetValue(0);
1043             watchers_.RemoveAt(pos);
1044           } else if (boolvar->Bound()) {
1045             ProcessUpperBoundWatcher(value, boolvar);
1046             watchers_.RemoveAt(pos);
1047           }
1048         }
1049       }
1050     }
1051 
Accept(ModelVisitor * const visitor) const1052     void Accept(ModelVisitor* const visitor) const override {
1053       visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1054       visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1055                                               variable_);
1056       std::vector<int64_t> all_coefficients;
1057       std::vector<IntVar*> all_bool_vars;
1058       for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1059         const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1060         all_coefficients.push_back(w.first);
1061         all_bool_vars.push_back(w.second);
1062       }
1063       visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1064                                                  all_bool_vars);
1065       visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1066                                          all_coefficients);
1067       visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1068     }
1069 
DebugString() const1070     std::string DebugString() const override {
1071       return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1072     }
1073 
1074    private:
ProcessUpperBoundWatcher(int64_t value,IntVar * const boolvar)1075     void ProcessUpperBoundWatcher(int64_t value, IntVar* const boolvar) {
1076       if (boolvar->Min() == 0) {
1077         variable_->SetMax(value - 1);
1078       } else {
1079         variable_->SetMin(value);
1080       }
1081     }
1082 
ProcessVar()1083     void ProcessVar() {
1084       const int64_t var_min = variable_->Min();
1085       const int64_t var_max = variable_->Max();
1086       if (sorted_) {
1087         while (start_.Value() <= end_.Value()) {
1088           const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1089           if (w.first <= var_min) {
1090             w.second->SetValue(1);
1091             start_.Incr(solver());
1092           } else {
1093             break;
1094           }
1095         }
1096         while (end_.Value() >= start_.Value()) {
1097           const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1098           if (w.first > var_max) {
1099             w.second->SetValue(0);
1100             end_.Decr(solver());
1101           } else {
1102             break;
1103           }
1104         }
1105         if (start_.Value() > end_.Value()) {
1106           var_demon_->inhibit(solver());
1107         }
1108       } else {
1109         for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1110           const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1111           const int64_t value = w.first;
1112           IntVar* const boolvar = w.second;
1113 
1114           if (value <= var_min) {
1115             boolvar->SetValue(1);
1116             watchers_.RemoveAt(pos);
1117           } else if (value > var_max) {
1118             boolvar->SetValue(0);
1119             watchers_.RemoveAt(pos);
1120           }
1121         }
1122         if (watchers_.Empty()) {
1123           var_demon_->inhibit(solver());
1124         }
1125       }
1126     }
1127 
1128     DomainIntVar* const variable_;
1129     RevSwitch posted_;
1130     Demon* var_demon_;
1131     RevIntPtrMap<IntVar> watchers_;
1132     NumericalRev<int> start_;
1133     NumericalRev<int> end_;
1134     bool sorted_;
1135   };
1136 
1137   // Optimized case for small maps.
1138   class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1139    public:
1140     class WatchDemon : public Demon {
1141      public:
WatchDemon(DenseUpperBoundWatcher * const watcher,int64_t value,IntVar * var)1142       WatchDemon(DenseUpperBoundWatcher* const watcher, int64_t value,
1143                  IntVar* var)
1144           : value_watcher_(watcher), value_(value), var_(var) {}
~WatchDemon()1145       ~WatchDemon() override {}
1146 
Run(Solver * const solver)1147       void Run(Solver* const solver) override {
1148         value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1149       }
1150 
1151      private:
1152       DenseUpperBoundWatcher* const value_watcher_;
1153       const int64_t value_;
1154       IntVar* const var_;
1155     };
1156 
1157     class VarDemon : public Demon {
1158      public:
VarDemon(DenseUpperBoundWatcher * const watcher)1159       explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1160           : value_watcher_(watcher) {}
1161 
~VarDemon()1162       ~VarDemon() override {}
1163 
Run(Solver * const solver)1164       void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1165 
1166      private:
1167       DenseUpperBoundWatcher* const value_watcher_;
1168     };
1169 
DenseUpperBoundWatcher(Solver * const solver,DomainIntVar * const variable)1170     DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1171         : BaseUpperBoundWatcher(solver),
1172           variable_(variable),
1173           var_demon_(nullptr),
1174           offset_(variable->Min()),
1175           watchers_(variable->Max() - variable->Min() + 1, nullptr),
1176           active_watchers_(0) {}
1177 
~DenseUpperBoundWatcher()1178     ~DenseUpperBoundWatcher() override {}
1179 
GetOrMakeUpperBoundWatcher(int64_t value)1180     IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
1181       if (variable_->Max() >= value) {
1182         if (variable_->Min() >= value) {
1183           return solver()->MakeIntConst(1);
1184         } else {
1185           const std::string vname = variable_->HasName()
1186                                         ? variable_->name()
1187                                         : variable_->DebugString();
1188           const std::string bname =
1189               absl::StrFormat("Watch<%s >= %d>", vname, value);
1190           IntVar* const boolvar = solver()->MakeBoolVar(bname);
1191           RevInsert(value - offset_, boolvar);
1192           if (posted_.Switched()) {
1193             boolvar->WhenBound(
1194                 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1195             var_demon_->desinhibit(solver());
1196           }
1197           return boolvar;
1198         }
1199       } else {
1200         return variable_->solver()->MakeIntConst(0);
1201       }
1202     }
1203 
SetUpperBoundWatcher(IntVar * const boolvar,int64_t value)1204     void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
1205       const int index = value - offset_;
1206       CHECK(watchers_[index] == nullptr);
1207       if (!boolvar->Bound()) {
1208         RevInsert(index, boolvar);
1209         if (posted_.Switched() && !boolvar->Bound()) {
1210           boolvar->WhenBound(
1211               solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1212           var_demon_->desinhibit(solver());
1213         }
1214       }
1215     }
1216 
Post()1217     void Post() override {
1218       var_demon_ = solver()->RevAlloc(new VarDemon(this));
1219       variable_->WhenRange(var_demon_);
1220       for (int pos = 0; pos < watchers_.size(); ++pos) {
1221         const int64_t value = pos + offset_;
1222         IntVar* const boolvar = watchers_[pos];
1223         if (boolvar != nullptr && !boolvar->Bound() &&
1224             value > variable_->Min() && value <= variable_->Max()) {
1225           boolvar->WhenBound(
1226               solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1227         }
1228       }
1229       posted_.Switch(solver());
1230     }
1231 
InitialPropagate()1232     void InitialPropagate() override {
1233       for (int pos = 0; pos < watchers_.size(); ++pos) {
1234         IntVar* const boolvar = watchers_[pos];
1235         if (boolvar == nullptr) continue;
1236         const int64_t value = pos + offset_;
1237         if (value <= variable_->Min()) {
1238           boolvar->SetValue(1);
1239           RevRemove(pos);
1240         } else if (value > variable_->Max()) {
1241           boolvar->SetValue(0);
1242           RevRemove(pos);
1243         } else if (boolvar->Bound()) {
1244           ProcessUpperBoundWatcher(value, boolvar);
1245           RevRemove(pos);
1246         }
1247       }
1248       if (active_watchers_.Value() == 0) {
1249         var_demon_->inhibit(solver());
1250       }
1251     }
1252 
ProcessUpperBoundWatcher(int64_t value,IntVar * boolvar)1253     void ProcessUpperBoundWatcher(int64_t value, IntVar* boolvar) {
1254       if (boolvar->Min() == 0) {
1255         variable_->SetMax(value - 1);
1256       } else {
1257         variable_->SetMin(value);
1258       }
1259     }
1260 
ProcessVar()1261     void ProcessVar() {
1262       const int64_t old_min_index = variable_->OldMin() - offset_;
1263       const int64_t old_max_index = variable_->OldMax() - offset_;
1264       const int64_t min_index = variable_->Min() - offset_;
1265       const int64_t max_index = variable_->Max() - offset_;
1266       for (int pos = old_min_index; pos <= min_index; ++pos) {
1267         IntVar* const boolvar = watchers_[pos];
1268         if (boolvar != nullptr) {
1269           boolvar->SetValue(1);
1270           RevRemove(pos);
1271         }
1272       }
1273 
1274       for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1275         IntVar* const boolvar = watchers_[pos];
1276         if (boolvar != nullptr) {
1277           boolvar->SetValue(0);
1278           RevRemove(pos);
1279         }
1280       }
1281       if (active_watchers_.Value() == 0) {
1282         var_demon_->inhibit(solver());
1283       }
1284     }
1285 
RevRemove(int pos)1286     void RevRemove(int pos) {
1287       solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1288       watchers_[pos] = nullptr;
1289       active_watchers_.Decr(solver());
1290     }
1291 
RevInsert(int pos,IntVar * boolvar)1292     void RevInsert(int pos, IntVar* boolvar) {
1293       solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1294       watchers_[pos] = boolvar;
1295       active_watchers_.Incr(solver());
1296     }
1297 
Accept(ModelVisitor * const visitor) const1298     void Accept(ModelVisitor* const visitor) const override {
1299       visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1300       visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1301                                               variable_);
1302       std::vector<int64_t> all_coefficients;
1303       std::vector<IntVar*> all_bool_vars;
1304       for (int position = 0; position < watchers_.size(); ++position) {
1305         if (watchers_[position] != nullptr) {
1306           all_coefficients.push_back(position + offset_);
1307           all_bool_vars.push_back(watchers_[position]);
1308         }
1309       }
1310       visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1311                                                  all_bool_vars);
1312       visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1313                                          all_coefficients);
1314       visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1315     }
1316 
DebugString() const1317     std::string DebugString() const override {
1318       return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1319                              variable_->DebugString());
1320     }
1321 
1322    private:
1323     DomainIntVar* const variable_;
1324     RevSwitch posted_;
1325     Demon* var_demon_;
1326     const int64_t offset_;
1327     std::vector<IntVar*> watchers_;
1328     NumericalRev<int> active_watchers_;
1329   };
1330 
1331   // ----- Main Class -----
1332   DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
1333                const std::string& name);
1334   DomainIntVar(Solver* const s, const std::vector<int64_t>& sorted_values,
1335                const std::string& name);
1336   ~DomainIntVar() override;
1337 
Min() const1338   int64_t Min() const override { return min_.Value(); }
1339   void SetMin(int64_t m) override;
Max() const1340   int64_t Max() const override { return max_.Value(); }
1341   void SetMax(int64_t m) override;
1342   void SetRange(int64_t mi, int64_t ma) override;
1343   void SetValue(int64_t v) override;
Bound() const1344   bool Bound() const override { return (min_.Value() == max_.Value()); }
Value() const1345   int64_t Value() const override {
1346     CHECK_EQ(min_.Value(), max_.Value())
1347         << " variable " << DebugString() << " is not bound.";
1348     return min_.Value();
1349   }
1350   void RemoveValue(int64_t v) override;
1351   void RemoveInterval(int64_t l, int64_t u) override;
1352   void CreateBits();
WhenBound(Demon * d)1353   void WhenBound(Demon* d) override {
1354     if (min_.Value() != max_.Value()) {
1355       if (d->priority() == Solver::DELAYED_PRIORITY) {
1356         delayed_bound_demons_.PushIfNotTop(solver(),
1357                                            solver()->RegisterDemon(d));
1358       } else {
1359         bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1360       }
1361     }
1362   }
WhenRange(Demon * d)1363   void WhenRange(Demon* d) override {
1364     if (min_.Value() != max_.Value()) {
1365       if (d->priority() == Solver::DELAYED_PRIORITY) {
1366         delayed_range_demons_.PushIfNotTop(solver(),
1367                                            solver()->RegisterDemon(d));
1368       } else {
1369         range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1370       }
1371     }
1372   }
WhenDomain(Demon * d)1373   void WhenDomain(Demon* d) override {
1374     if (min_.Value() != max_.Value()) {
1375       if (d->priority() == Solver::DELAYED_PRIORITY) {
1376         delayed_domain_demons_.PushIfNotTop(solver(),
1377                                             solver()->RegisterDemon(d));
1378       } else {
1379         domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1380       }
1381     }
1382   }
1383 
IsEqual(int64_t constant)1384   IntVar* IsEqual(int64_t constant) override {
1385     Solver* const s = solver();
1386     if (constant == min_.Value() && value_watcher_ == nullptr) {
1387       return s->MakeIsLessOrEqualCstVar(this, constant);
1388     }
1389     if (constant == max_.Value() && value_watcher_ == nullptr) {
1390       return s->MakeIsGreaterOrEqualCstVar(this, constant);
1391     }
1392     if (!Contains(constant)) {
1393       return s->MakeIntConst(int64_t{0});
1394     }
1395     if (Bound() && min_.Value() == constant) {
1396       return s->MakeIntConst(int64_t{1});
1397     }
1398     IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1399         this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1400     if (cache != nullptr) {
1401       return cache->Var();
1402     } else {
1403       if (value_watcher_ == nullptr) {
1404         if (CapSub(Max(), Min()) <= 256) {
1405           solver()->SaveAndSetValue(
1406               reinterpret_cast<void**>(&value_watcher_),
1407               reinterpret_cast<void*>(
1408                   solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1409 
1410         } else {
1411           solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1412                                     reinterpret_cast<void*>(solver()->RevAlloc(
1413                                         new ValueWatcher(solver(), this))));
1414         }
1415         solver()->AddConstraint(value_watcher_);
1416       }
1417       IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1418       s->Cache()->InsertExprConstantExpression(
1419           boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1420       return boolvar;
1421     }
1422   }
1423 
SetIsEqual(const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)1424   Constraint* SetIsEqual(const std::vector<int64_t>& values,
1425                          const std::vector<IntVar*>& vars) {
1426     if (value_watcher_ == nullptr) {
1427       solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1428                                 reinterpret_cast<void*>(solver()->RevAlloc(
1429                                     new ValueWatcher(solver(), this))));
1430       for (int i = 0; i < vars.size(); ++i) {
1431         value_watcher_->SetValueWatcher(vars[i], values[i]);
1432       }
1433     }
1434     return value_watcher_;
1435   }
1436 
IsDifferent(int64_t constant)1437   IntVar* IsDifferent(int64_t constant) override {
1438     Solver* const s = solver();
1439     if (constant == min_.Value() && value_watcher_ == nullptr) {
1440       return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1441     }
1442     if (constant == max_.Value() && value_watcher_ == nullptr) {
1443       return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1444     }
1445     if (!Contains(constant)) {
1446       return s->MakeIntConst(int64_t{1});
1447     }
1448     if (Bound() && min_.Value() == constant) {
1449       return s->MakeIntConst(int64_t{0});
1450     }
1451     IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1452         this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1453     if (cache != nullptr) {
1454       return cache->Var();
1455     } else {
1456       IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1457       s->Cache()->InsertExprConstantExpression(
1458           boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1459       return boolvar;
1460     }
1461   }
1462 
IsGreaterOrEqual(int64_t constant)1463   IntVar* IsGreaterOrEqual(int64_t constant) override {
1464     Solver* const s = solver();
1465     if (max_.Value() < constant) {
1466       return s->MakeIntConst(int64_t{0});
1467     }
1468     if (min_.Value() >= constant) {
1469       return s->MakeIntConst(int64_t{1});
1470     }
1471     IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1472         this, constant, ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
1473     if (cache != nullptr) {
1474       return cache->Var();
1475     } else {
1476       if (bound_watcher_ == nullptr) {
1477         if (CapSub(Max(), Min()) <= 256) {
1478           solver()->SaveAndSetValue(
1479               reinterpret_cast<void**>(&bound_watcher_),
1480               reinterpret_cast<void*>(solver()->RevAlloc(
1481                   new DenseUpperBoundWatcher(solver(), this))));
1482           solver()->AddConstraint(bound_watcher_);
1483         } else {
1484           solver()->SaveAndSetValue(
1485               reinterpret_cast<void**>(&bound_watcher_),
1486               reinterpret_cast<void*>(
1487                   solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1488           solver()->AddConstraint(bound_watcher_);
1489         }
1490       }
1491       IntVar* const boolvar =
1492           bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1493       s->Cache()->InsertExprConstantExpression(
1494           boolvar, this, constant,
1495           ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
1496       return boolvar;
1497     }
1498   }
1499 
SetIsGreaterOrEqual(const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)1500   Constraint* SetIsGreaterOrEqual(const std::vector<int64_t>& values,
1501                                   const std::vector<IntVar*>& vars) {
1502     if (bound_watcher_ == nullptr) {
1503       if (CapSub(Max(), Min()) <= 256) {
1504         solver()->SaveAndSetValue(
1505             reinterpret_cast<void**>(&bound_watcher_),
1506             reinterpret_cast<void*>(solver()->RevAlloc(
1507                 new DenseUpperBoundWatcher(solver(), this))));
1508         solver()->AddConstraint(bound_watcher_);
1509       } else {
1510         solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1511                                   reinterpret_cast<void*>(solver()->RevAlloc(
1512                                       new UpperBoundWatcher(solver(), this))));
1513         solver()->AddConstraint(bound_watcher_);
1514       }
1515       for (int i = 0; i < values.size(); ++i) {
1516         bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1517       }
1518     }
1519     return bound_watcher_;
1520   }
1521 
IsLessOrEqual(int64_t constant)1522   IntVar* IsLessOrEqual(int64_t constant) override {
1523     Solver* const s = solver();
1524     IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1525         this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1526     if (cache != nullptr) {
1527       return cache->Var();
1528     } else {
1529       IntVar* const boolvar =
1530           s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1531       s->Cache()->InsertExprConstantExpression(
1532           boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1533       return boolvar;
1534     }
1535   }
1536 
1537   void Process();
1538   void Push();
1539   void CleanInProcess();
Size() const1540   uint64_t Size() const override {
1541     if (bits_ != nullptr) return bits_->Size();
1542     return (static_cast<uint64_t>(max_.Value()) -
1543             static_cast<uint64_t>(min_.Value()) + 1);
1544   }
Contains(int64_t v) const1545   bool Contains(int64_t v) const override {
1546     if (v < min_.Value() || v > max_.Value()) return false;
1547     return (bits_ == nullptr ? true : bits_->Contains(v));
1548   }
1549   IntVarIterator* MakeHoleIterator(bool reversible) const override;
1550   IntVarIterator* MakeDomainIterator(bool reversible) const override;
OldMin() const1551   int64_t OldMin() const override { return std::min(old_min_, min_.Value()); }
OldMax() const1552   int64_t OldMax() const override { return std::max(old_max_, max_.Value()); }
1553 
1554   std::string DebugString() const override;
bitset() const1555   BitSet* bitset() const { return bits_; }
VarType() const1556   int VarType() const override { return DOMAIN_INT_VAR; }
BaseName() const1557   std::string BaseName() const override { return "IntegerVar"; }
1558 
1559   friend class PlusCstDomainIntVar;
1560   friend class LinkExprAndDomainIntVar;
1561 
1562  private:
CheckOldMin()1563   void CheckOldMin() {
1564     if (old_min_ > min_.Value()) {
1565       old_min_ = min_.Value();
1566     }
1567   }
CheckOldMax()1568   void CheckOldMax() {
1569     if (old_max_ < max_.Value()) {
1570       old_max_ = max_.Value();
1571     }
1572   }
1573   Rev<int64_t> min_;
1574   Rev<int64_t> max_;
1575   int64_t old_min_;
1576   int64_t old_max_;
1577   int64_t new_min_;
1578   int64_t new_max_;
1579   SimpleRevFIFO<Demon*> bound_demons_;
1580   SimpleRevFIFO<Demon*> range_demons_;
1581   SimpleRevFIFO<Demon*> domain_demons_;
1582   SimpleRevFIFO<Demon*> delayed_bound_demons_;
1583   SimpleRevFIFO<Demon*> delayed_range_demons_;
1584   SimpleRevFIFO<Demon*> delayed_domain_demons_;
1585   QueueHandler handler_;
1586   bool in_process_;
1587   BitSet* bits_;
1588   BaseValueWatcher* value_watcher_;
1589   BaseUpperBoundWatcher* bound_watcher_;
1590 };
1591 
1592 // ----- BitSet -----
1593 
1594 // Return whether an integer interval [a..b] (inclusive) contains at most
1595 // K values, i.e. b - a < K, in a way that's robust to overflows.
1596 // For performance reasons, in opt mode it doesn't check that [a, b] is a
1597 // valid interval, nor that K is nonnegative.
ClosedIntervalNoLargerThan(int64_t a,int64_t b,int64_t K)1598 inline bool ClosedIntervalNoLargerThan(int64_t a, int64_t b, int64_t K) {
1599   DCHECK_LE(a, b);
1600   DCHECK_GE(K, 0);
1601   if (a > 0) {
1602     return a > b - K;
1603   } else {
1604     return a + K > b;
1605   }
1606 }
1607 
1608 class SimpleBitSet : public DomainIntVar::BitSet {
1609  public:
SimpleBitSet(Solver * const s,int64_t vmin,int64_t vmax)1610   SimpleBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1611       : BitSet(s),
1612         bits_(nullptr),
1613         stamps_(nullptr),
1614         omin_(vmin),
1615         omax_(vmax),
1616         size_(vmax - vmin + 1),
1617         bsize_(BitLength64(size_.Value())) {
1618     CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1619         << "Bitset too large: [" << vmin << ", " << vmax << "]";
1620     bits_ = new uint64_t[bsize_];
1621     stamps_ = new uint64_t[bsize_];
1622     for (int i = 0; i < bsize_; ++i) {
1623       const int bs =
1624           (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1625       bits_[i] = kAllBits64 >> bs;
1626       stamps_[i] = s->stamp() - 1;
1627     }
1628   }
1629 
SimpleBitSet(Solver * const s,const std::vector<int64_t> & sorted_values,int64_t vmin,int64_t vmax)1630   SimpleBitSet(Solver* const s, const std::vector<int64_t>& sorted_values,
1631                int64_t vmin, int64_t vmax)
1632       : BitSet(s),
1633         bits_(nullptr),
1634         stamps_(nullptr),
1635         omin_(vmin),
1636         omax_(vmax),
1637         size_(sorted_values.size()),
1638         bsize_(BitLength64(vmax - vmin + 1)) {
1639     CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1640         << "Bitset too large: [" << vmin << ", " << vmax << "]";
1641     bits_ = new uint64_t[bsize_];
1642     stamps_ = new uint64_t[bsize_];
1643     for (int i = 0; i < bsize_; ++i) {
1644       bits_[i] = uint64_t{0};
1645       stamps_[i] = s->stamp() - 1;
1646     }
1647     for (int i = 0; i < sorted_values.size(); ++i) {
1648       const int64_t val = sorted_values[i];
1649       DCHECK(!bit(val));
1650       const int offset = BitOffset64(val - omin_);
1651       const int pos = BitPos64(val - omin_);
1652       bits_[offset] |= OneBit64(pos);
1653     }
1654   }
1655 
~SimpleBitSet()1656   ~SimpleBitSet() override {
1657     delete[] bits_;
1658     delete[] stamps_;
1659   }
1660 
bit(int64_t val) const1661   bool bit(int64_t val) const { return IsBitSet64(bits_, val - omin_); }
1662 
ComputeNewMin(int64_t nmin,int64_t cmin,int64_t cmax)1663   int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1664     DCHECK_GE(nmin, cmin);
1665     DCHECK_LE(nmin, cmax);
1666     DCHECK_LE(cmin, cmax);
1667     DCHECK_GE(cmin, omin_);
1668     DCHECK_LE(cmax, omax_);
1669     const int64_t new_min =
1670         UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1671         omin_;
1672     const uint64_t removed_bits =
1673         BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1674     size_.Add(solver_, -removed_bits);
1675     return new_min;
1676   }
1677 
ComputeNewMax(int64_t nmax,int64_t cmin,int64_t cmax)1678   int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1679     DCHECK_GE(nmax, cmin);
1680     DCHECK_LE(nmax, cmax);
1681     DCHECK_LE(cmin, cmax);
1682     DCHECK_GE(cmin, omin_);
1683     DCHECK_LE(cmax, omax_);
1684     const int64_t new_max =
1685         UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1686         omin_;
1687     const uint64_t removed_bits =
1688         BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1689     size_.Add(solver_, -removed_bits);
1690     return new_max;
1691   }
1692 
SetValue(int64_t val)1693   bool SetValue(int64_t val) override {
1694     DCHECK_GE(val, omin_);
1695     DCHECK_LE(val, omax_);
1696     if (bit(val)) {
1697       size_.SetValue(solver_, 1);
1698       return true;
1699     }
1700     return false;
1701   }
1702 
Contains(int64_t val) const1703   bool Contains(int64_t val) const override {
1704     DCHECK_GE(val, omin_);
1705     DCHECK_LE(val, omax_);
1706     return bit(val);
1707   }
1708 
RemoveValue(int64_t val)1709   bool RemoveValue(int64_t val) override {
1710     if (val < omin_ || val > omax_ || !bit(val)) {
1711       return false;
1712     }
1713     // Bitset.
1714     const int64_t val_offset = val - omin_;
1715     const int offset = BitOffset64(val_offset);
1716     const uint64_t current_stamp = solver_->stamp();
1717     if (stamps_[offset] < current_stamp) {
1718       stamps_[offset] = current_stamp;
1719       solver_->SaveValue(&bits_[offset]);
1720     }
1721     const int pos = BitPos64(val_offset);
1722     bits_[offset] &= ~OneBit64(pos);
1723     // Size.
1724     size_.Decr(solver_);
1725     // Holes.
1726     InitHoles();
1727     AddHole(val);
1728     return true;
1729   }
Size() const1730   uint64_t Size() const override { return size_.Value(); }
1731 
DebugString() const1732   std::string DebugString() const override {
1733     std::string out;
1734     absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1735     for (int i = 0; i < bsize_; ++i) {
1736       absl::StrAppendFormat(&out, "%x", bits_[i]);
1737     }
1738     out += ")";
1739     return out;
1740   }
1741 
DelayRemoveValue(int64_t val)1742   void DelayRemoveValue(int64_t val) override { removed_.push_back(val); }
1743 
ApplyRemovedValues(DomainIntVar * var)1744   void ApplyRemovedValues(DomainIntVar* var) override {
1745     std::sort(removed_.begin(), removed_.end());
1746     for (std::vector<int64_t>::iterator it = removed_.begin();
1747          it != removed_.end(); ++it) {
1748       var->RemoveValue(*it);
1749     }
1750   }
1751 
ClearRemovedValues()1752   void ClearRemovedValues() override { removed_.clear(); }
1753 
pretty_DebugString(int64_t min,int64_t max) const1754   std::string pretty_DebugString(int64_t min, int64_t max) const override {
1755     std::string out;
1756     DCHECK(bit(min));
1757     DCHECK(bit(max));
1758     if (max != min) {
1759       int cumul = true;
1760       int64_t start_cumul = min;
1761       for (int64_t v = min + 1; v < max; ++v) {
1762         if (bit(v)) {
1763           if (!cumul) {
1764             cumul = true;
1765             start_cumul = v;
1766           }
1767         } else {
1768           if (cumul) {
1769             if (v == start_cumul + 1) {
1770               absl::StrAppendFormat(&out, "%d ", start_cumul);
1771             } else if (v == start_cumul + 2) {
1772               absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1773             } else {
1774               absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1775             }
1776             cumul = false;
1777           }
1778         }
1779       }
1780       if (cumul) {
1781         if (max == start_cumul + 1) {
1782           absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1783         } else {
1784           absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1785         }
1786       } else {
1787         absl::StrAppendFormat(&out, "%d", max);
1788       }
1789     } else {
1790       absl::StrAppendFormat(&out, "%d", min);
1791     }
1792     return out;
1793   }
1794 
MakeIterator()1795   DomainIntVar::BitSetIterator* MakeIterator() override {
1796     return new DomainIntVar::BitSetIterator(bits_, omin_);
1797   }
1798 
1799  private:
1800   uint64_t* bits_;
1801   uint64_t* stamps_;
1802   const int64_t omin_;
1803   const int64_t omax_;
1804   NumericalRev<int64_t> size_;
1805   const int bsize_;
1806   std::vector<int64_t> removed_;
1807 };
1808 
1809 // This is a special case where the bitset fits into one 64 bit integer.
1810 // In that case, there are no offset to compute.
1811 // Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1812 class SmallBitSet : public DomainIntVar::BitSet {
1813  public:
SmallBitSet(Solver * const s,int64_t vmin,int64_t vmax)1814   SmallBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1815       : BitSet(s),
1816         bits_(uint64_t{0}),
1817         stamp_(s->stamp() - 1),
1818         omin_(vmin),
1819         omax_(vmax),
1820         size_(vmax - vmin + 1) {
1821     CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1822     bits_ = OneRange64(0, size_.Value() - 1);
1823   }
1824 
SmallBitSet(Solver * const s,const std::vector<int64_t> & sorted_values,int64_t vmin,int64_t vmax)1825   SmallBitSet(Solver* const s, const std::vector<int64_t>& sorted_values,
1826               int64_t vmin, int64_t vmax)
1827       : BitSet(s),
1828         bits_(uint64_t{0}),
1829         stamp_(s->stamp() - 1),
1830         omin_(vmin),
1831         omax_(vmax),
1832         size_(sorted_values.size()) {
1833     CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1834     // We know the array is sorted and does not contains duplicate values.
1835     for (int i = 0; i < sorted_values.size(); ++i) {
1836       const int64_t val = sorted_values[i];
1837       DCHECK_GE(val, vmin);
1838       DCHECK_LE(val, vmax);
1839       DCHECK(!IsBitSet64(&bits_, val - omin_));
1840       bits_ |= OneBit64(val - omin_);
1841     }
1842   }
1843 
~SmallBitSet()1844   ~SmallBitSet() override {}
1845 
bit(int64_t val) const1846   bool bit(int64_t val) const {
1847     DCHECK_GE(val, omin_);
1848     DCHECK_LE(val, omax_);
1849     return (bits_ & OneBit64(val - omin_)) != 0;
1850   }
1851 
ComputeNewMin(int64_t nmin,int64_t cmin,int64_t cmax)1852   int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1853     DCHECK_GE(nmin, cmin);
1854     DCHECK_LE(nmin, cmax);
1855     DCHECK_LE(cmin, cmax);
1856     DCHECK_GE(cmin, omin_);
1857     DCHECK_LE(cmax, omax_);
1858     // We do not clean the bits between cmin and nmin.
1859     // But we use mask to look only at 'active' bits.
1860 
1861     // Create the mask and compute new bits
1862     const uint64_t new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1863     if (new_bits != uint64_t{0}) {
1864       // Compute new size and new min
1865       size_.SetValue(solver_, BitCount64(new_bits));
1866       if (bit(nmin)) {  // Common case, the new min is inside the bitset
1867         return nmin;
1868       }
1869       return LeastSignificantBitPosition64(new_bits) + omin_;
1870     } else {  // == 0 -> Fail()
1871       solver_->Fail();
1872       return std::numeric_limits<int64_t>::max();
1873     }
1874   }
1875 
ComputeNewMax(int64_t nmax,int64_t cmin,int64_t cmax)1876   int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1877     DCHECK_GE(nmax, cmin);
1878     DCHECK_LE(nmax, cmax);
1879     DCHECK_LE(cmin, cmax);
1880     DCHECK_GE(cmin, omin_);
1881     DCHECK_LE(cmax, omax_);
1882     // We do not clean the bits between nmax and cmax.
1883     // But we use mask to look only at 'active' bits.
1884 
1885     // Create the mask and compute new_bits
1886     const uint64_t new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1887     if (new_bits != uint64_t{0}) {
1888       // Compute new size and new min
1889       size_.SetValue(solver_, BitCount64(new_bits));
1890       if (bit(nmax)) {  // Common case, the new max is inside the bitset
1891         return nmax;
1892       }
1893       return MostSignificantBitPosition64(new_bits) + omin_;
1894     } else {  // == 0 -> Fail()
1895       solver_->Fail();
1896       return std::numeric_limits<int64_t>::min();
1897     }
1898   }
1899 
SetValue(int64_t val)1900   bool SetValue(int64_t val) override {
1901     DCHECK_GE(val, omin_);
1902     DCHECK_LE(val, omax_);
1903     // We do not clean the bits. We will use masks to ignore the bits
1904     // that should have been cleaned.
1905     if (bit(val)) {
1906       size_.SetValue(solver_, 1);
1907       return true;
1908     }
1909     return false;
1910   }
1911 
Contains(int64_t val) const1912   bool Contains(int64_t val) const override {
1913     DCHECK_GE(val, omin_);
1914     DCHECK_LE(val, omax_);
1915     return bit(val);
1916   }
1917 
RemoveValue(int64_t val)1918   bool RemoveValue(int64_t val) override {
1919     DCHECK_GE(val, omin_);
1920     DCHECK_LE(val, omax_);
1921     if (bit(val)) {
1922       // Bitset.
1923       const uint64_t current_stamp = solver_->stamp();
1924       if (stamp_ < current_stamp) {
1925         stamp_ = current_stamp;
1926         solver_->SaveValue(&bits_);
1927       }
1928       bits_ &= ~OneBit64(val - omin_);
1929       DCHECK(!bit(val));
1930       // Size.
1931       size_.Decr(solver_);
1932       // Holes.
1933       InitHoles();
1934       AddHole(val);
1935       return true;
1936     } else {
1937       return false;
1938     }
1939   }
1940 
Size() const1941   uint64_t Size() const override { return size_.Value(); }
1942 
DebugString() const1943   std::string DebugString() const override {
1944     return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1945   }
1946 
DelayRemoveValue(int64_t val)1947   void DelayRemoveValue(int64_t val) override {
1948     DCHECK_GE(val, omin_);
1949     DCHECK_LE(val, omax_);
1950     removed_.push_back(val);
1951   }
1952 
ApplyRemovedValues(DomainIntVar * var)1953   void ApplyRemovedValues(DomainIntVar* var) override {
1954     std::sort(removed_.begin(), removed_.end());
1955     for (std::vector<int64_t>::iterator it = removed_.begin();
1956          it != removed_.end(); ++it) {
1957       var->RemoveValue(*it);
1958     }
1959   }
1960 
ClearRemovedValues()1961   void ClearRemovedValues() override { removed_.clear(); }
1962 
pretty_DebugString(int64_t min,int64_t max) const1963   std::string pretty_DebugString(int64_t min, int64_t max) const override {
1964     std::string out;
1965     DCHECK(bit(min));
1966     DCHECK(bit(max));
1967     if (max != min) {
1968       int cumul = true;
1969       int64_t start_cumul = min;
1970       for (int64_t v = min + 1; v < max; ++v) {
1971         if (bit(v)) {
1972           if (!cumul) {
1973             cumul = true;
1974             start_cumul = v;
1975           }
1976         } else {
1977           if (cumul) {
1978             if (v == start_cumul + 1) {
1979               absl::StrAppendFormat(&out, "%d ", start_cumul);
1980             } else if (v == start_cumul + 2) {
1981               absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1982             } else {
1983               absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1984             }
1985             cumul = false;
1986           }
1987         }
1988       }
1989       if (cumul) {
1990         if (max == start_cumul + 1) {
1991           absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1992         } else {
1993           absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1994         }
1995       } else {
1996         absl::StrAppendFormat(&out, "%d", max);
1997       }
1998     } else {
1999       absl::StrAppendFormat(&out, "%d", min);
2000     }
2001     return out;
2002   }
2003 
MakeIterator()2004   DomainIntVar::BitSetIterator* MakeIterator() override {
2005     return new DomainIntVar::BitSetIterator(&bits_, omin_);
2006   }
2007 
2008  private:
2009   uint64_t bits_;
2010   uint64_t stamp_;
2011   const int64_t omin_;
2012   const int64_t omax_;
2013   NumericalRev<int64_t> size_;
2014   std::vector<int64_t> removed_;
2015 };
2016 
2017 class EmptyIterator : public IntVarIterator {
2018  public:
~EmptyIterator()2019   ~EmptyIterator() override {}
Init()2020   void Init() override {}
Ok() const2021   bool Ok() const override { return false; }
Value() const2022   int64_t Value() const override {
2023     LOG(FATAL) << "Should not be called";
2024     return 0LL;
2025   }
Next()2026   void Next() override {}
2027 };
2028 
2029 class RangeIterator : public IntVarIterator {
2030  public:
RangeIterator(const IntVar * const var)2031   explicit RangeIterator(const IntVar* const var)
2032       : var_(var),
2033         min_(std::numeric_limits<int64_t>::max()),
2034         max_(std::numeric_limits<int64_t>::min()),
2035         current_(-1) {}
2036 
~RangeIterator()2037   ~RangeIterator() override {}
2038 
Init()2039   void Init() override {
2040     min_ = var_->Min();
2041     max_ = var_->Max();
2042     current_ = min_;
2043   }
2044 
Ok() const2045   bool Ok() const override { return current_ <= max_; }
2046 
Value() const2047   int64_t Value() const override { return current_; }
2048 
Next()2049   void Next() override { current_++; }
2050 
2051  private:
2052   const IntVar* const var_;
2053   int64_t min_;
2054   int64_t max_;
2055   int64_t current_;
2056 };
2057 
2058 class DomainIntVarHoleIterator : public IntVarIterator {
2059  public:
DomainIntVarHoleIterator(const DomainIntVar * const v)2060   explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2061       : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2062 
~DomainIntVarHoleIterator()2063   ~DomainIntVarHoleIterator() override {}
2064 
Init()2065   void Init() override {
2066     bits_ = var_->bitset();
2067     if (bits_ != nullptr) {
2068       bits_->InitHoles();
2069       values_ = bits_->Holes().data();
2070       size_ = bits_->Holes().size();
2071     } else {
2072       values_ = nullptr;
2073       size_ = 0;
2074     }
2075     index_ = 0;
2076   }
2077 
Ok() const2078   bool Ok() const override { return index_ < size_; }
2079 
Value() const2080   int64_t Value() const override {
2081     DCHECK(bits_ != nullptr);
2082     DCHECK(index_ < size_);
2083     return values_[index_];
2084   }
2085 
Next()2086   void Next() override { index_++; }
2087 
2088  private:
2089   const DomainIntVar* const var_;
2090   DomainIntVar::BitSet* bits_;
2091   const int64_t* values_;
2092   int size_;
2093   int index_;
2094 };
2095 
2096 class DomainIntVarDomainIterator : public IntVarIterator {
2097  public:
DomainIntVarDomainIterator(const DomainIntVar * const v,bool reversible)2098   explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2099                                       bool reversible)
2100       : var_(v),
2101         bitset_iterator_(nullptr),
2102         min_(std::numeric_limits<int64_t>::max()),
2103         max_(std::numeric_limits<int64_t>::min()),
2104         current_(-1),
2105         reversible_(reversible) {}
2106 
~DomainIntVarDomainIterator()2107   ~DomainIntVarDomainIterator() override {
2108     if (!reversible_ && bitset_iterator_) {
2109       delete bitset_iterator_;
2110     }
2111   }
2112 
Init()2113   void Init() override {
2114     if (var_->bitset() != nullptr && !var_->Bound()) {
2115       if (reversible_) {
2116         if (!bitset_iterator_) {
2117           Solver* const solver = var_->solver();
2118           solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2119           bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2120         }
2121       } else {
2122         if (bitset_iterator_) {
2123           delete bitset_iterator_;
2124         }
2125         bitset_iterator_ = var_->bitset()->MakeIterator();
2126       }
2127       bitset_iterator_->Init(var_->Min(), var_->Max());
2128     } else {
2129       if (bitset_iterator_) {
2130         if (reversible_) {
2131           Solver* const solver = var_->solver();
2132           solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2133         } else {
2134           delete bitset_iterator_;
2135         }
2136         bitset_iterator_ = nullptr;
2137       }
2138       min_ = var_->Min();
2139       max_ = var_->Max();
2140       current_ = min_;
2141     }
2142   }
2143 
Ok() const2144   bool Ok() const override {
2145     return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2146   }
2147 
Value() const2148   int64_t Value() const override {
2149     return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2150   }
2151 
Next()2152   void Next() override {
2153     if (bitset_iterator_) {
2154       bitset_iterator_->Next();
2155     } else {
2156       current_++;
2157     }
2158   }
2159 
2160  private:
2161   const DomainIntVar* const var_;
2162   DomainIntVar::BitSetIterator* bitset_iterator_;
2163   int64_t min_;
2164   int64_t max_;
2165   int64_t current_;
2166   const bool reversible_;
2167 };
2168 
2169 class UnaryIterator : public IntVarIterator {
2170  public:
UnaryIterator(const IntVar * const v,bool hole,bool reversible)2171   UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2172       : iterator_(hole ? v->MakeHoleIterator(reversible)
2173                        : v->MakeDomainIterator(reversible)),
2174         reversible_(reversible) {}
2175 
~UnaryIterator()2176   ~UnaryIterator() override {
2177     if (!reversible_) {
2178       delete iterator_;
2179     }
2180   }
2181 
Init()2182   void Init() override { iterator_->Init(); }
2183 
Ok() const2184   bool Ok() const override { return iterator_->Ok(); }
2185 
Next()2186   void Next() override { iterator_->Next(); }
2187 
2188  protected:
2189   IntVarIterator* const iterator_;
2190   const bool reversible_;
2191 };
2192 
DomainIntVar(Solver * const s,int64_t vmin,int64_t vmax,const std::string & name)2193 DomainIntVar::DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
2194                            const std::string& name)
2195     : IntVar(s, name),
2196       min_(vmin),
2197       max_(vmax),
2198       old_min_(vmin),
2199       old_max_(vmax),
2200       new_min_(vmin),
2201       new_max_(vmax),
2202       handler_(this),
2203       in_process_(false),
2204       bits_(nullptr),
2205       value_watcher_(nullptr),
2206       bound_watcher_(nullptr) {}
2207 
DomainIntVar(Solver * const s,const std::vector<int64_t> & sorted_values,const std::string & name)2208 DomainIntVar::DomainIntVar(Solver* const s,
2209                            const std::vector<int64_t>& sorted_values,
2210                            const std::string& name)
2211     : IntVar(s, name),
2212       min_(std::numeric_limits<int64_t>::max()),
2213       max_(std::numeric_limits<int64_t>::min()),
2214       old_min_(std::numeric_limits<int64_t>::max()),
2215       old_max_(std::numeric_limits<int64_t>::min()),
2216       new_min_(std::numeric_limits<int64_t>::max()),
2217       new_max_(std::numeric_limits<int64_t>::min()),
2218       handler_(this),
2219       in_process_(false),
2220       bits_(nullptr),
2221       value_watcher_(nullptr),
2222       bound_watcher_(nullptr) {
2223   CHECK_GE(sorted_values.size(), 1);
2224   // We know that the vector is sorted and does not have duplicate values.
2225   const int64_t vmin = sorted_values.front();
2226   const int64_t vmax = sorted_values.back();
2227   const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2228 
2229   min_.SetValue(solver(), vmin);
2230   old_min_ = vmin;
2231   new_min_ = vmin;
2232   max_.SetValue(solver(), vmax);
2233   old_max_ = vmax;
2234   new_max_ = vmax;
2235 
2236   if (!contiguous) {
2237     if (vmax - vmin + 1 < 65) {
2238       bits_ = solver()->RevAlloc(
2239           new SmallBitSet(solver(), sorted_values, vmin, vmax));
2240     } else {
2241       bits_ = solver()->RevAlloc(
2242           new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2243     }
2244   }
2245 }
2246 
~DomainIntVar()2247 DomainIntVar::~DomainIntVar() {}
2248 
SetMin(int64_t m)2249 void DomainIntVar::SetMin(int64_t m) {
2250   if (m <= min_.Value()) return;
2251   if (m > max_.Value()) solver()->Fail();
2252   if (in_process_) {
2253     if (m > new_min_) {
2254       new_min_ = m;
2255       if (new_min_ > new_max_) {
2256         solver()->Fail();
2257       }
2258     }
2259   } else {
2260     CheckOldMin();
2261     const int64_t new_min =
2262         (bits_ == nullptr
2263              ? m
2264              : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2265     min_.SetValue(solver(), new_min);
2266     if (min_.Value() > max_.Value()) {
2267       solver()->Fail();
2268     }
2269     Push();
2270   }
2271 }
2272 
SetMax(int64_t m)2273 void DomainIntVar::SetMax(int64_t m) {
2274   if (m >= max_.Value()) return;
2275   if (m < min_.Value()) solver()->Fail();
2276   if (in_process_) {
2277     if (m < new_max_) {
2278       new_max_ = m;
2279       if (new_max_ < new_min_) {
2280         solver()->Fail();
2281       }
2282     }
2283   } else {
2284     CheckOldMax();
2285     const int64_t new_max =
2286         (bits_ == nullptr
2287              ? m
2288              : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2289     max_.SetValue(solver(), new_max);
2290     if (min_.Value() > max_.Value()) {
2291       solver()->Fail();
2292     }
2293     Push();
2294   }
2295 }
2296 
SetRange(int64_t mi,int64_t ma)2297 void DomainIntVar::SetRange(int64_t mi, int64_t ma) {
2298   if (mi == ma) {
2299     SetValue(mi);
2300   } else {
2301     if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2302     if (mi <= min_.Value() && ma >= max_.Value()) return;
2303     if (in_process_) {
2304       if (ma < new_max_) {
2305         new_max_ = ma;
2306       }
2307       if (mi > new_min_) {
2308         new_min_ = mi;
2309       }
2310       if (new_min_ > new_max_) {
2311         solver()->Fail();
2312       }
2313     } else {
2314       if (mi > min_.Value()) {
2315         CheckOldMin();
2316         const int64_t new_min =
2317             (bits_ == nullptr
2318                  ? mi
2319                  : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2320         min_.SetValue(solver(), new_min);
2321       }
2322       if (min_.Value() > ma) {
2323         solver()->Fail();
2324       }
2325       if (ma < max_.Value()) {
2326         CheckOldMax();
2327         const int64_t new_max =
2328             (bits_ == nullptr
2329                  ? ma
2330                  : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2331         max_.SetValue(solver(), new_max);
2332       }
2333       if (min_.Value() > max_.Value()) {
2334         solver()->Fail();
2335       }
2336       Push();
2337     }
2338   }
2339 }
2340 
SetValue(int64_t v)2341 void DomainIntVar::SetValue(int64_t v) {
2342   if (v != min_.Value() || v != max_.Value()) {
2343     if (v < min_.Value() || v > max_.Value()) {
2344       solver()->Fail();
2345     }
2346     if (in_process_) {
2347       if (v > new_max_ || v < new_min_) {
2348         solver()->Fail();
2349       }
2350       new_min_ = v;
2351       new_max_ = v;
2352     } else {
2353       if (bits_ && !bits_->SetValue(v)) {
2354         solver()->Fail();
2355       }
2356       CheckOldMin();
2357       CheckOldMax();
2358       min_.SetValue(solver(), v);
2359       max_.SetValue(solver(), v);
2360       Push();
2361     }
2362   }
2363 }
2364 
RemoveValue(int64_t v)2365 void DomainIntVar::RemoveValue(int64_t v) {
2366   if (v < min_.Value() || v > max_.Value()) return;
2367   if (v == min_.Value()) {
2368     SetMin(v + 1);
2369   } else if (v == max_.Value()) {
2370     SetMax(v - 1);
2371   } else {
2372     if (bits_ == nullptr) {
2373       CreateBits();
2374     }
2375     if (in_process_) {
2376       if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2377         bits_->DelayRemoveValue(v);
2378       }
2379     } else {
2380       if (bits_->RemoveValue(v)) {
2381         Push();
2382       }
2383     }
2384   }
2385 }
2386 
RemoveInterval(int64_t l,int64_t u)2387 void DomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2388   if (l <= min_.Value()) {
2389     SetMin(u + 1);
2390   } else if (u >= max_.Value()) {
2391     SetMax(l - 1);
2392   } else {
2393     for (int64_t v = l; v <= u; ++v) {
2394       RemoveValue(v);
2395     }
2396   }
2397 }
2398 
CreateBits()2399 void DomainIntVar::CreateBits() {
2400   solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2401   if (max_.Value() - min_.Value() < 64) {
2402     bits_ = solver()->RevAlloc(
2403         new SmallBitSet(solver(), min_.Value(), max_.Value()));
2404   } else {
2405     bits_ = solver()->RevAlloc(
2406         new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2407   }
2408 }
2409 
CleanInProcess()2410 void DomainIntVar::CleanInProcess() {
2411   in_process_ = false;
2412   if (bits_ != nullptr) {
2413     bits_->ClearHoles();
2414   }
2415 }
2416 
Push()2417 void DomainIntVar::Push() {
2418   const bool in_process = in_process_;
2419   EnqueueVar(&handler_);
2420   CHECK_EQ(in_process, in_process_);
2421 }
2422 
Process()2423 void DomainIntVar::Process() {
2424   CHECK(!in_process_);
2425   in_process_ = true;
2426   if (bits_ != nullptr) {
2427     bits_->ClearRemovedValues();
2428   }
2429   set_variable_to_clean_on_fail(this);
2430   new_min_ = min_.Value();
2431   new_max_ = max_.Value();
2432   const bool is_bound = min_.Value() == max_.Value();
2433   const bool range_changed =
2434       min_.Value() != OldMin() || max_.Value() != OldMax();
2435   // Process immediate demons.
2436   if (is_bound) {
2437     ExecuteAll(bound_demons_);
2438   }
2439   if (range_changed) {
2440     ExecuteAll(range_demons_);
2441   }
2442   ExecuteAll(domain_demons_);
2443 
2444   // Process delayed demons.
2445   if (is_bound) {
2446     EnqueueAll(delayed_bound_demons_);
2447   }
2448   if (range_changed) {
2449     EnqueueAll(delayed_range_demons_);
2450   }
2451   EnqueueAll(delayed_domain_demons_);
2452 
2453   // Everything went well if we arrive here. Let's clean the variable.
2454   set_variable_to_clean_on_fail(nullptr);
2455   CleanInProcess();
2456   old_min_ = min_.Value();
2457   old_max_ = max_.Value();
2458   if (min_.Value() < new_min_) {
2459     SetMin(new_min_);
2460   }
2461   if (max_.Value() > new_max_) {
2462     SetMax(new_max_);
2463   }
2464   if (bits_ != nullptr) {
2465     bits_->ApplyRemovedValues(this);
2466   }
2467 }
2468 
2469 #define COND_REV_ALLOC(rev, alloc) rev ? solver()->RevAlloc(alloc) : alloc;
2470 
MakeHoleIterator(bool reversible) const2471 IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2472   return COND_REV_ALLOC(reversible, new DomainIntVarHoleIterator(this));
2473 }
2474 
MakeDomainIterator(bool reversible) const2475 IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2476   return COND_REV_ALLOC(reversible,
2477                         new DomainIntVarDomainIterator(this, reversible));
2478 }
2479 
DebugString() const2480 std::string DomainIntVar::DebugString() const {
2481   std::string out;
2482   const std::string& var_name = name();
2483   if (!var_name.empty()) {
2484     out = var_name + "(";
2485   } else {
2486     out = "DomainIntVar(";
2487   }
2488   if (min_.Value() == max_.Value()) {
2489     absl::StrAppendFormat(&out, "%d", min_.Value());
2490   } else if (bits_ != nullptr) {
2491     out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2492   } else {
2493     absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2494   }
2495   out += ")";
2496   return out;
2497 }
2498 
2499 // ----- Real Boolean Var -----
2500 
2501 class ConcreteBooleanVar : public BooleanVar {
2502  public:
2503   // Utility classes
2504   class Handler : public Demon {
2505    public:
Handler(ConcreteBooleanVar * const var)2506     explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
~Handler()2507     ~Handler() override {}
Run(Solver * const s)2508     void Run(Solver* const s) override {
2509       s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2510       var_->Process();
2511       s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2512     }
priority() const2513     Solver::DemonPriority priority() const override {
2514       return Solver::VAR_PRIORITY;
2515     }
DebugString() const2516     std::string DebugString() const override {
2517       return absl::StrFormat("Handler(%s)", var_->DebugString());
2518     }
2519 
2520    private:
2521     ConcreteBooleanVar* const var_;
2522   };
2523 
ConcreteBooleanVar(Solver * const s,const std::string & name)2524   ConcreteBooleanVar(Solver* const s, const std::string& name)
2525       : BooleanVar(s, name), handler_(this) {}
2526 
~ConcreteBooleanVar()2527   ~ConcreteBooleanVar() override {}
2528 
SetValue(int64_t v)2529   void SetValue(int64_t v) override {
2530     if (value_ == kUnboundBooleanVarValue) {
2531       if ((v & 0xfffffffffffffffe) == 0) {
2532         InternalSaveBooleanVarValue(solver(), this);
2533         value_ = static_cast<int>(v);
2534         EnqueueVar(&handler_);
2535         return;
2536       }
2537     } else if (v == value_) {
2538       return;
2539     }
2540     solver()->Fail();
2541   }
2542 
Process()2543   void Process() {
2544     DCHECK_NE(value_, kUnboundBooleanVarValue);
2545     ExecuteAll(bound_demons_);
2546     for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2547          ++it) {
2548       EnqueueDelayedDemon(*it);
2549     }
2550   }
2551 
OldMin() const2552   int64_t OldMin() const override { return 0LL; }
OldMax() const2553   int64_t OldMax() const override { return 1LL; }
RestoreValue()2554   void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2555 
2556  private:
2557   Handler handler_;
2558 };
2559 
2560 // ----- IntConst -----
2561 
2562 class IntConst : public IntVar {
2563  public:
IntConst(Solver * const s,int64_t value,const std::string & name="")2564   IntConst(Solver* const s, int64_t value, const std::string& name = "")
2565       : IntVar(s, name), value_(value) {}
~IntConst()2566   ~IntConst() override {}
2567 
Min() const2568   int64_t Min() const override { return value_; }
SetMin(int64_t m)2569   void SetMin(int64_t m) override {
2570     if (m > value_) {
2571       solver()->Fail();
2572     }
2573   }
Max() const2574   int64_t Max() const override { return value_; }
SetMax(int64_t m)2575   void SetMax(int64_t m) override {
2576     if (m < value_) {
2577       solver()->Fail();
2578     }
2579   }
SetRange(int64_t l,int64_t u)2580   void SetRange(int64_t l, int64_t u) override {
2581     if (l > value_ || u < value_) {
2582       solver()->Fail();
2583     }
2584   }
SetValue(int64_t v)2585   void SetValue(int64_t v) override {
2586     if (v != value_) {
2587       solver()->Fail();
2588     }
2589   }
Bound() const2590   bool Bound() const override { return true; }
Value() const2591   int64_t Value() const override { return value_; }
RemoveValue(int64_t v)2592   void RemoveValue(int64_t v) override {
2593     if (v == value_) {
2594       solver()->Fail();
2595     }
2596   }
RemoveInterval(int64_t l,int64_t u)2597   void RemoveInterval(int64_t l, int64_t u) override {
2598     if (l <= value_ && value_ <= u) {
2599       solver()->Fail();
2600     }
2601   }
WhenBound(Demon * d)2602   void WhenBound(Demon* d) override {}
WhenRange(Demon * d)2603   void WhenRange(Demon* d) override {}
WhenDomain(Demon * d)2604   void WhenDomain(Demon* d) override {}
Size() const2605   uint64_t Size() const override { return 1; }
Contains(int64_t v) const2606   bool Contains(int64_t v) const override { return (v == value_); }
MakeHoleIterator(bool reversible) const2607   IntVarIterator* MakeHoleIterator(bool reversible) const override {
2608     return COND_REV_ALLOC(reversible, new EmptyIterator());
2609   }
MakeDomainIterator(bool reversible) const2610   IntVarIterator* MakeDomainIterator(bool reversible) const override {
2611     return COND_REV_ALLOC(reversible, new RangeIterator(this));
2612   }
OldMin() const2613   int64_t OldMin() const override { return value_; }
OldMax() const2614   int64_t OldMax() const override { return value_; }
DebugString() const2615   std::string DebugString() const override {
2616     std::string out;
2617     if (solver()->HasName(this)) {
2618       const std::string& var_name = name();
2619       absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2620     } else {
2621       absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2622     }
2623     return out;
2624   }
2625 
VarType() const2626   int VarType() const override { return CONST_VAR; }
2627 
IsEqual(int64_t constant)2628   IntVar* IsEqual(int64_t constant) override {
2629     if (constant == value_) {
2630       return solver()->MakeIntConst(1);
2631     } else {
2632       return solver()->MakeIntConst(0);
2633     }
2634   }
2635 
IsDifferent(int64_t constant)2636   IntVar* IsDifferent(int64_t constant) override {
2637     if (constant == value_) {
2638       return solver()->MakeIntConst(0);
2639     } else {
2640       return solver()->MakeIntConst(1);
2641     }
2642   }
2643 
IsGreaterOrEqual(int64_t constant)2644   IntVar* IsGreaterOrEqual(int64_t constant) override {
2645     return solver()->MakeIntConst(value_ >= constant);
2646   }
2647 
IsLessOrEqual(int64_t constant)2648   IntVar* IsLessOrEqual(int64_t constant) override {
2649     return solver()->MakeIntConst(value_ <= constant);
2650   }
2651 
name() const2652   std::string name() const override {
2653     if (solver()->HasName(this)) {
2654       return PropagationBaseObject::name();
2655     } else {
2656       return absl::StrCat(value_);
2657     }
2658   }
2659 
2660  private:
2661   int64_t value_;
2662 };
2663 
2664 // ----- x + c variable, optimized case -----
2665 
2666 class PlusCstVar : public IntVar {
2667  public:
PlusCstVar(Solver * const s,IntVar * v,int64_t c)2668   PlusCstVar(Solver* const s, IntVar* v, int64_t c)
2669       : IntVar(s), var_(v), cst_(c) {}
2670 
~PlusCstVar()2671   ~PlusCstVar() override {}
2672 
WhenRange(Demon * d)2673   void WhenRange(Demon* d) override { var_->WhenRange(d); }
2674 
WhenBound(Demon * d)2675   void WhenBound(Demon* d) override { var_->WhenBound(d); }
2676 
WhenDomain(Demon * d)2677   void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2678 
OldMin() const2679   int64_t OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2680 
OldMax() const2681   int64_t OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2682 
DebugString() const2683   std::string DebugString() const override {
2684     if (HasName()) {
2685       return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2686     } else {
2687       return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2688     }
2689   }
2690 
VarType() const2691   int VarType() const override { return VAR_ADD_CST; }
2692 
Accept(ModelVisitor * const visitor) const2693   void Accept(ModelVisitor* const visitor) const override {
2694     visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2695                                   var_);
2696   }
2697 
IsEqual(int64_t constant)2698   IntVar* IsEqual(int64_t constant) override {
2699     return var_->IsEqual(constant - cst_);
2700   }
2701 
IsDifferent(int64_t constant)2702   IntVar* IsDifferent(int64_t constant) override {
2703     return var_->IsDifferent(constant - cst_);
2704   }
2705 
IsGreaterOrEqual(int64_t constant)2706   IntVar* IsGreaterOrEqual(int64_t constant) override {
2707     return var_->IsGreaterOrEqual(constant - cst_);
2708   }
2709 
IsLessOrEqual(int64_t constant)2710   IntVar* IsLessOrEqual(int64_t constant) override {
2711     return var_->IsLessOrEqual(constant - cst_);
2712   }
2713 
SubVar() const2714   IntVar* SubVar() const { return var_; }
2715 
Constant() const2716   int64_t Constant() const { return cst_; }
2717 
2718  protected:
2719   IntVar* const var_;
2720   const int64_t cst_;
2721 };
2722 
2723 class PlusCstIntVar : public PlusCstVar {
2724  public:
2725   class PlusCstIntVarIterator : public UnaryIterator {
2726    public:
PlusCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool rev)2727     PlusCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2728         : UnaryIterator(v, hole, rev), cst_(c) {}
2729 
~PlusCstIntVarIterator()2730     ~PlusCstIntVarIterator() override {}
2731 
Value() const2732     int64_t Value() const override { return iterator_->Value() + cst_; }
2733 
2734    private:
2735     const int64_t cst_;
2736   };
2737 
PlusCstIntVar(Solver * const s,IntVar * v,int64_t c)2738   PlusCstIntVar(Solver* const s, IntVar* v, int64_t c) : PlusCstVar(s, v, c) {}
2739 
~PlusCstIntVar()2740   ~PlusCstIntVar() override {}
2741 
Min() const2742   int64_t Min() const override { return var_->Min() + cst_; }
2743 
SetMin(int64_t m)2744   void SetMin(int64_t m) override { var_->SetMin(CapSub(m, cst_)); }
2745 
Max() const2746   int64_t Max() const override { return var_->Max() + cst_; }
2747 
SetMax(int64_t m)2748   void SetMax(int64_t m) override { var_->SetMax(CapSub(m, cst_)); }
2749 
SetRange(int64_t l,int64_t u)2750   void SetRange(int64_t l, int64_t u) override {
2751     var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2752   }
2753 
SetValue(int64_t v)2754   void SetValue(int64_t v) override { var_->SetValue(v - cst_); }
2755 
Value() const2756   int64_t Value() const override { return var_->Value() + cst_; }
2757 
Bound() const2758   bool Bound() const override { return var_->Bound(); }
2759 
RemoveValue(int64_t v)2760   void RemoveValue(int64_t v) override { var_->RemoveValue(v - cst_); }
2761 
RemoveInterval(int64_t l,int64_t u)2762   void RemoveInterval(int64_t l, int64_t u) override {
2763     var_->RemoveInterval(l - cst_, u - cst_);
2764   }
2765 
Size() const2766   uint64_t Size() const override { return var_->Size(); }
2767 
Contains(int64_t v) const2768   bool Contains(int64_t v) const override { return var_->Contains(v - cst_); }
2769 
MakeHoleIterator(bool reversible) const2770   IntVarIterator* MakeHoleIterator(bool reversible) const override {
2771     return COND_REV_ALLOC(
2772         reversible, new PlusCstIntVarIterator(var_, cst_, true, reversible));
2773   }
MakeDomainIterator(bool reversible) const2774   IntVarIterator* MakeDomainIterator(bool reversible) const override {
2775     return COND_REV_ALLOC(
2776         reversible, new PlusCstIntVarIterator(var_, cst_, false, reversible));
2777   }
2778 };
2779 
2780 class PlusCstDomainIntVar : public PlusCstVar {
2781  public:
2782   class PlusCstDomainIntVarIterator : public UnaryIterator {
2783    public:
PlusCstDomainIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)2784     PlusCstDomainIntVarIterator(const IntVar* const v, int64_t c, bool hole,
2785                                 bool reversible)
2786         : UnaryIterator(v, hole, reversible), cst_(c) {}
2787 
~PlusCstDomainIntVarIterator()2788     ~PlusCstDomainIntVarIterator() override {}
2789 
Value() const2790     int64_t Value() const override { return iterator_->Value() + cst_; }
2791 
2792    private:
2793     const int64_t cst_;
2794   };
2795 
PlusCstDomainIntVar(Solver * const s,DomainIntVar * v,int64_t c)2796   PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64_t c)
2797       : PlusCstVar(s, v, c) {}
2798 
~PlusCstDomainIntVar()2799   ~PlusCstDomainIntVar() override {}
2800 
2801   int64_t Min() const override;
2802   void SetMin(int64_t m) override;
2803   int64_t Max() const override;
2804   void SetMax(int64_t m) override;
2805   void SetRange(int64_t l, int64_t u) override;
2806   void SetValue(int64_t v) override;
2807   bool Bound() const override;
2808   int64_t Value() const override;
2809   void RemoveValue(int64_t v) override;
2810   void RemoveInterval(int64_t l, int64_t u) override;
2811   uint64_t Size() const override;
2812   bool Contains(int64_t v) const override;
2813 
domain_int_var() const2814   DomainIntVar* domain_int_var() const {
2815     return reinterpret_cast<DomainIntVar*>(var_);
2816   }
2817 
MakeHoleIterator(bool reversible) const2818   IntVarIterator* MakeHoleIterator(bool reversible) const override {
2819     return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2820                                           var_, cst_, true, reversible));
2821   }
MakeDomainIterator(bool reversible) const2822   IntVarIterator* MakeDomainIterator(bool reversible) const override {
2823     return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2824                                           var_, cst_, false, reversible));
2825   }
2826 };
2827 
Min() const2828 int64_t PlusCstDomainIntVar::Min() const {
2829   return domain_int_var()->min_.Value() + cst_;
2830 }
2831 
SetMin(int64_t m)2832 void PlusCstDomainIntVar::SetMin(int64_t m) {
2833   domain_int_var()->DomainIntVar::SetMin(m - cst_);
2834 }
2835 
Max() const2836 int64_t PlusCstDomainIntVar::Max() const {
2837   return domain_int_var()->max_.Value() + cst_;
2838 }
2839 
SetMax(int64_t m)2840 void PlusCstDomainIntVar::SetMax(int64_t m) {
2841   domain_int_var()->DomainIntVar::SetMax(m - cst_);
2842 }
2843 
SetRange(int64_t l,int64_t u)2844 void PlusCstDomainIntVar::SetRange(int64_t l, int64_t u) {
2845   domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2846 }
2847 
SetValue(int64_t v)2848 void PlusCstDomainIntVar::SetValue(int64_t v) {
2849   domain_int_var()->DomainIntVar::SetValue(v - cst_);
2850 }
2851 
Bound() const2852 bool PlusCstDomainIntVar::Bound() const {
2853   return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2854 }
2855 
Value() const2856 int64_t PlusCstDomainIntVar::Value() const {
2857   CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2858       << " variable is not bound";
2859   return domain_int_var()->min_.Value() + cst_;
2860 }
2861 
RemoveValue(int64_t v)2862 void PlusCstDomainIntVar::RemoveValue(int64_t v) {
2863   domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2864 }
2865 
RemoveInterval(int64_t l,int64_t u)2866 void PlusCstDomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2867   domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2868 }
2869 
Size() const2870 uint64_t PlusCstDomainIntVar::Size() const {
2871   return domain_int_var()->DomainIntVar::Size();
2872 }
2873 
Contains(int64_t v) const2874 bool PlusCstDomainIntVar::Contains(int64_t v) const {
2875   return domain_int_var()->DomainIntVar::Contains(v - cst_);
2876 }
2877 
2878 // c - x variable, optimized case
2879 
2880 class SubCstIntVar : public IntVar {
2881  public:
2882   class SubCstIntVarIterator : public UnaryIterator {
2883    public:
SubCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool rev)2884     SubCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2885         : UnaryIterator(v, hole, rev), cst_(c) {}
~SubCstIntVarIterator()2886     ~SubCstIntVarIterator() override {}
2887 
Value() const2888     int64_t Value() const override { return cst_ - iterator_->Value(); }
2889 
2890    private:
2891     const int64_t cst_;
2892   };
2893 
2894   SubCstIntVar(Solver* const s, IntVar* v, int64_t c);
2895   ~SubCstIntVar() override;
2896 
2897   int64_t Min() const override;
2898   void SetMin(int64_t m) override;
2899   int64_t Max() const override;
2900   void SetMax(int64_t m) override;
2901   void SetRange(int64_t l, int64_t u) override;
2902   void SetValue(int64_t v) override;
2903   bool Bound() const override;
2904   int64_t Value() const override;
2905   void RemoveValue(int64_t v) override;
2906   void RemoveInterval(int64_t l, int64_t u) override;
2907   uint64_t Size() const override;
2908   bool Contains(int64_t v) const override;
2909   void WhenRange(Demon* d) override;
2910   void WhenBound(Demon* d) override;
2911   void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const2912   IntVarIterator* MakeHoleIterator(bool reversible) const override {
2913     return COND_REV_ALLOC(
2914         reversible, new SubCstIntVarIterator(var_, cst_, true, reversible));
2915   }
MakeDomainIterator(bool reversible) const2916   IntVarIterator* MakeDomainIterator(bool reversible) const override {
2917     return COND_REV_ALLOC(
2918         reversible, new SubCstIntVarIterator(var_, cst_, false, reversible));
2919   }
OldMin() const2920   int64_t OldMin() const override { return CapSub(cst_, var_->OldMax()); }
OldMax() const2921   int64_t OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2922   std::string DebugString() const override;
2923   std::string name() const override;
VarType() const2924   int VarType() const override { return CST_SUB_VAR; }
2925 
Accept(ModelVisitor * const visitor) const2926   void Accept(ModelVisitor* const visitor) const override {
2927     visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2928                                   cst_, var_);
2929   }
2930 
IsEqual(int64_t constant)2931   IntVar* IsEqual(int64_t constant) override {
2932     return var_->IsEqual(cst_ - constant);
2933   }
2934 
IsDifferent(int64_t constant)2935   IntVar* IsDifferent(int64_t constant) override {
2936     return var_->IsDifferent(cst_ - constant);
2937   }
2938 
IsGreaterOrEqual(int64_t constant)2939   IntVar* IsGreaterOrEqual(int64_t constant) override {
2940     return var_->IsLessOrEqual(cst_ - constant);
2941   }
2942 
IsLessOrEqual(int64_t constant)2943   IntVar* IsLessOrEqual(int64_t constant) override {
2944     return var_->IsGreaterOrEqual(cst_ - constant);
2945   }
2946 
SubVar() const2947   IntVar* SubVar() const { return var_; }
Constant() const2948   int64_t Constant() const { return cst_; }
2949 
2950  private:
2951   IntVar* const var_;
2952   const int64_t cst_;
2953 };
2954 
SubCstIntVar(Solver * const s,IntVar * v,int64_t c)2955 SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64_t c)
2956     : IntVar(s), var_(v), cst_(c) {}
2957 
~SubCstIntVar()2958 SubCstIntVar::~SubCstIntVar() {}
2959 
Min() const2960 int64_t SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2961 
SetMin(int64_t m)2962 void SubCstIntVar::SetMin(int64_t m) { var_->SetMax(CapSub(cst_, m)); }
2963 
Max() const2964 int64_t SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2965 
SetMax(int64_t m)2966 void SubCstIntVar::SetMax(int64_t m) { var_->SetMin(CapSub(cst_, m)); }
2967 
SetRange(int64_t l,int64_t u)2968 void SubCstIntVar::SetRange(int64_t l, int64_t u) {
2969   var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2970 }
2971 
SetValue(int64_t v)2972 void SubCstIntVar::SetValue(int64_t v) { var_->SetValue(cst_ - v); }
2973 
Bound() const2974 bool SubCstIntVar::Bound() const { return var_->Bound(); }
2975 
WhenRange(Demon * d)2976 void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2977 
Value() const2978 int64_t SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2979 
RemoveValue(int64_t v)2980 void SubCstIntVar::RemoveValue(int64_t v) { var_->RemoveValue(cst_ - v); }
2981 
RemoveInterval(int64_t l,int64_t u)2982 void SubCstIntVar::RemoveInterval(int64_t l, int64_t u) {
2983   var_->RemoveInterval(cst_ - u, cst_ - l);
2984 }
2985 
WhenBound(Demon * d)2986 void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2987 
WhenDomain(Demon * d)2988 void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2989 
Size() const2990 uint64_t SubCstIntVar::Size() const { return var_->Size(); }
2991 
Contains(int64_t v) const2992 bool SubCstIntVar::Contains(int64_t v) const {
2993   return var_->Contains(cst_ - v);
2994 }
2995 
DebugString() const2996 std::string SubCstIntVar::DebugString() const {
2997   if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2998     return absl::StrFormat("Not(%s)", var_->DebugString());
2999   } else {
3000     return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
3001   }
3002 }
3003 
name() const3004 std::string SubCstIntVar::name() const {
3005   if (solver()->HasName(this)) {
3006     return PropagationBaseObject::name();
3007   } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3008     return absl::StrFormat("Not(%s)", var_->name());
3009   } else {
3010     return absl::StrFormat("(%d - %s)", cst_, var_->name());
3011   }
3012 }
3013 
3014 // -x variable, optimized case
3015 
3016 class OppIntVar : public IntVar {
3017  public:
3018   class OppIntVarIterator : public UnaryIterator {
3019    public:
OppIntVarIterator(const IntVar * const v,bool hole,bool reversible)3020     OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3021         : UnaryIterator(v, hole, reversible) {}
~OppIntVarIterator()3022     ~OppIntVarIterator() override {}
3023 
Value() const3024     int64_t Value() const override { return -iterator_->Value(); }
3025   };
3026 
3027   OppIntVar(Solver* const s, IntVar* v);
3028   ~OppIntVar() override;
3029 
3030   int64_t Min() const override;
3031   void SetMin(int64_t m) override;
3032   int64_t Max() const override;
3033   void SetMax(int64_t m) override;
3034   void SetRange(int64_t l, int64_t u) override;
3035   void SetValue(int64_t v) override;
3036   bool Bound() const override;
3037   int64_t Value() const override;
3038   void RemoveValue(int64_t v) override;
3039   void RemoveInterval(int64_t l, int64_t u) override;
3040   uint64_t Size() const override;
3041   bool Contains(int64_t v) const override;
3042   void WhenRange(Demon* d) override;
3043   void WhenBound(Demon* d) override;
3044   void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3045   IntVarIterator* MakeHoleIterator(bool reversible) const override {
3046     return COND_REV_ALLOC(reversible,
3047                           new OppIntVarIterator(var_, true, reversible));
3048   }
MakeDomainIterator(bool reversible) const3049   IntVarIterator* MakeDomainIterator(bool reversible) const override {
3050     return COND_REV_ALLOC(reversible,
3051                           new OppIntVarIterator(var_, false, reversible));
3052   }
OldMin() const3053   int64_t OldMin() const override { return CapOpp(var_->OldMax()); }
OldMax() const3054   int64_t OldMax() const override { return CapOpp(var_->OldMin()); }
3055   std::string DebugString() const override;
VarType() const3056   int VarType() const override { return OPP_VAR; }
3057 
Accept(ModelVisitor * const visitor) const3058   void Accept(ModelVisitor* const visitor) const override {
3059     visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3060                                   var_);
3061   }
3062 
IsEqual(int64_t constant)3063   IntVar* IsEqual(int64_t constant) override {
3064     return var_->IsEqual(-constant);
3065   }
3066 
IsDifferent(int64_t constant)3067   IntVar* IsDifferent(int64_t constant) override {
3068     return var_->IsDifferent(-constant);
3069   }
3070 
IsGreaterOrEqual(int64_t constant)3071   IntVar* IsGreaterOrEqual(int64_t constant) override {
3072     return var_->IsLessOrEqual(-constant);
3073   }
3074 
IsLessOrEqual(int64_t constant)3075   IntVar* IsLessOrEqual(int64_t constant) override {
3076     return var_->IsGreaterOrEqual(-constant);
3077   }
3078 
SubVar() const3079   IntVar* SubVar() const { return var_; }
3080 
3081  private:
3082   IntVar* const var_;
3083 };
3084 
OppIntVar(Solver * const s,IntVar * v)3085 OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3086 
~OppIntVar()3087 OppIntVar::~OppIntVar() {}
3088 
Min() const3089 int64_t OppIntVar::Min() const { return -var_->Max(); }
3090 
SetMin(int64_t m)3091 void OppIntVar::SetMin(int64_t m) { var_->SetMax(CapOpp(m)); }
3092 
Max() const3093 int64_t OppIntVar::Max() const { return -var_->Min(); }
3094 
SetMax(int64_t m)3095 void OppIntVar::SetMax(int64_t m) { var_->SetMin(CapOpp(m)); }
3096 
SetRange(int64_t l,int64_t u)3097 void OppIntVar::SetRange(int64_t l, int64_t u) {
3098   var_->SetRange(CapOpp(u), CapOpp(l));
3099 }
3100 
SetValue(int64_t v)3101 void OppIntVar::SetValue(int64_t v) { var_->SetValue(CapOpp(v)); }
3102 
Bound() const3103 bool OppIntVar::Bound() const { return var_->Bound(); }
3104 
WhenRange(Demon * d)3105 void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3106 
Value() const3107 int64_t OppIntVar::Value() const { return -var_->Value(); }
3108 
RemoveValue(int64_t v)3109 void OppIntVar::RemoveValue(int64_t v) { var_->RemoveValue(-v); }
3110 
RemoveInterval(int64_t l,int64_t u)3111 void OppIntVar::RemoveInterval(int64_t l, int64_t u) {
3112   var_->RemoveInterval(-u, -l);
3113 }
3114 
WhenBound(Demon * d)3115 void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3116 
WhenDomain(Demon * d)3117 void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3118 
Size() const3119 uint64_t OppIntVar::Size() const { return var_->Size(); }
3120 
Contains(int64_t v) const3121 bool OppIntVar::Contains(int64_t v) const { return var_->Contains(-v); }
3122 
DebugString() const3123 std::string OppIntVar::DebugString() const {
3124   return absl::StrFormat("-(%s)", var_->DebugString());
3125 }
3126 
3127 // ----- Utility functions -----
3128 
3129 // x * c variable, optimized case
3130 
3131 class TimesCstIntVar : public IntVar {
3132  public:
TimesCstIntVar(Solver * const s,IntVar * v,int64_t c)3133   TimesCstIntVar(Solver* const s, IntVar* v, int64_t c)
3134       : IntVar(s), var_(v), cst_(c) {}
~TimesCstIntVar()3135   ~TimesCstIntVar() override {}
3136 
SubVar() const3137   IntVar* SubVar() const { return var_; }
Constant() const3138   int64_t Constant() const { return cst_; }
3139 
Accept(ModelVisitor * const visitor) const3140   void Accept(ModelVisitor* const visitor) const override {
3141     visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3142                                   var_);
3143   }
3144 
IsEqual(int64_t constant)3145   IntVar* IsEqual(int64_t constant) override {
3146     if (constant % cst_ == 0) {
3147       return var_->IsEqual(constant / cst_);
3148     } else {
3149       return solver()->MakeIntConst(0);
3150     }
3151   }
3152 
IsDifferent(int64_t constant)3153   IntVar* IsDifferent(int64_t constant) override {
3154     if (constant % cst_ == 0) {
3155       return var_->IsDifferent(constant / cst_);
3156     } else {
3157       return solver()->MakeIntConst(1);
3158     }
3159   }
3160 
IsGreaterOrEqual(int64_t constant)3161   IntVar* IsGreaterOrEqual(int64_t constant) override {
3162     if (cst_ > 0) {
3163       return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3164     } else {
3165       return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3166     }
3167   }
3168 
IsLessOrEqual(int64_t constant)3169   IntVar* IsLessOrEqual(int64_t constant) override {
3170     if (cst_ > 0) {
3171       return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3172     } else {
3173       return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3174     }
3175   }
3176 
DebugString() const3177   std::string DebugString() const override {
3178     return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3179   }
3180 
VarType() const3181   int VarType() const override { return VAR_TIMES_CST; }
3182 
3183  protected:
3184   IntVar* const var_;
3185   const int64_t cst_;
3186 };
3187 
3188 class TimesPosCstIntVar : public TimesCstIntVar {
3189  public:
3190   class TimesPosCstIntVarIterator : public UnaryIterator {
3191    public:
TimesPosCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)3192     TimesPosCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3193                               bool reversible)
3194         : UnaryIterator(v, hole, reversible), cst_(c) {}
~TimesPosCstIntVarIterator()3195     ~TimesPosCstIntVarIterator() override {}
3196 
Value() const3197     int64_t Value() const override { return iterator_->Value() * cst_; }
3198 
3199    private:
3200     const int64_t cst_;
3201   };
3202 
3203   TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c);
3204   ~TimesPosCstIntVar() override;
3205 
3206   int64_t Min() const override;
3207   void SetMin(int64_t m) override;
3208   int64_t Max() const override;
3209   void SetMax(int64_t m) override;
3210   void SetRange(int64_t l, int64_t u) override;
3211   void SetValue(int64_t v) override;
3212   bool Bound() const override;
3213   int64_t Value() const override;
3214   void RemoveValue(int64_t v) override;
3215   void RemoveInterval(int64_t l, int64_t u) override;
3216   uint64_t Size() const override;
3217   bool Contains(int64_t v) const override;
3218   void WhenRange(Demon* d) override;
3219   void WhenBound(Demon* d) override;
3220   void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3221   IntVarIterator* MakeHoleIterator(bool reversible) const override {
3222     return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3223                                           var_, cst_, true, reversible));
3224   }
MakeDomainIterator(bool reversible) const3225   IntVarIterator* MakeDomainIterator(bool reversible) const override {
3226     return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3227                                           var_, cst_, false, reversible));
3228   }
OldMin() const3229   int64_t OldMin() const override { return CapProd(var_->OldMin(), cst_); }
OldMax() const3230   int64_t OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3231 };
3232 
3233 // ----- TimesPosCstIntVar -----
3234 
TimesPosCstIntVar(Solver * const s,IntVar * v,int64_t c)3235 TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c)
3236     : TimesCstIntVar(s, v, c) {}
3237 
~TimesPosCstIntVar()3238 TimesPosCstIntVar::~TimesPosCstIntVar() {}
3239 
Min() const3240 int64_t TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3241 
SetMin(int64_t m)3242 void TimesPosCstIntVar::SetMin(int64_t m) {
3243   if (m != std::numeric_limits<int64_t>::min()) {
3244     var_->SetMin(PosIntDivUp(m, cst_));
3245   }
3246 }
3247 
Max() const3248 int64_t TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3249 
SetMax(int64_t m)3250 void TimesPosCstIntVar::SetMax(int64_t m) {
3251   if (m != std::numeric_limits<int64_t>::max()) {
3252     var_->SetMax(PosIntDivDown(m, cst_));
3253   }
3254 }
3255 
SetRange(int64_t l,int64_t u)3256 void TimesPosCstIntVar::SetRange(int64_t l, int64_t u) {
3257   var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3258 }
3259 
SetValue(int64_t v)3260 void TimesPosCstIntVar::SetValue(int64_t v) {
3261   if (v % cst_ != 0) {
3262     solver()->Fail();
3263   }
3264   var_->SetValue(v / cst_);
3265 }
3266 
Bound() const3267 bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3268 
WhenRange(Demon * d)3269 void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3270 
Value() const3271 int64_t TimesPosCstIntVar::Value() const {
3272   return CapProd(var_->Value(), cst_);
3273 }
3274 
RemoveValue(int64_t v)3275 void TimesPosCstIntVar::RemoveValue(int64_t v) {
3276   if (v % cst_ == 0) {
3277     var_->RemoveValue(v / cst_);
3278   }
3279 }
3280 
RemoveInterval(int64_t l,int64_t u)3281 void TimesPosCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3282   for (int64_t v = l; v <= u; ++v) {
3283     RemoveValue(v);
3284   }
3285   // TODO(user) : Improve me
3286 }
3287 
WhenBound(Demon * d)3288 void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3289 
WhenDomain(Demon * d)3290 void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3291 
Size() const3292 uint64_t TimesPosCstIntVar::Size() const { return var_->Size(); }
3293 
Contains(int64_t v) const3294 bool TimesPosCstIntVar::Contains(int64_t v) const {
3295   return (v % cst_ == 0 && var_->Contains(v / cst_));
3296 }
3297 
3298 // b * c variable, optimized case
3299 
3300 class TimesPosCstBoolVar : public TimesCstIntVar {
3301  public:
3302   class TimesPosCstBoolVarIterator : public UnaryIterator {
3303    public:
3304     // TODO(user) : optimize this.
TimesPosCstBoolVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)3305     TimesPosCstBoolVarIterator(const IntVar* const v, int64_t c, bool hole,
3306                                bool reversible)
3307         : UnaryIterator(v, hole, reversible), cst_(c) {}
~TimesPosCstBoolVarIterator()3308     ~TimesPosCstBoolVarIterator() override {}
3309 
Value() const3310     int64_t Value() const override { return iterator_->Value() * cst_; }
3311 
3312    private:
3313     const int64_t cst_;
3314   };
3315 
3316   TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64_t c);
3317   ~TimesPosCstBoolVar() override;
3318 
3319   int64_t Min() const override;
3320   void SetMin(int64_t m) override;
3321   int64_t Max() const override;
3322   void SetMax(int64_t m) override;
3323   void SetRange(int64_t l, int64_t u) override;
3324   void SetValue(int64_t v) override;
3325   bool Bound() const override;
3326   int64_t Value() const override;
3327   void RemoveValue(int64_t v) override;
3328   void RemoveInterval(int64_t l, int64_t u) override;
3329   uint64_t Size() const override;
3330   bool Contains(int64_t v) const override;
3331   void WhenRange(Demon* d) override;
3332   void WhenBound(Demon* d) override;
3333   void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3334   IntVarIterator* MakeHoleIterator(bool reversible) const override {
3335     return COND_REV_ALLOC(reversible, new EmptyIterator());
3336   }
MakeDomainIterator(bool reversible) const3337   IntVarIterator* MakeDomainIterator(bool reversible) const override {
3338     return COND_REV_ALLOC(
3339         reversible,
3340         new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3341   }
OldMin() const3342   int64_t OldMin() const override { return 0; }
OldMax() const3343   int64_t OldMax() const override { return cst_; }
3344 
boolean_var() const3345   BooleanVar* boolean_var() const {
3346     return reinterpret_cast<BooleanVar*>(var_);
3347   }
3348 };
3349 
3350 // ----- TimesPosCstBoolVar -----
3351 
TimesPosCstBoolVar(Solver * const s,BooleanVar * v,int64_t c)3352 TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v,
3353                                        int64_t c)
3354     : TimesCstIntVar(s, v, c) {}
3355 
~TimesPosCstBoolVar()3356 TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3357 
Min() const3358 int64_t TimesPosCstBoolVar::Min() const {
3359   return (boolean_var()->RawValue() == 1) * cst_;
3360 }
3361 
SetMin(int64_t m)3362 void TimesPosCstBoolVar::SetMin(int64_t m) {
3363   if (m > cst_) {
3364     solver()->Fail();
3365   } else if (m > 0) {
3366     boolean_var()->SetMin(1);
3367   }
3368 }
3369 
Max() const3370 int64_t TimesPosCstBoolVar::Max() const {
3371   return (boolean_var()->RawValue() != 0) * cst_;
3372 }
3373 
SetMax(int64_t m)3374 void TimesPosCstBoolVar::SetMax(int64_t m) {
3375   if (m < 0) {
3376     solver()->Fail();
3377   } else if (m < cst_) {
3378     boolean_var()->SetMax(0);
3379   }
3380 }
3381 
SetRange(int64_t l,int64_t u)3382 void TimesPosCstBoolVar::SetRange(int64_t l, int64_t u) {
3383   if (u < 0 || l > cst_ || l > u) {
3384     solver()->Fail();
3385   }
3386   if (l > 0) {
3387     boolean_var()->SetMin(1);
3388   } else if (u < cst_) {
3389     boolean_var()->SetMax(0);
3390   }
3391 }
3392 
SetValue(int64_t v)3393 void TimesPosCstBoolVar::SetValue(int64_t v) {
3394   if (v == 0) {
3395     boolean_var()->SetValue(0);
3396   } else if (v == cst_) {
3397     boolean_var()->SetValue(1);
3398   } else {
3399     solver()->Fail();
3400   }
3401 }
3402 
Bound() const3403 bool TimesPosCstBoolVar::Bound() const {
3404   return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3405 }
3406 
WhenRange(Demon * d)3407 void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3408 
Value() const3409 int64_t TimesPosCstBoolVar::Value() const {
3410   CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3411       << " variable is not bound";
3412   return boolean_var()->RawValue() * cst_;
3413 }
3414 
RemoveValue(int64_t v)3415 void TimesPosCstBoolVar::RemoveValue(int64_t v) {
3416   if (v == 0) {
3417     boolean_var()->RemoveValue(0);
3418   } else if (v == cst_) {
3419     boolean_var()->RemoveValue(1);
3420   }
3421 }
3422 
RemoveInterval(int64_t l,int64_t u)3423 void TimesPosCstBoolVar::RemoveInterval(int64_t l, int64_t u) {
3424   if (l <= 0 && u >= 0) {
3425     boolean_var()->RemoveValue(0);
3426   }
3427   if (l <= cst_ && u >= cst_) {
3428     boolean_var()->RemoveValue(1);
3429   }
3430 }
3431 
WhenBound(Demon * d)3432 void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3433 
WhenDomain(Demon * d)3434 void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3435 
Size() const3436 uint64_t TimesPosCstBoolVar::Size() const {
3437   return (1 +
3438           (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3439 }
3440 
Contains(int64_t v) const3441 bool TimesPosCstBoolVar::Contains(int64_t v) const {
3442   if (v == 0) {
3443     return boolean_var()->RawValue() != 1;
3444   } else if (v == cst_) {
3445     return boolean_var()->RawValue() != 0;
3446   }
3447   return false;
3448 }
3449 
3450 // TimesNegCstIntVar
3451 
3452 class TimesNegCstIntVar : public TimesCstIntVar {
3453  public:
3454   class TimesNegCstIntVarIterator : public UnaryIterator {
3455    public:
TimesNegCstIntVarIterator(const IntVar * const v,int64_t c,bool hole,bool reversible)3456     TimesNegCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3457                               bool reversible)
3458         : UnaryIterator(v, hole, reversible), cst_(c) {}
~TimesNegCstIntVarIterator()3459     ~TimesNegCstIntVarIterator() override {}
3460 
Value() const3461     int64_t Value() const override { return iterator_->Value() * cst_; }
3462 
3463    private:
3464     const int64_t cst_;
3465   };
3466 
3467   TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c);
3468   ~TimesNegCstIntVar() override;
3469 
3470   int64_t Min() const override;
3471   void SetMin(int64_t m) override;
3472   int64_t Max() const override;
3473   void SetMax(int64_t m) override;
3474   void SetRange(int64_t l, int64_t u) override;
3475   void SetValue(int64_t v) override;
3476   bool Bound() const override;
3477   int64_t Value() const override;
3478   void RemoveValue(int64_t v) override;
3479   void RemoveInterval(int64_t l, int64_t u) override;
3480   uint64_t Size() const override;
3481   bool Contains(int64_t v) const override;
3482   void WhenRange(Demon* d) override;
3483   void WhenBound(Demon* d) override;
3484   void WhenDomain(Demon* d) override;
MakeHoleIterator(bool reversible) const3485   IntVarIterator* MakeHoleIterator(bool reversible) const override {
3486     return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3487                                           var_, cst_, true, reversible));
3488   }
MakeDomainIterator(bool reversible) const3489   IntVarIterator* MakeDomainIterator(bool reversible) const override {
3490     return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3491                                           var_, cst_, false, reversible));
3492   }
OldMin() const3493   int64_t OldMin() const override { return CapProd(var_->OldMax(), cst_); }
OldMax() const3494   int64_t OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3495 };
3496 
3497 // ----- TimesNegCstIntVar -----
3498 
TimesNegCstIntVar(Solver * const s,IntVar * v,int64_t c)3499 TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c)
3500     : TimesCstIntVar(s, v, c) {}
3501 
~TimesNegCstIntVar()3502 TimesNegCstIntVar::~TimesNegCstIntVar() {}
3503 
Min() const3504 int64_t TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3505 
SetMin(int64_t m)3506 void TimesNegCstIntVar::SetMin(int64_t m) {
3507   if (m != std::numeric_limits<int64_t>::min()) {
3508     var_->SetMax(PosIntDivDown(-m, -cst_));
3509   }
3510 }
3511 
Max() const3512 int64_t TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3513 
SetMax(int64_t m)3514 void TimesNegCstIntVar::SetMax(int64_t m) {
3515   if (m != std::numeric_limits<int64_t>::max()) {
3516     var_->SetMin(PosIntDivUp(-m, -cst_));
3517   }
3518 }
3519 
SetRange(int64_t l,int64_t u)3520 void TimesNegCstIntVar::SetRange(int64_t l, int64_t u) {
3521   var_->SetRange(PosIntDivUp(-u, -cst_), PosIntDivDown(-l, -cst_));
3522 }
3523 
SetValue(int64_t v)3524 void TimesNegCstIntVar::SetValue(int64_t v) {
3525   if (v % cst_ != 0) {
3526     solver()->Fail();
3527   }
3528   var_->SetValue(v / cst_);
3529 }
3530 
Bound() const3531 bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3532 
WhenRange(Demon * d)3533 void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3534 
Value() const3535 int64_t TimesNegCstIntVar::Value() const {
3536   return CapProd(var_->Value(), cst_);
3537 }
3538 
RemoveValue(int64_t v)3539 void TimesNegCstIntVar::RemoveValue(int64_t v) {
3540   if (v % cst_ == 0) {
3541     var_->RemoveValue(v / cst_);
3542   }
3543 }
3544 
RemoveInterval(int64_t l,int64_t u)3545 void TimesNegCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3546   for (int64_t v = l; v <= u; ++v) {
3547     RemoveValue(v);
3548   }
3549   // TODO(user) : Improve me
3550 }
3551 
WhenBound(Demon * d)3552 void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3553 
WhenDomain(Demon * d)3554 void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3555 
Size() const3556 uint64_t TimesNegCstIntVar::Size() const { return var_->Size(); }
3557 
Contains(int64_t v) const3558 bool TimesNegCstIntVar::Contains(int64_t v) const {
3559   return (v % cst_ == 0 && var_->Contains(v / cst_));
3560 }
3561 
3562 // ---------- arithmetic expressions ----------
3563 
3564 // ----- PlusIntExpr -----
3565 
3566 class PlusIntExpr : public BaseIntExpr {
3567  public:
PlusIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3568   PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3569       : BaseIntExpr(s), left_(l), right_(r) {}
3570 
~PlusIntExpr()3571   ~PlusIntExpr() override {}
3572 
Min() const3573   int64_t Min() const override { return left_->Min() + right_->Min(); }
3574 
SetMin(int64_t m)3575   void SetMin(int64_t m) override {
3576     if (m > left_->Min() + right_->Min()) {
3577       left_->SetMin(m - right_->Max());
3578       right_->SetMin(m - left_->Max());
3579     }
3580   }
3581 
SetRange(int64_t l,int64_t u)3582   void SetRange(int64_t l, int64_t u) override {
3583     const int64_t left_min = left_->Min();
3584     const int64_t right_min = right_->Min();
3585     const int64_t left_max = left_->Max();
3586     const int64_t right_max = right_->Max();
3587     if (l > left_min + right_min) {
3588       left_->SetMin(l - right_max);
3589       right_->SetMin(l - left_max);
3590     }
3591     if (u < left_max + right_max) {
3592       left_->SetMax(u - right_min);
3593       right_->SetMax(u - left_min);
3594     }
3595   }
3596 
Max() const3597   int64_t Max() const override { return left_->Max() + right_->Max(); }
3598 
SetMax(int64_t m)3599   void SetMax(int64_t m) override {
3600     if (m < left_->Max() + right_->Max()) {
3601       left_->SetMax(m - right_->Min());
3602       right_->SetMax(m - left_->Min());
3603     }
3604   }
3605 
Bound() const3606   bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3607 
Range(int64_t * const mi,int64_t * const ma)3608   void Range(int64_t* const mi, int64_t* const ma) override {
3609     *mi = left_->Min() + right_->Min();
3610     *ma = left_->Max() + right_->Max();
3611   }
3612 
name() const3613   std::string name() const override {
3614     return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3615   }
3616 
DebugString() const3617   std::string DebugString() const override {
3618     return absl::StrFormat("(%s + %s)", left_->DebugString(),
3619                            right_->DebugString());
3620   }
3621 
WhenRange(Demon * d)3622   void WhenRange(Demon* d) override {
3623     left_->WhenRange(d);
3624     right_->WhenRange(d);
3625   }
3626 
ExpandPlusIntExpr(IntExpr * const expr,std::vector<IntExpr * > * subs)3627   void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3628     PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3629     if (casted != nullptr) {
3630       ExpandPlusIntExpr(casted->left_, subs);
3631       ExpandPlusIntExpr(casted->right_, subs);
3632     } else {
3633       subs->push_back(expr);
3634     }
3635   }
3636 
CastToVar()3637   IntVar* CastToVar() override {
3638     if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3639         dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3640       std::vector<IntExpr*> sub_exprs;
3641       ExpandPlusIntExpr(left_, &sub_exprs);
3642       ExpandPlusIntExpr(right_, &sub_exprs);
3643       if (sub_exprs.size() >= 3) {
3644         std::vector<IntVar*> sub_vars(sub_exprs.size());
3645         for (int i = 0; i < sub_exprs.size(); ++i) {
3646           sub_vars[i] = sub_exprs[i]->Var();
3647         }
3648         return solver()->MakeSum(sub_vars)->Var();
3649       }
3650     }
3651     return BaseIntExpr::CastToVar();
3652   }
3653 
Accept(ModelVisitor * const visitor) const3654   void Accept(ModelVisitor* const visitor) const override {
3655     visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3656     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3657     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3658                                             right_);
3659     visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3660   }
3661 
3662  private:
3663   IntExpr* const left_;
3664   IntExpr* const right_;
3665 };
3666 
3667 class SafePlusIntExpr : public BaseIntExpr {
3668  public:
SafePlusIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3669   SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3670       : BaseIntExpr(s), left_(l), right_(r) {}
3671 
~SafePlusIntExpr()3672   ~SafePlusIntExpr() override {}
3673 
Min() const3674   int64_t Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3675 
SetMin(int64_t m)3676   void SetMin(int64_t m) override {
3677     left_->SetMin(CapSub(m, right_->Max()));
3678     right_->SetMin(CapSub(m, left_->Max()));
3679   }
3680 
SetRange(int64_t l,int64_t u)3681   void SetRange(int64_t l, int64_t u) override {
3682     const int64_t left_min = left_->Min();
3683     const int64_t right_min = right_->Min();
3684     const int64_t left_max = left_->Max();
3685     const int64_t right_max = right_->Max();
3686     if (l > CapAdd(left_min, right_min)) {
3687       left_->SetMin(CapSub(l, right_max));
3688       right_->SetMin(CapSub(l, left_max));
3689     }
3690     if (u < CapAdd(left_max, right_max)) {
3691       left_->SetMax(CapSub(u, right_min));
3692       right_->SetMax(CapSub(u, left_min));
3693     }
3694   }
3695 
Max() const3696   int64_t Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3697 
SetMax(int64_t m)3698   void SetMax(int64_t m) override {
3699     left_->SetMax(CapSub(m, right_->Min()));
3700     right_->SetMax(CapSub(m, left_->Min()));
3701   }
3702 
Bound() const3703   bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3704 
name() const3705   std::string name() const override {
3706     return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3707   }
3708 
DebugString() const3709   std::string DebugString() const override {
3710     return absl::StrFormat("(%s + %s)", left_->DebugString(),
3711                            right_->DebugString());
3712   }
3713 
WhenRange(Demon * d)3714   void WhenRange(Demon* d) override {
3715     left_->WhenRange(d);
3716     right_->WhenRange(d);
3717   }
3718 
Accept(ModelVisitor * const visitor) const3719   void Accept(ModelVisitor* const visitor) const override {
3720     visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3721     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3722     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3723                                             right_);
3724     visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3725   }
3726 
3727  private:
3728   IntExpr* const left_;
3729   IntExpr* const right_;
3730 };
3731 
3732 // ----- PlusIntCstExpr -----
3733 
3734 class PlusIntCstExpr : public BaseIntExpr {
3735  public:
PlusIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)3736   PlusIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3737       : BaseIntExpr(s), expr_(e), value_(v) {}
~PlusIntCstExpr()3738   ~PlusIntCstExpr() override {}
Min() const3739   int64_t Min() const override { return CapAdd(expr_->Min(), value_); }
SetMin(int64_t m)3740   void SetMin(int64_t m) override { expr_->SetMin(CapSub(m, value_)); }
Max() const3741   int64_t Max() const override { return CapAdd(expr_->Max(), value_); }
SetMax(int64_t m)3742   void SetMax(int64_t m) override { expr_->SetMax(CapSub(m, value_)); }
Bound() const3743   bool Bound() const override { return (expr_->Bound()); }
name() const3744   std::string name() const override {
3745     return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3746   }
DebugString() const3747   std::string DebugString() const override {
3748     return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3749   }
WhenRange(Demon * d)3750   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3751   IntVar* CastToVar() override;
Accept(ModelVisitor * const visitor) const3752   void Accept(ModelVisitor* const visitor) const override {
3753     visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3754     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3755                                             expr_);
3756     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3757     visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3758   }
3759 
3760  private:
3761   IntExpr* const expr_;
3762   const int64_t value_;
3763 };
3764 
CastToVar()3765 IntVar* PlusIntCstExpr::CastToVar() {
3766   Solver* const s = solver();
3767   IntVar* const var = expr_->Var();
3768   IntVar* cast = nullptr;
3769   if (AddOverflows(value_, expr_->Max()) ||
3770       AddOverflows(value_, expr_->Min())) {
3771     return BaseIntExpr::CastToVar();
3772   }
3773   switch (var->VarType()) {
3774     case DOMAIN_INT_VAR:
3775       cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3776           s, reinterpret_cast<DomainIntVar*>(var), value_)));
3777       // FIXME: Break was inserted during fallthrough cleanup. Please check.
3778       break;
3779     default:
3780       cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3781       break;
3782   }
3783   return cast;
3784 }
3785 
3786 // ----- SubIntExpr -----
3787 
3788 class SubIntExpr : public BaseIntExpr {
3789  public:
SubIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3790   SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3791       : BaseIntExpr(s), left_(l), right_(r) {}
3792 
~SubIntExpr()3793   ~SubIntExpr() override {}
3794 
Min() const3795   int64_t Min() const override { return left_->Min() - right_->Max(); }
3796 
SetMin(int64_t m)3797   void SetMin(int64_t m) override {
3798     left_->SetMin(CapAdd(m, right_->Min()));
3799     right_->SetMax(CapSub(left_->Max(), m));
3800   }
3801 
Max() const3802   int64_t Max() const override { return left_->Max() - right_->Min(); }
3803 
SetMax(int64_t m)3804   void SetMax(int64_t m) override {
3805     left_->SetMax(CapAdd(m, right_->Max()));
3806     right_->SetMin(CapSub(left_->Min(), m));
3807   }
3808 
Range(int64_t * mi,int64_t * ma)3809   void Range(int64_t* mi, int64_t* ma) override {
3810     *mi = left_->Min() - right_->Max();
3811     *ma = left_->Max() - right_->Min();
3812   }
3813 
SetRange(int64_t l,int64_t u)3814   void SetRange(int64_t l, int64_t u) override {
3815     const int64_t left_min = left_->Min();
3816     const int64_t right_min = right_->Min();
3817     const int64_t left_max = left_->Max();
3818     const int64_t right_max = right_->Max();
3819     if (l > left_min - right_max) {
3820       left_->SetMin(CapAdd(l, right_min));
3821       right_->SetMax(CapSub(left_max, l));
3822     }
3823     if (u < left_max - right_min) {
3824       left_->SetMax(CapAdd(u, right_max));
3825       right_->SetMin(CapSub(left_min, u));
3826     }
3827   }
3828 
Bound() const3829   bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3830 
name() const3831   std::string name() const override {
3832     return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3833   }
3834 
DebugString() const3835   std::string DebugString() const override {
3836     return absl::StrFormat("(%s - %s)", left_->DebugString(),
3837                            right_->DebugString());
3838   }
3839 
WhenRange(Demon * d)3840   void WhenRange(Demon* d) override {
3841     left_->WhenRange(d);
3842     right_->WhenRange(d);
3843   }
3844 
Accept(ModelVisitor * const visitor) const3845   void Accept(ModelVisitor* const visitor) const override {
3846     visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3847     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3848     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3849                                             right_);
3850     visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3851   }
3852 
left() const3853   IntExpr* left() const { return left_; }
right() const3854   IntExpr* right() const { return right_; }
3855 
3856  protected:
3857   IntExpr* const left_;
3858   IntExpr* const right_;
3859 };
3860 
3861 class SafeSubIntExpr : public SubIntExpr {
3862  public:
SafeSubIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)3863   SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3864       : SubIntExpr(s, l, r) {}
3865 
~SafeSubIntExpr()3866   ~SafeSubIntExpr() override {}
3867 
Min() const3868   int64_t Min() const override { return CapSub(left_->Min(), right_->Max()); }
3869 
SetMin(int64_t m)3870   void SetMin(int64_t m) override {
3871     left_->SetMin(CapAdd(m, right_->Min()));
3872     right_->SetMax(CapSub(left_->Max(), m));
3873   }
3874 
SetRange(int64_t l,int64_t u)3875   void SetRange(int64_t l, int64_t u) override {
3876     const int64_t left_min = left_->Min();
3877     const int64_t right_min = right_->Min();
3878     const int64_t left_max = left_->Max();
3879     const int64_t right_max = right_->Max();
3880     if (l > CapSub(left_min, right_max)) {
3881       left_->SetMin(CapAdd(l, right_min));
3882       right_->SetMax(CapSub(left_max, l));
3883     }
3884     if (u < CapSub(left_max, right_min)) {
3885       left_->SetMax(CapAdd(u, right_max));
3886       right_->SetMin(CapSub(left_min, u));
3887     }
3888   }
3889 
Range(int64_t * mi,int64_t * ma)3890   void Range(int64_t* mi, int64_t* ma) override {
3891     *mi = CapSub(left_->Min(), right_->Max());
3892     *ma = CapSub(left_->Max(), right_->Min());
3893   }
3894 
Max() const3895   int64_t Max() const override { return CapSub(left_->Max(), right_->Min()); }
3896 
SetMax(int64_t m)3897   void SetMax(int64_t m) override {
3898     left_->SetMax(CapAdd(m, right_->Max()));
3899     right_->SetMin(CapSub(left_->Min(), m));
3900   }
3901 };
3902 
3903 // l - r
3904 
3905 // ----- SubIntCstExpr -----
3906 
3907 class SubIntCstExpr : public BaseIntExpr {
3908  public:
SubIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)3909   SubIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3910       : BaseIntExpr(s), expr_(e), value_(v) {}
~SubIntCstExpr()3911   ~SubIntCstExpr() override {}
Min() const3912   int64_t Min() const override { return CapSub(value_, expr_->Max()); }
SetMin(int64_t m)3913   void SetMin(int64_t m) override { expr_->SetMax(CapSub(value_, m)); }
Max() const3914   int64_t Max() const override { return CapSub(value_, expr_->Min()); }
SetMax(int64_t m)3915   void SetMax(int64_t m) override { expr_->SetMin(CapSub(value_, m)); }
Bound() const3916   bool Bound() const override { return (expr_->Bound()); }
name() const3917   std::string name() const override {
3918     return absl::StrFormat("(%d - %s)", value_, expr_->name());
3919   }
DebugString() const3920   std::string DebugString() const override {
3921     return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3922   }
WhenRange(Demon * d)3923   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3924   IntVar* CastToVar() override;
3925 
Accept(ModelVisitor * const visitor) const3926   void Accept(ModelVisitor* const visitor) const override {
3927     visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3928     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3929     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3930                                             expr_);
3931     visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3932   }
3933 
3934  private:
3935   IntExpr* const expr_;
3936   const int64_t value_;
3937 };
3938 
CastToVar()3939 IntVar* SubIntCstExpr::CastToVar() {
3940   if (SubOverflows(value_, expr_->Min()) ||
3941       SubOverflows(value_, expr_->Max())) {
3942     return BaseIntExpr::CastToVar();
3943   }
3944   Solver* const s = solver();
3945   IntVar* const var =
3946       s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3947   return var;
3948 }
3949 
3950 // ----- OppIntExpr -----
3951 
3952 class OppIntExpr : public BaseIntExpr {
3953  public:
OppIntExpr(Solver * const s,IntExpr * const e)3954   OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
~OppIntExpr()3955   ~OppIntExpr() override {}
Min() const3956   int64_t Min() const override { return (CapOpp(expr_->Max())); }
SetMin(int64_t m)3957   void SetMin(int64_t m) override { expr_->SetMax(CapOpp(m)); }
Max() const3958   int64_t Max() const override { return (CapOpp(expr_->Min())); }
SetMax(int64_t m)3959   void SetMax(int64_t m) override { expr_->SetMin(CapOpp(m)); }
Bound() const3960   bool Bound() const override { return (expr_->Bound()); }
name() const3961   std::string name() const override {
3962     return absl::StrFormat("(-%s)", expr_->name());
3963   }
DebugString() const3964   std::string DebugString() const override {
3965     return absl::StrFormat("(-%s)", expr_->DebugString());
3966   }
WhenRange(Demon * d)3967   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3968   IntVar* CastToVar() override;
3969 
Accept(ModelVisitor * const visitor) const3970   void Accept(ModelVisitor* const visitor) const override {
3971     visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3972     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3973                                             expr_);
3974     visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3975   }
3976 
3977  private:
3978   IntExpr* const expr_;
3979 };
3980 
CastToVar()3981 IntVar* OppIntExpr::CastToVar() {
3982   Solver* const s = solver();
3983   IntVar* const var =
3984       s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
3985   return var;
3986 }
3987 
3988 // ----- TimesIntCstExpr -----
3989 
3990 class TimesIntCstExpr : public BaseIntExpr {
3991  public:
TimesIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)3992   TimesIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3993       : BaseIntExpr(s), expr_(e), value_(v) {}
3994 
~TimesIntCstExpr()3995   ~TimesIntCstExpr() override {}
3996 
Bound() const3997   bool Bound() const override { return (expr_->Bound()); }
3998 
name() const3999   std::string name() const override {
4000     return absl::StrFormat("(%s * %d)", expr_->name(), value_);
4001   }
4002 
DebugString() const4003   std::string DebugString() const override {
4004     return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
4005   }
4006 
WhenRange(Demon * d)4007   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4008 
Expr() const4009   IntExpr* Expr() const { return expr_; }
4010 
Constant() const4011   int64_t Constant() const { return value_; }
4012 
Accept(ModelVisitor * const visitor) const4013   void Accept(ModelVisitor* const visitor) const override {
4014     visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4015     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4016                                             expr_);
4017     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4018     visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4019   }
4020 
4021  protected:
4022   IntExpr* const expr_;
4023   const int64_t value_;
4024 };
4025 
4026 // ----- TimesPosIntCstExpr -----
4027 
4028 class TimesPosIntCstExpr : public TimesIntCstExpr {
4029  public:
TimesPosIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)4030   TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4031       : TimesIntCstExpr(s, e, v) {
4032     CHECK_GT(v, 0);
4033   }
4034 
~TimesPosIntCstExpr()4035   ~TimesPosIntCstExpr() override {}
4036 
Min() const4037   int64_t Min() const override { return expr_->Min() * value_; }
4038 
SetMin(int64_t m)4039   void SetMin(int64_t m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4040 
Max() const4041   int64_t Max() const override { return expr_->Max() * value_; }
4042 
SetMax(int64_t m)4043   void SetMax(int64_t m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4044 
CastToVar()4045   IntVar* CastToVar() override {
4046     Solver* const s = solver();
4047     IntVar* var = nullptr;
4048     if (expr_->IsVar() &&
4049         reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4050       var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4051           s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4052     } else {
4053       var = s->RegisterIntVar(
4054           s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4055     }
4056     return var;
4057   }
4058 };
4059 
4060 // This expressions adds safe arithmetic (w.r.t. overflows) compared
4061 // to the previous one.
4062 class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4063  public:
SafeTimesPosIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)4064   SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4065       : TimesIntCstExpr(s, e, v) {
4066     CHECK_GT(v, 0);
4067   }
4068 
~SafeTimesPosIntCstExpr()4069   ~SafeTimesPosIntCstExpr() override {}
4070 
Min() const4071   int64_t Min() const override { return CapProd(expr_->Min(), value_); }
4072 
SetMin(int64_t m)4073   void SetMin(int64_t m) override {
4074     if (m != std::numeric_limits<int64_t>::min()) {
4075       expr_->SetMin(PosIntDivUp(m, value_));
4076     }
4077   }
4078 
Max() const4079   int64_t Max() const override { return CapProd(expr_->Max(), value_); }
4080 
SetMax(int64_t m)4081   void SetMax(int64_t m) override {
4082     if (m != std::numeric_limits<int64_t>::max()) {
4083       expr_->SetMax(PosIntDivDown(m, value_));
4084     }
4085   }
4086 
CastToVar()4087   IntVar* CastToVar() override {
4088     Solver* const s = solver();
4089     IntVar* var = nullptr;
4090     if (expr_->IsVar() &&
4091         reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4092       var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4093           s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4094     } else {
4095       // TODO(user): Check overflows.
4096       var = s->RegisterIntVar(
4097           s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4098     }
4099     return var;
4100   }
4101 };
4102 
4103 // ----- TimesIntNegCstExpr -----
4104 
4105 class TimesIntNegCstExpr : public TimesIntCstExpr {
4106  public:
TimesIntNegCstExpr(Solver * const s,IntExpr * const e,int64_t v)4107   TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4108       : TimesIntCstExpr(s, e, v) {
4109     CHECK_LT(v, 0);
4110   }
4111 
~TimesIntNegCstExpr()4112   ~TimesIntNegCstExpr() override {}
4113 
Min() const4114   int64_t Min() const override { return CapProd(expr_->Max(), value_); }
4115 
SetMin(int64_t m)4116   void SetMin(int64_t m) override {
4117     if (m != std::numeric_limits<int64_t>::min()) {
4118       expr_->SetMax(PosIntDivDown(-m, -value_));
4119     }
4120   }
4121 
Max() const4122   int64_t Max() const override { return CapProd(expr_->Min(), value_); }
4123 
SetMax(int64_t m)4124   void SetMax(int64_t m) override {
4125     if (m != std::numeric_limits<int64_t>::max()) {
4126       expr_->SetMin(PosIntDivUp(-m, -value_));
4127     }
4128   }
4129 
CastToVar()4130   IntVar* CastToVar() override {
4131     Solver* const s = solver();
4132     IntVar* var = nullptr;
4133     var = s->RegisterIntVar(
4134         s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4135     return var;
4136   }
4137 };
4138 
4139 // ----- Utilities for product expression -----
4140 
4141 // Propagates set_min on left * right, left and right >= 0.
SetPosPosMinExpr(IntExpr * const left,IntExpr * const right,int64_t m)4142 void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4143   DCHECK_GE(left->Min(), 0);
4144   DCHECK_GE(right->Min(), 0);
4145   const int64_t lmax = left->Max();
4146   const int64_t rmax = right->Max();
4147   if (m > CapProd(lmax, rmax)) {
4148     left->solver()->Fail();
4149   }
4150   if (m > CapProd(left->Min(), right->Min())) {
4151     // Ok for m == 0 due to left and right being positive
4152     if (0 != rmax) {
4153       left->SetMin(PosIntDivUp(m, rmax));
4154     }
4155     if (0 != lmax) {
4156       right->SetMin(PosIntDivUp(m, lmax));
4157     }
4158   }
4159 }
4160 
4161 // Propagates set_max on left * right, left and right >= 0.
SetPosPosMaxExpr(IntExpr * const left,IntExpr * const right,int64_t m)4162 void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4163   DCHECK_GE(left->Min(), 0);
4164   DCHECK_GE(right->Min(), 0);
4165   const int64_t lmin = left->Min();
4166   const int64_t rmin = right->Min();
4167   if (m < CapProd(lmin, rmin)) {
4168     left->solver()->Fail();
4169   }
4170   if (m < CapProd(left->Max(), right->Max())) {
4171     if (0 != lmin) {
4172       right->SetMax(PosIntDivDown(m, lmin));
4173     }
4174     if (0 != rmin) {
4175       left->SetMax(PosIntDivDown(m, rmin));
4176     }
4177     // else do nothing: 0 is supporting any value from other expr.
4178   }
4179 }
4180 
4181 // Propagates set_min on left * right, left >= 0, right across 0.
SetPosGenMinExpr(IntExpr * const left,IntExpr * const right,int64_t m)4182 void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4183   DCHECK_GE(left->Min(), 0);
4184   DCHECK_GT(right->Max(), 0);
4185   DCHECK_LT(right->Min(), 0);
4186   const int64_t lmax = left->Max();
4187   const int64_t rmax = right->Max();
4188   if (m > CapProd(lmax, rmax)) {
4189     left->solver()->Fail();
4190   }
4191   if (left->Max() == 0) {  // left is bound to 0, product is bound to 0.
4192     DCHECK_EQ(0, left->Min());
4193     DCHECK_LE(m, 0);
4194   } else {
4195     if (m > 0) {  // We deduce right > 0.
4196       left->SetMin(PosIntDivUp(m, rmax));
4197       right->SetMin(PosIntDivUp(m, lmax));
4198     } else if (m == 0) {
4199       const int64_t lmin = left->Min();
4200       if (lmin > 0) {
4201         right->SetMin(0);
4202       }
4203     } else {  // m < 0
4204       const int64_t lmin = left->Min();
4205       if (0 != lmin) {  // We cannot deduce anything if 0 is in the domain.
4206         right->SetMin(-PosIntDivDown(-m, lmin));
4207       }
4208     }
4209   }
4210 }
4211 
4212 // Propagates set_min on left * right, left and right across 0.
SetGenGenMinExpr(IntExpr * const left,IntExpr * const right,int64_t m)4213 void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4214   DCHECK_LT(left->Min(), 0);
4215   DCHECK_GT(left->Max(), 0);
4216   DCHECK_GT(right->Max(), 0);
4217   DCHECK_LT(right->Min(), 0);
4218   const int64_t lmin = left->Min();
4219   const int64_t lmax = left->Max();
4220   const int64_t rmin = right->Min();
4221   const int64_t rmax = right->Max();
4222   if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4223     left->solver()->Fail();
4224   }
4225   if (m > lmin * rmin) {  // Must be positive section * positive section.
4226     left->SetMin(PosIntDivUp(m, rmax));
4227     right->SetMin(PosIntDivUp(m, lmax));
4228   } else if (m > CapProd(lmax, rmax)) {  // Negative section * negative section.
4229     left->SetMax(-PosIntDivUp(m, -rmin));
4230     right->SetMax(-PosIntDivUp(m, -lmin));
4231   }
4232 }
4233 
TimesSetMin(IntExpr * const left,IntExpr * const right,IntExpr * const minus_left,IntExpr * const minus_right,int64_t m)4234 void TimesSetMin(IntExpr* const left, IntExpr* const right,
4235                  IntExpr* const minus_left, IntExpr* const minus_right,
4236                  int64_t m) {
4237   if (left->Min() >= 0) {
4238     if (right->Min() >= 0) {
4239       SetPosPosMinExpr(left, right, m);
4240     } else if (right->Max() <= 0) {
4241       SetPosPosMaxExpr(left, minus_right, -m);
4242     } else {  // right->Min() < 0 && right->Max() > 0
4243       SetPosGenMinExpr(left, right, m);
4244     }
4245   } else if (left->Max() <= 0) {
4246     if (right->Min() >= 0) {
4247       SetPosPosMaxExpr(right, minus_left, -m);
4248     } else if (right->Max() <= 0) {
4249       SetPosPosMinExpr(minus_left, minus_right, m);
4250     } else {  // right->Min() < 0 && right->Max() > 0
4251       SetPosGenMinExpr(minus_left, minus_right, m);
4252     }
4253   } else if (right->Min() >= 0) {  // left->Min() < 0 && left->Max() > 0
4254     SetPosGenMinExpr(right, left, m);
4255   } else if (right->Max() <= 0) {  // left->Min() < 0 && left->Max() > 0
4256     SetPosGenMinExpr(minus_right, minus_left, m);
4257   } else {  // left->Min() < 0 && left->Max() > 0 &&
4258             // right->Min() < 0 && right->Max() > 0
4259     SetGenGenMinExpr(left, right, m);
4260   }
4261 }
4262 
4263 class TimesIntExpr : public BaseIntExpr {
4264  public:
TimesIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)4265   TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4266       : BaseIntExpr(s),
4267         left_(l),
4268         right_(r),
4269         minus_left_(s->MakeOpposite(left_)),
4270         minus_right_(s->MakeOpposite(right_)) {}
~TimesIntExpr()4271   ~TimesIntExpr() override {}
Min() const4272   int64_t Min() const override {
4273     const int64_t lmin = left_->Min();
4274     const int64_t lmax = left_->Max();
4275     const int64_t rmin = right_->Min();
4276     const int64_t rmax = right_->Max();
4277     return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4278                     std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4279   }
4280   void SetMin(int64_t m) override;
Max() const4281   int64_t Max() const override {
4282     const int64_t lmin = left_->Min();
4283     const int64_t lmax = left_->Max();
4284     const int64_t rmin = right_->Min();
4285     const int64_t rmax = right_->Max();
4286     return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4287                     std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4288   }
4289   void SetMax(int64_t m) override;
4290   bool Bound() const override;
name() const4291   std::string name() const override {
4292     return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4293   }
DebugString() const4294   std::string DebugString() const override {
4295     return absl::StrFormat("(%s * %s)", left_->DebugString(),
4296                            right_->DebugString());
4297   }
WhenRange(Demon * d)4298   void WhenRange(Demon* d) override {
4299     left_->WhenRange(d);
4300     right_->WhenRange(d);
4301   }
4302 
Accept(ModelVisitor * const visitor) const4303   void Accept(ModelVisitor* const visitor) const override {
4304     visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4305     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4306     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4307                                             right_);
4308     visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4309   }
4310 
4311  private:
4312   IntExpr* const left_;
4313   IntExpr* const right_;
4314   IntExpr* const minus_left_;
4315   IntExpr* const minus_right_;
4316 };
4317 
SetMin(int64_t m)4318 void TimesIntExpr::SetMin(int64_t m) {
4319   if (m != std::numeric_limits<int64_t>::min()) {
4320     TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4321   }
4322 }
4323 
SetMax(int64_t m)4324 void TimesIntExpr::SetMax(int64_t m) {
4325   if (m != std::numeric_limits<int64_t>::max()) {
4326     TimesSetMin(left_, minus_right_, minus_left_, right_, CapOpp(m));
4327   }
4328 }
4329 
Bound() const4330 bool TimesIntExpr::Bound() const {
4331   const bool left_bound = left_->Bound();
4332   const bool right_bound = right_->Bound();
4333   return ((left_bound && left_->Max() == 0) ||
4334           (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4335 }
4336 
4337 // ----- TimesPosIntExpr -----
4338 
4339 class TimesPosIntExpr : public BaseIntExpr {
4340  public:
TimesPosIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)4341   TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4342       : BaseIntExpr(s), left_(l), right_(r) {}
~TimesPosIntExpr()4343   ~TimesPosIntExpr() override {}
Min() const4344   int64_t Min() const override { return (left_->Min() * right_->Min()); }
4345   void SetMin(int64_t m) override;
Max() const4346   int64_t Max() const override { return (left_->Max() * right_->Max()); }
4347   void SetMax(int64_t m) override;
4348   bool Bound() const override;
name() const4349   std::string name() const override {
4350     return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4351   }
DebugString() const4352   std::string DebugString() const override {
4353     return absl::StrFormat("(%s * %s)", left_->DebugString(),
4354                            right_->DebugString());
4355   }
WhenRange(Demon * d)4356   void WhenRange(Demon* d) override {
4357     left_->WhenRange(d);
4358     right_->WhenRange(d);
4359   }
4360 
Accept(ModelVisitor * const visitor) const4361   void Accept(ModelVisitor* const visitor) const override {
4362     visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4363     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4364     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4365                                             right_);
4366     visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4367   }
4368 
4369  private:
4370   IntExpr* const left_;
4371   IntExpr* const right_;
4372 };
4373 
SetMin(int64_t m)4374 void TimesPosIntExpr::SetMin(int64_t m) { SetPosPosMinExpr(left_, right_, m); }
4375 
SetMax(int64_t m)4376 void TimesPosIntExpr::SetMax(int64_t m) { SetPosPosMaxExpr(left_, right_, m); }
4377 
Bound() const4378 bool TimesPosIntExpr::Bound() const {
4379   return (left_->Max() == 0 || right_->Max() == 0 ||
4380           (left_->Bound() && right_->Bound()));
4381 }
4382 
4383 // ----- SafeTimesPosIntExpr -----
4384 
4385 class SafeTimesPosIntExpr : public BaseIntExpr {
4386  public:
SafeTimesPosIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)4387   SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4388       : BaseIntExpr(s), left_(l), right_(r) {}
~SafeTimesPosIntExpr()4389   ~SafeTimesPosIntExpr() override {}
Min() const4390   int64_t Min() const override { return CapProd(left_->Min(), right_->Min()); }
SetMin(int64_t m)4391   void SetMin(int64_t m) override {
4392     if (m != std::numeric_limits<int64_t>::min()) {
4393       SetPosPosMinExpr(left_, right_, m);
4394     }
4395   }
Max() const4396   int64_t Max() const override { return CapProd(left_->Max(), right_->Max()); }
SetMax(int64_t m)4397   void SetMax(int64_t m) override {
4398     if (m != std::numeric_limits<int64_t>::max()) {
4399       SetPosPosMaxExpr(left_, right_, m);
4400     }
4401   }
Bound() const4402   bool Bound() const override {
4403     return (left_->Max() == 0 || right_->Max() == 0 ||
4404             (left_->Bound() && right_->Bound()));
4405   }
name() const4406   std::string name() const override {
4407     return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4408   }
DebugString() const4409   std::string DebugString() const override {
4410     return absl::StrFormat("(%s * %s)", left_->DebugString(),
4411                            right_->DebugString());
4412   }
WhenRange(Demon * d)4413   void WhenRange(Demon* d) override {
4414     left_->WhenRange(d);
4415     right_->WhenRange(d);
4416   }
4417 
Accept(ModelVisitor * const visitor) const4418   void Accept(ModelVisitor* const visitor) const override {
4419     visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4420     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4421     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4422                                             right_);
4423     visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4424   }
4425 
4426  private:
4427   IntExpr* const left_;
4428   IntExpr* const right_;
4429 };
4430 
4431 // ----- TimesBooleanPosIntExpr -----
4432 
4433 class TimesBooleanPosIntExpr : public BaseIntExpr {
4434  public:
TimesBooleanPosIntExpr(Solver * const s,BooleanVar * const b,IntExpr * const e)4435   TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4436       : BaseIntExpr(s), boolvar_(b), expr_(e) {}
~TimesBooleanPosIntExpr()4437   ~TimesBooleanPosIntExpr() override {}
Min() const4438   int64_t Min() const override {
4439     return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4440   }
4441   void SetMin(int64_t m) override;
Max() const4442   int64_t Max() const override {
4443     return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4444   }
4445   void SetMax(int64_t m) override;
4446   void Range(int64_t* mi, int64_t* ma) override;
4447   void SetRange(int64_t mi, int64_t ma) override;
4448   bool Bound() const override;
name() const4449   std::string name() const override {
4450     return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4451   }
DebugString() const4452   std::string DebugString() const override {
4453     return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4454                            expr_->DebugString());
4455   }
WhenRange(Demon * d)4456   void WhenRange(Demon* d) override {
4457     boolvar_->WhenRange(d);
4458     expr_->WhenRange(d);
4459   }
4460 
Accept(ModelVisitor * const visitor) const4461   void Accept(ModelVisitor* const visitor) const override {
4462     visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4463     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4464                                             boolvar_);
4465     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4466                                             expr_);
4467     visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4468   }
4469 
4470  private:
4471   BooleanVar* const boolvar_;
4472   IntExpr* const expr_;
4473 };
4474 
SetMin(int64_t m)4475 void TimesBooleanPosIntExpr::SetMin(int64_t m) {
4476   if (m > 0) {
4477     boolvar_->SetValue(1);
4478     expr_->SetMin(m);
4479   }
4480 }
4481 
SetMax(int64_t m)4482 void TimesBooleanPosIntExpr::SetMax(int64_t m) {
4483   if (m < 0) {
4484     solver()->Fail();
4485   }
4486   if (m < expr_->Min()) {
4487     boolvar_->SetValue(0);
4488   }
4489   if (boolvar_->RawValue() == 1) {
4490     expr_->SetMax(m);
4491   }
4492 }
4493 
Range(int64_t * mi,int64_t * ma)4494 void TimesBooleanPosIntExpr::Range(int64_t* mi, int64_t* ma) {
4495   const int value = boolvar_->RawValue();
4496   if (value == 0) {
4497     *mi = 0;
4498     *ma = 0;
4499   } else if (value == 1) {
4500     expr_->Range(mi, ma);
4501   } else {
4502     *mi = 0;
4503     *ma = expr_->Max();
4504   }
4505 }
4506 
SetRange(int64_t mi,int64_t ma)4507 void TimesBooleanPosIntExpr::SetRange(int64_t mi, int64_t ma) {
4508   if (ma < 0 || mi > ma) {
4509     solver()->Fail();
4510   }
4511   if (mi > 0) {
4512     boolvar_->SetValue(1);
4513     expr_->SetMin(mi);
4514   }
4515   if (ma < expr_->Min()) {
4516     boolvar_->SetValue(0);
4517   }
4518   if (boolvar_->RawValue() == 1) {
4519     expr_->SetMax(ma);
4520   }
4521 }
4522 
Bound() const4523 bool TimesBooleanPosIntExpr::Bound() const {
4524   return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4525           (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4526            expr_->Bound()));
4527 }
4528 
4529 // ----- TimesBooleanIntExpr -----
4530 
4531 class TimesBooleanIntExpr : public BaseIntExpr {
4532  public:
TimesBooleanIntExpr(Solver * const s,BooleanVar * const b,IntExpr * const e)4533   TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4534       : BaseIntExpr(s), boolvar_(b), expr_(e) {}
~TimesBooleanIntExpr()4535   ~TimesBooleanIntExpr() override {}
Min() const4536   int64_t Min() const override {
4537     switch (boolvar_->RawValue()) {
4538       case 0: {
4539         return 0LL;
4540       }
4541       case 1: {
4542         return expr_->Min();
4543       }
4544       default: {
4545         DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4546         return std::min(int64_t{0}, expr_->Min());
4547       }
4548     }
4549   }
4550   void SetMin(int64_t m) override;
Max() const4551   int64_t Max() const override {
4552     switch (boolvar_->RawValue()) {
4553       case 0: {
4554         return 0LL;
4555       }
4556       case 1: {
4557         return expr_->Max();
4558       }
4559       default: {
4560         DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4561         return std::max(int64_t{0}, expr_->Max());
4562       }
4563     }
4564   }
4565   void SetMax(int64_t m) override;
4566   void Range(int64_t* mi, int64_t* ma) override;
4567   void SetRange(int64_t mi, int64_t ma) override;
4568   bool Bound() const override;
name() const4569   std::string name() const override {
4570     return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4571   }
DebugString() const4572   std::string DebugString() const override {
4573     return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4574                            expr_->DebugString());
4575   }
WhenRange(Demon * d)4576   void WhenRange(Demon* d) override {
4577     boolvar_->WhenRange(d);
4578     expr_->WhenRange(d);
4579   }
4580 
Accept(ModelVisitor * const visitor) const4581   void Accept(ModelVisitor* const visitor) const override {
4582     visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4583     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4584                                             boolvar_);
4585     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4586                                             expr_);
4587     visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4588   }
4589 
4590  private:
4591   BooleanVar* const boolvar_;
4592   IntExpr* const expr_;
4593 };
4594 
SetMin(int64_t m)4595 void TimesBooleanIntExpr::SetMin(int64_t m) {
4596   switch (boolvar_->RawValue()) {
4597     case 0: {
4598       if (m > 0) {
4599         solver()->Fail();
4600       }
4601       break;
4602     }
4603     case 1: {
4604       expr_->SetMin(m);
4605       break;
4606     }
4607     default: {
4608       DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4609       if (m > 0) {  // 0 is no longer possible for boolvar because min > 0.
4610         boolvar_->SetValue(1);
4611         expr_->SetMin(m);
4612       } else if (m <= 0 && expr_->Max() < m) {
4613         boolvar_->SetValue(0);
4614       }
4615     }
4616   }
4617 }
4618 
SetMax(int64_t m)4619 void TimesBooleanIntExpr::SetMax(int64_t m) {
4620   switch (boolvar_->RawValue()) {
4621     case 0: {
4622       if (m < 0) {
4623         solver()->Fail();
4624       }
4625       break;
4626     }
4627     case 1: {
4628       expr_->SetMax(m);
4629       break;
4630     }
4631     default: {
4632       DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4633       if (m < 0) {  // 0 is no longer possible for boolvar because max < 0.
4634         boolvar_->SetValue(1);
4635         expr_->SetMax(m);
4636       } else if (m >= 0 && expr_->Min() > m) {
4637         boolvar_->SetValue(0);
4638       }
4639     }
4640   }
4641 }
4642 
Range(int64_t * mi,int64_t * ma)4643 void TimesBooleanIntExpr::Range(int64_t* mi, int64_t* ma) {
4644   switch (boolvar_->RawValue()) {
4645     case 0: {
4646       *mi = 0;
4647       *ma = 0;
4648       break;
4649     }
4650     case 1: {
4651       *mi = expr_->Min();
4652       *ma = expr_->Max();
4653       break;
4654     }
4655     default: {
4656       DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4657       *mi = std::min(int64_t{0}, expr_->Min());
4658       *ma = std::max(int64_t{0}, expr_->Max());
4659       break;
4660     }
4661   }
4662 }
4663 
SetRange(int64_t mi,int64_t ma)4664 void TimesBooleanIntExpr::SetRange(int64_t mi, int64_t ma) {
4665   if (mi > ma) {
4666     solver()->Fail();
4667   }
4668   switch (boolvar_->RawValue()) {
4669     case 0: {
4670       if (mi > 0 || ma < 0) {
4671         solver()->Fail();
4672       }
4673       break;
4674     }
4675     case 1: {
4676       expr_->SetRange(mi, ma);
4677       break;
4678     }
4679     default: {
4680       DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4681       if (mi > 0) {
4682         boolvar_->SetValue(1);
4683         expr_->SetMin(mi);
4684       } else if (mi == 0 && expr_->Max() < 0) {
4685         boolvar_->SetValue(0);
4686       }
4687       if (ma < 0) {
4688         boolvar_->SetValue(1);
4689         expr_->SetMax(ma);
4690       } else if (ma == 0 && expr_->Min() > 0) {
4691         boolvar_->SetValue(0);
4692       }
4693       break;
4694     }
4695   }
4696 }
4697 
Bound() const4698 bool TimesBooleanIntExpr::Bound() const {
4699   return (boolvar_->RawValue() == 0 ||
4700           (expr_->Bound() &&
4701            (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4702             expr_->Max() == 0)));
4703 }
4704 
4705 // ----- DivPosIntCstExpr -----
4706 
4707 class DivPosIntCstExpr : public BaseIntExpr {
4708  public:
DivPosIntCstExpr(Solver * const s,IntExpr * const e,int64_t v)4709   DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4710       : BaseIntExpr(s), expr_(e), value_(v) {
4711     CHECK_GE(v, 0);
4712   }
~DivPosIntCstExpr()4713   ~DivPosIntCstExpr() override {}
4714 
Min() const4715   int64_t Min() const override { return expr_->Min() / value_; }
4716 
SetMin(int64_t m)4717   void SetMin(int64_t m) override {
4718     if (m > 0) {
4719       expr_->SetMin(m * value_);
4720     } else {
4721       expr_->SetMin((m - 1) * value_ + 1);
4722     }
4723   }
Max() const4724   int64_t Max() const override { return expr_->Max() / value_; }
4725 
SetMax(int64_t m)4726   void SetMax(int64_t m) override {
4727     if (m >= 0) {
4728       expr_->SetMax((m + 1) * value_ - 1);
4729     } else {
4730       expr_->SetMax(m * value_);
4731     }
4732   }
4733 
name() const4734   std::string name() const override {
4735     return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4736   }
4737 
DebugString() const4738   std::string DebugString() const override {
4739     return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4740   }
4741 
WhenRange(Demon * d)4742   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4743 
Accept(ModelVisitor * const visitor) const4744   void Accept(ModelVisitor* const visitor) const override {
4745     visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4746     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4747                                             expr_);
4748     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4749     visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4750   }
4751 
4752  private:
4753   IntExpr* const expr_;
4754   const int64_t value_;
4755 };
4756 
4757 // DivPosIntExpr
4758 
4759 class DivPosIntExpr : public BaseIntExpr {
4760  public:
DivPosIntExpr(Solver * const s,IntExpr * const num,IntExpr * const denom)4761   DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4762       : BaseIntExpr(s),
4763         num_(num),
4764         denom_(denom),
4765         opp_num_(s->MakeOpposite(num)) {}
4766 
~DivPosIntExpr()4767   ~DivPosIntExpr() override {}
4768 
Min() const4769   int64_t Min() const override {
4770     return num_->Min() >= 0
4771                ? num_->Min() / denom_->Max()
4772                : (denom_->Min() == 0 ? num_->Min()
4773                                      : num_->Min() / denom_->Min());
4774   }
4775 
Max() const4776   int64_t Max() const override {
4777     return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4778                                                   : num_->Max() / denom_->Min())
4779                             : num_->Max() / denom_->Max();
4780   }
4781 
SetPosMin(IntExpr * const num,IntExpr * const denom,int64_t m)4782   static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4783     num->SetMin(m * denom->Min());
4784     denom->SetMax(num->Max() / m);
4785   }
4786 
SetPosMax(IntExpr * const num,IntExpr * const denom,int64_t m)4787   static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
4788     num->SetMax((m + 1) * denom->Max() - 1);
4789     denom->SetMin(num->Min() / (m + 1) + 1);
4790   }
4791 
SetMin(int64_t m)4792   void SetMin(int64_t m) override {
4793     if (m > 0) {
4794       SetPosMin(num_, denom_, m);
4795     } else {
4796       SetPosMax(opp_num_, denom_, -m);
4797     }
4798   }
4799 
SetMax(int64_t m)4800   void SetMax(int64_t m) override {
4801     if (m >= 0) {
4802       SetPosMax(num_, denom_, m);
4803     } else {
4804       SetPosMin(opp_num_, denom_, -m);
4805     }
4806   }
4807 
name() const4808   std::string name() const override {
4809     return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4810   }
DebugString() const4811   std::string DebugString() const override {
4812     return absl::StrFormat("(%s div %s)", num_->DebugString(),
4813                            denom_->DebugString());
4814   }
WhenRange(Demon * d)4815   void WhenRange(Demon* d) override {
4816     num_->WhenRange(d);
4817     denom_->WhenRange(d);
4818   }
4819 
Accept(ModelVisitor * const visitor) const4820   void Accept(ModelVisitor* const visitor) const override {
4821     visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4822     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4823     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4824                                             denom_);
4825     visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4826   }
4827 
4828  private:
4829   IntExpr* const num_;
4830   IntExpr* const denom_;
4831   IntExpr* const opp_num_;
4832 };
4833 
4834 class DivPosPosIntExpr : public BaseIntExpr {
4835  public:
DivPosPosIntExpr(Solver * const s,IntExpr * const num,IntExpr * const denom)4836   DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4837       : BaseIntExpr(s), num_(num), denom_(denom) {}
4838 
~DivPosPosIntExpr()4839   ~DivPosPosIntExpr() override {}
4840 
Min() const4841   int64_t Min() const override {
4842     if (denom_->Max() == 0) {
4843       solver()->Fail();
4844     }
4845     return num_->Min() / denom_->Max();
4846   }
4847 
Max() const4848   int64_t Max() const override {
4849     if (denom_->Min() == 0) {
4850       return num_->Max();
4851     } else {
4852       return num_->Max() / denom_->Min();
4853     }
4854   }
4855 
SetMin(int64_t m)4856   void SetMin(int64_t m) override {
4857     if (m > 0) {
4858       num_->SetMin(m * denom_->Min());
4859       denom_->SetMax(num_->Max() / m);
4860     }
4861   }
4862 
SetMax(int64_t m)4863   void SetMax(int64_t m) override {
4864     if (m >= 0) {
4865       num_->SetMax((m + 1) * denom_->Max() - 1);
4866       denom_->SetMin(num_->Min() / (m + 1) + 1);
4867     } else {
4868       solver()->Fail();
4869     }
4870   }
4871 
name() const4872   std::string name() const override {
4873     return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4874   }
4875 
DebugString() const4876   std::string DebugString() const override {
4877     return absl::StrFormat("(%s div %s)", num_->DebugString(),
4878                            denom_->DebugString());
4879   }
4880 
WhenRange(Demon * d)4881   void WhenRange(Demon* d) override {
4882     num_->WhenRange(d);
4883     denom_->WhenRange(d);
4884   }
4885 
Accept(ModelVisitor * const visitor) const4886   void Accept(ModelVisitor* const visitor) const override {
4887     visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4888     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4889     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4890                                             denom_);
4891     visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4892   }
4893 
4894  private:
4895   IntExpr* const num_;
4896   IntExpr* const denom_;
4897 };
4898 
4899 // DivIntExpr
4900 
4901 class DivIntExpr : public BaseIntExpr {
4902  public:
DivIntExpr(Solver * const s,IntExpr * const num,IntExpr * const denom)4903   DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4904       : BaseIntExpr(s),
4905         num_(num),
4906         denom_(denom),
4907         opp_num_(s->MakeOpposite(num)) {}
4908 
~DivIntExpr()4909   ~DivIntExpr() override {}
4910 
Min() const4911   int64_t Min() const override {
4912     const int64_t num_min = num_->Min();
4913     const int64_t num_max = num_->Max();
4914     const int64_t denom_min = denom_->Min();
4915     const int64_t denom_max = denom_->Max();
4916 
4917     if (denom_min == 0 && denom_max == 0) {
4918       return std::numeric_limits<int64_t>::max();  // TODO(user): Check this
4919                                                    // convention.
4920     }
4921 
4922     if (denom_min >= 0) {  // Denominator strictly positive.
4923       DCHECK_GT(denom_max, 0);
4924       const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4925       return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4926     } else if (denom_max <= 0) {  // Denominator strictly negative.
4927       DCHECK_LT(denom_min, 0);
4928       const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4929       return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4930     } else {  // Denominator across 0.
4931       return std::min(num_min, -num_max);
4932     }
4933   }
4934 
Max() const4935   int64_t Max() const override {
4936     const int64_t num_min = num_->Min();
4937     const int64_t num_max = num_->Max();
4938     const int64_t denom_min = denom_->Min();
4939     const int64_t denom_max = denom_->Max();
4940 
4941     if (denom_min == 0 && denom_max == 0) {
4942       return std::numeric_limits<int64_t>::min();  // TODO(user): Check this
4943                                                    // convention.
4944     }
4945 
4946     if (denom_min >= 0) {  // Denominator strictly positive.
4947       DCHECK_GT(denom_max, 0);
4948       const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4949       return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4950     } else if (denom_max <= 0) {  // Denominator strictly negative.
4951       DCHECK_LT(denom_min, 0);
4952       const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4953       return num_min >= 0 ? num_min / denom_min
4954                           : -num_min / -adjusted_denom_max;
4955     } else {  // Denominator across 0.
4956       return std::max(num_max, -num_min);
4957     }
4958   }
4959 
AdjustDenominator()4960   void AdjustDenominator() {
4961     if (denom_->Min() == 0) {
4962       denom_->SetMin(1);
4963     } else if (denom_->Max() == 0) {
4964       denom_->SetMax(-1);
4965     }
4966   }
4967 
4968   // m > 0.
SetPosMin(IntExpr * const num,IntExpr * const denom,int64_t m)4969   static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4970     DCHECK_GT(m, 0);
4971     const int64_t num_min = num->Min();
4972     const int64_t num_max = num->Max();
4973     const int64_t denom_min = denom->Min();
4974     const int64_t denom_max = denom->Max();
4975     DCHECK_NE(denom_min, 0);
4976     DCHECK_NE(denom_max, 0);
4977     if (denom_min > 0) {  // Denominator strictly positive.
4978       num->SetMin(m * denom_min);
4979       denom->SetMax(num_max / m);
4980     } else if (denom_max < 0) {  // Denominator strictly negative.
4981       num->SetMax(m * denom_max);
4982       denom->SetMin(num_min / m);
4983     } else {  // Denominator across 0.
4984       if (num_min >= 0) {
4985         num->SetMin(m);
4986         denom->SetRange(1, num_max / m);
4987       } else if (num_max <= 0) {
4988         num->SetMax(-m);
4989         denom->SetRange(num_min / m, -1);
4990       } else {
4991         if (m > -num_min) {  // Denominator is forced positive.
4992           num->SetMin(m);
4993           denom->SetRange(1, num_max / m);
4994         } else if (m > num_max) {  // Denominator is forced negative.
4995           num->SetMax(-m);
4996           denom->SetRange(num_min / m, -1);
4997         } else {
4998           denom->SetRange(num_min / m, num_max / m);
4999         }
5000       }
5001     }
5002   }
5003 
5004   // m >= 0.
SetPosMax(IntExpr * const num,IntExpr * const denom,int64_t m)5005   static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
5006     DCHECK_GE(m, 0);
5007     const int64_t num_min = num->Min();
5008     const int64_t num_max = num->Max();
5009     const int64_t denom_min = denom->Min();
5010     const int64_t denom_max = denom->Max();
5011     DCHECK_NE(denom_min, 0);
5012     DCHECK_NE(denom_max, 0);
5013     if (denom_min > 0) {  // Denominator strictly positive.
5014       num->SetMax((m + 1) * denom_max - 1);
5015       denom->SetMin((num_min / (m + 1)) + 1);
5016     } else if (denom_max < 0) {
5017       num->SetMin((m + 1) * denom_min + 1);
5018       denom->SetMax(num_max / (m + 1) - 1);
5019     } else if (num_min > (m + 1) * denom_max - 1) {
5020       denom->SetMax(-1);
5021     } else if (num_max < (m + 1) * denom_min + 1) {
5022       denom->SetMin(1);
5023     }
5024   }
5025 
SetMin(int64_t m)5026   void SetMin(int64_t m) override {
5027     AdjustDenominator();
5028     if (m > 0) {
5029       SetPosMin(num_, denom_, m);
5030     } else {
5031       SetPosMax(opp_num_, denom_, -m);
5032     }
5033   }
5034 
SetMax(int64_t m)5035   void SetMax(int64_t m) override {
5036     AdjustDenominator();
5037     if (m >= 0) {
5038       SetPosMax(num_, denom_, m);
5039     } else {
5040       SetPosMin(opp_num_, denom_, -m);
5041     }
5042   }
5043 
name() const5044   std::string name() const override {
5045     return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5046   }
DebugString() const5047   std::string DebugString() const override {
5048     return absl::StrFormat("(%s div %s)", num_->DebugString(),
5049                            denom_->DebugString());
5050   }
WhenRange(Demon * d)5051   void WhenRange(Demon* d) override {
5052     num_->WhenRange(d);
5053     denom_->WhenRange(d);
5054   }
5055 
Accept(ModelVisitor * const visitor) const5056   void Accept(ModelVisitor* const visitor) const override {
5057     visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5058     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5059     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5060                                             denom_);
5061     visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5062   }
5063 
5064  private:
5065   IntExpr* const num_;
5066   IntExpr* const denom_;
5067   IntExpr* const opp_num_;
5068 };
5069 
5070 // ----- IntAbs And IntAbsConstraint ------
5071 
5072 class IntAbsConstraint : public CastConstraint {
5073  public:
IntAbsConstraint(Solver * const s,IntVar * const sub,IntVar * const target)5074   IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5075       : CastConstraint(s, target), sub_(sub) {}
5076 
~IntAbsConstraint()5077   ~IntAbsConstraint() override {}
5078 
Post()5079   void Post() override {
5080     Demon* const sub_demon = MakeConstraintDemon0(
5081         solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5082     sub_->WhenRange(sub_demon);
5083     Demon* const target_demon = MakeConstraintDemon0(
5084         solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5085     target_var_->WhenRange(target_demon);
5086   }
5087 
InitialPropagate()5088   void InitialPropagate() override {
5089     PropagateSub();
5090     PropagateTarget();
5091   }
5092 
PropagateSub()5093   void PropagateSub() {
5094     const int64_t smin = sub_->Min();
5095     const int64_t smax = sub_->Max();
5096     if (smax <= 0) {
5097       target_var_->SetRange(-smax, -smin);
5098     } else if (smin >= 0) {
5099       target_var_->SetRange(smin, smax);
5100     } else {
5101       target_var_->SetRange(0, std::max(-smin, smax));
5102     }
5103   }
5104 
PropagateTarget()5105   void PropagateTarget() {
5106     const int64_t target_max = target_var_->Max();
5107     sub_->SetRange(-target_max, target_max);
5108     const int64_t target_min = target_var_->Min();
5109     if (target_min > 0) {
5110       if (sub_->Min() > -target_min) {
5111         sub_->SetMin(target_min);
5112       } else if (sub_->Max() < target_min) {
5113         sub_->SetMax(-target_min);
5114       }
5115     }
5116   }
5117 
DebugString() const5118   std::string DebugString() const override {
5119     return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5120                            target_var_->DebugString());
5121   }
5122 
Accept(ModelVisitor * const visitor) const5123   void Accept(ModelVisitor* const visitor) const override {
5124     visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5125     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5126                                             sub_);
5127     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5128                                             target_var_);
5129     visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5130   }
5131 
5132  private:
5133   IntVar* const sub_;
5134 };
5135 
5136 class IntAbs : public BaseIntExpr {
5137  public:
IntAbs(Solver * const s,IntExpr * const e)5138   IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5139 
~IntAbs()5140   ~IntAbs() override {}
5141 
Min() const5142   int64_t Min() const override {
5143     int64_t emin = 0;
5144     int64_t emax = 0;
5145     expr_->Range(&emin, &emax);
5146     if (emin >= 0) {
5147       return emin;
5148     }
5149     if (emax <= 0) {
5150       return -emax;
5151     }
5152     return 0;
5153   }
5154 
SetMin(int64_t m)5155   void SetMin(int64_t m) override {
5156     if (m > 0) {
5157       int64_t emin = 0;
5158       int64_t emax = 0;
5159       expr_->Range(&emin, &emax);
5160       if (emin > -m) {
5161         expr_->SetMin(m);
5162       } else if (emax < m) {
5163         expr_->SetMax(-m);
5164       }
5165     }
5166   }
5167 
Max() const5168   int64_t Max() const override {
5169     int64_t emin = 0;
5170     int64_t emax = 0;
5171     expr_->Range(&emin, &emax);
5172     return std::max(-emin, emax);
5173   }
5174 
SetMax(int64_t m)5175   void SetMax(int64_t m) override { expr_->SetRange(-m, m); }
5176 
SetRange(int64_t mi,int64_t ma)5177   void SetRange(int64_t mi, int64_t ma) override {
5178     expr_->SetRange(-ma, ma);
5179     if (mi > 0) {
5180       int64_t emin = 0;
5181       int64_t emax = 0;
5182       expr_->Range(&emin, &emax);
5183       if (emin > -mi) {
5184         expr_->SetMin(mi);
5185       } else if (emax < mi) {
5186         expr_->SetMax(-mi);
5187       }
5188     }
5189   }
5190 
Range(int64_t * mi,int64_t * ma)5191   void Range(int64_t* mi, int64_t* ma) override {
5192     int64_t emin = 0;
5193     int64_t emax = 0;
5194     expr_->Range(&emin, &emax);
5195     if (emin >= 0) {
5196       *mi = emin;
5197       *ma = emax;
5198     } else if (emax <= 0) {
5199       *mi = -emax;
5200       *ma = -emin;
5201     } else {
5202       *mi = 0;
5203       *ma = std::max(-emin, emax);
5204     }
5205   }
5206 
Bound() const5207   bool Bound() const override { return expr_->Bound(); }
5208 
WhenRange(Demon * d)5209   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5210 
name() const5211   std::string name() const override {
5212     return absl::StrFormat("IntAbs(%s)", expr_->name());
5213   }
5214 
DebugString() const5215   std::string DebugString() const override {
5216     return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5217   }
5218 
Accept(ModelVisitor * const visitor) const5219   void Accept(ModelVisitor* const visitor) const override {
5220     visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5221     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5222                                             expr_);
5223     visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5224   }
5225 
CastToVar()5226   IntVar* CastToVar() override {
5227     int64_t min_value = 0;
5228     int64_t max_value = 0;
5229     Range(&min_value, &max_value);
5230     Solver* const s = solver();
5231     const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5232     IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5233     CastConstraint* const ct =
5234         s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5235     s->AddCastConstraint(ct, target, this);
5236     return target;
5237   }
5238 
5239  private:
5240   IntExpr* const expr_;
5241 };
5242 
5243 // ----- Square -----
5244 
5245 // TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5246 class IntSquare : public BaseIntExpr {
5247  public:
IntSquare(Solver * const s,IntExpr * const e)5248   IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
~IntSquare()5249   ~IntSquare() override {}
5250 
Min() const5251   int64_t Min() const override {
5252     const int64_t emin = expr_->Min();
5253     if (emin >= 0) {
5254       return emin >= std::numeric_limits<int32_t>::max()
5255                  ? std::numeric_limits<int64_t>::max()
5256                  : emin * emin;
5257     }
5258     const int64_t emax = expr_->Max();
5259     if (emax < 0) {
5260       return emax <= -std::numeric_limits<int32_t>::max()
5261                  ? std::numeric_limits<int64_t>::max()
5262                  : emax * emax;
5263     }
5264     return 0LL;
5265   }
SetMin(int64_t m)5266   void SetMin(int64_t m) override {
5267     if (m <= 0) {
5268       return;
5269     }
5270     // TODO(user): What happens if m is kint64max?
5271     const int64_t emin = expr_->Min();
5272     const int64_t emax = expr_->Max();
5273     const int64_t root =
5274         static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5275     if (emin >= 0) {
5276       expr_->SetMin(root);
5277     } else if (emax <= 0) {
5278       expr_->SetMax(-root);
5279     } else if (expr_->IsVar()) {
5280       reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5281     }
5282   }
Max() const5283   int64_t Max() const override {
5284     const int64_t emax = expr_->Max();
5285     const int64_t emin = expr_->Min();
5286     if (emax >= std::numeric_limits<int32_t>::max() ||
5287         emin <= -std::numeric_limits<int32_t>::max()) {
5288       return std::numeric_limits<int64_t>::max();
5289     }
5290     return std::max(emin * emin, emax * emax);
5291   }
SetMax(int64_t m)5292   void SetMax(int64_t m) override {
5293     if (m < 0) {
5294       solver()->Fail();
5295     }
5296     if (m == std::numeric_limits<int64_t>::max()) {
5297       return;
5298     }
5299     const int64_t root =
5300         static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5301     expr_->SetRange(-root, root);
5302   }
Bound() const5303   bool Bound() const override { return expr_->Bound(); }
WhenRange(Demon * d)5304   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
name() const5305   std::string name() const override {
5306     return absl::StrFormat("IntSquare(%s)", expr_->name());
5307   }
DebugString() const5308   std::string DebugString() const override {
5309     return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5310   }
5311 
Accept(ModelVisitor * const visitor) const5312   void Accept(ModelVisitor* const visitor) const override {
5313     visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5314     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5315                                             expr_);
5316     visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5317   }
5318 
expr() const5319   IntExpr* expr() const { return expr_; }
5320 
5321  protected:
5322   IntExpr* const expr_;
5323 };
5324 
5325 class PosIntSquare : public IntSquare {
5326  public:
PosIntSquare(Solver * const s,IntExpr * const e)5327   PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
~PosIntSquare()5328   ~PosIntSquare() override {}
5329 
Min() const5330   int64_t Min() const override {
5331     const int64_t emin = expr_->Min();
5332     return emin >= std::numeric_limits<int32_t>::max()
5333                ? std::numeric_limits<int64_t>::max()
5334                : emin * emin;
5335   }
SetMin(int64_t m)5336   void SetMin(int64_t m) override {
5337     if (m <= 0) {
5338       return;
5339     }
5340     const int64_t root =
5341         static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5342     expr_->SetMin(root);
5343   }
Max() const5344   int64_t Max() const override {
5345     const int64_t emax = expr_->Max();
5346     return emax >= std::numeric_limits<int32_t>::max()
5347                ? std::numeric_limits<int64_t>::max()
5348                : emax * emax;
5349   }
SetMax(int64_t m)5350   void SetMax(int64_t m) override {
5351     if (m < 0) {
5352       solver()->Fail();
5353     }
5354     if (m == std::numeric_limits<int64_t>::max()) {
5355       return;
5356     }
5357     const int64_t root =
5358         static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5359     expr_->SetMax(root);
5360   }
5361 };
5362 
5363 // ----- EvenPower -----
5364 
IntPower(int64_t value,int64_t power)5365 int64_t IntPower(int64_t value, int64_t power) {
5366   int64_t result = value;
5367   // TODO(user): Speed that up.
5368   for (int i = 1; i < power; ++i) {
5369     result *= value;
5370   }
5371   return result;
5372 }
5373 
OverflowLimit(int64_t power)5374 int64_t OverflowLimit(int64_t power) {
5375   return static_cast<int64_t>(floor(exp(
5376       log(static_cast<double>(std::numeric_limits<int64_t>::max())) / power)));
5377 }
5378 
5379 class BasePower : public BaseIntExpr {
5380  public:
BasePower(Solver * const s,IntExpr * const e,int64_t n)5381   BasePower(Solver* const s, IntExpr* const e, int64_t n)
5382       : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5383     CHECK_GT(n, 0);
5384   }
5385 
~BasePower()5386   ~BasePower() override {}
5387 
Bound() const5388   bool Bound() const override { return expr_->Bound(); }
5389 
expr() const5390   IntExpr* expr() const { return expr_; }
5391 
exponant() const5392   int64_t exponant() const { return pow_; }
5393 
WhenRange(Demon * d)5394   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5395 
name() const5396   std::string name() const override {
5397     return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5398   }
5399 
DebugString() const5400   std::string DebugString() const override {
5401     return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5402   }
5403 
Accept(ModelVisitor * const visitor) const5404   void Accept(ModelVisitor* const visitor) const override {
5405     visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5406     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5407                                             expr_);
5408     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5409     visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5410   }
5411 
5412  protected:
Pown(int64_t value) const5413   int64_t Pown(int64_t value) const {
5414     if (value >= limit_) {
5415       return std::numeric_limits<int64_t>::max();
5416     }
5417     if (value <= -limit_) {
5418       if (pow_ % 2 == 0) {
5419         return std::numeric_limits<int64_t>::max();
5420       } else {
5421         return std::numeric_limits<int64_t>::min();
5422       }
5423     }
5424     return IntPower(value, pow_);
5425   }
5426 
SqrnDown(int64_t value) const5427   int64_t SqrnDown(int64_t value) const {
5428     if (value == std::numeric_limits<int64_t>::min()) {
5429       return std::numeric_limits<int64_t>::min();
5430     }
5431     if (value == std::numeric_limits<int64_t>::max()) {
5432       return std::numeric_limits<int64_t>::max();
5433     }
5434     int64_t res = 0;
5435     const double d_value = static_cast<double>(value);
5436     if (value >= 0) {
5437       const double sq = exp(log(d_value) / pow_);
5438       res = static_cast<int64_t>(floor(sq));
5439     } else {
5440       CHECK_EQ(1, pow_ % 2);
5441       const double sq = exp(log(-d_value) / pow_);
5442       res = -static_cast<int64_t>(ceil(sq));
5443     }
5444     const int64_t pow_res = Pown(res + 1);
5445     if (pow_res <= value) {
5446       return res + 1;
5447     } else {
5448       return res;
5449     }
5450   }
5451 
SqrnUp(int64_t value) const5452   int64_t SqrnUp(int64_t value) const {
5453     if (value == std::numeric_limits<int64_t>::min()) {
5454       return std::numeric_limits<int64_t>::min();
5455     }
5456     if (value == std::numeric_limits<int64_t>::max()) {
5457       return std::numeric_limits<int64_t>::max();
5458     }
5459     int64_t res = 0;
5460     const double d_value = static_cast<double>(value);
5461     if (value >= 0) {
5462       const double sq = exp(log(d_value) / pow_);
5463       res = static_cast<int64_t>(ceil(sq));
5464     } else {
5465       CHECK_EQ(1, pow_ % 2);
5466       const double sq = exp(log(-d_value) / pow_);
5467       res = -static_cast<int64_t>(floor(sq));
5468     }
5469     const int64_t pow_res = Pown(res - 1);
5470     if (pow_res >= value) {
5471       return res - 1;
5472     } else {
5473       return res;
5474     }
5475   }
5476 
5477   IntExpr* const expr_;
5478   const int64_t pow_;
5479   const int64_t limit_;
5480 };
5481 
5482 class IntEvenPower : public BasePower {
5483  public:
IntEvenPower(Solver * const s,IntExpr * const e,int64_t n)5484   IntEvenPower(Solver* const s, IntExpr* const e, int64_t n)
5485       : BasePower(s, e, n) {
5486     CHECK_EQ(0, n % 2);
5487   }
5488 
~IntEvenPower()5489   ~IntEvenPower() override {}
5490 
Min() const5491   int64_t Min() const override {
5492     int64_t emin = 0;
5493     int64_t emax = 0;
5494     expr_->Range(&emin, &emax);
5495     if (emin >= 0) {
5496       return Pown(emin);
5497     }
5498     if (emax < 0) {
5499       return Pown(emax);
5500     }
5501     return 0LL;
5502   }
SetMin(int64_t m)5503   void SetMin(int64_t m) override {
5504     if (m <= 0) {
5505       return;
5506     }
5507     int64_t emin = 0;
5508     int64_t emax = 0;
5509     expr_->Range(&emin, &emax);
5510     const int64_t root = SqrnUp(m);
5511     if (emin > -root) {
5512       expr_->SetMin(root);
5513     } else if (emax < root) {
5514       expr_->SetMax(-root);
5515     } else if (expr_->IsVar()) {
5516       reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5517     }
5518   }
5519 
Max() const5520   int64_t Max() const override {
5521     return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5522   }
5523 
SetMax(int64_t m)5524   void SetMax(int64_t m) override {
5525     if (m < 0) {
5526       solver()->Fail();
5527     }
5528     if (m == std::numeric_limits<int64_t>::max()) {
5529       return;
5530     }
5531     const int64_t root = SqrnDown(m);
5532     expr_->SetRange(-root, root);
5533   }
5534 };
5535 
5536 class PosIntEvenPower : public BasePower {
5537  public:
PosIntEvenPower(Solver * const s,IntExpr * const e,int64_t pow)5538   PosIntEvenPower(Solver* const s, IntExpr* const e, int64_t pow)
5539       : BasePower(s, e, pow) {
5540     CHECK_EQ(0, pow % 2);
5541   }
5542 
~PosIntEvenPower()5543   ~PosIntEvenPower() override {}
5544 
Min() const5545   int64_t Min() const override { return Pown(expr_->Min()); }
5546 
SetMin(int64_t m)5547   void SetMin(int64_t m) override {
5548     if (m <= 0) {
5549       return;
5550     }
5551     expr_->SetMin(SqrnUp(m));
5552   }
Max() const5553   int64_t Max() const override { return Pown(expr_->Max()); }
5554 
SetMax(int64_t m)5555   void SetMax(int64_t m) override {
5556     if (m < 0) {
5557       solver()->Fail();
5558     }
5559     if (m == std::numeric_limits<int64_t>::max()) {
5560       return;
5561     }
5562     expr_->SetMax(SqrnDown(m));
5563   }
5564 };
5565 
5566 class IntOddPower : public BasePower {
5567  public:
IntOddPower(Solver * const s,IntExpr * const e,int64_t n)5568   IntOddPower(Solver* const s, IntExpr* const e, int64_t n)
5569       : BasePower(s, e, n) {
5570     CHECK_EQ(1, n % 2);
5571   }
5572 
~IntOddPower()5573   ~IntOddPower() override {}
5574 
Min() const5575   int64_t Min() const override { return Pown(expr_->Min()); }
5576 
SetMin(int64_t m)5577   void SetMin(int64_t m) override { expr_->SetMin(SqrnUp(m)); }
5578 
Max() const5579   int64_t Max() const override { return Pown(expr_->Max()); }
5580 
SetMax(int64_t m)5581   void SetMax(int64_t m) override { expr_->SetMax(SqrnDown(m)); }
5582 };
5583 
5584 // ----- Min(expr, expr) -----
5585 
5586 class MinIntExpr : public BaseIntExpr {
5587  public:
MinIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)5588   MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5589       : BaseIntExpr(s), left_(l), right_(r) {}
~MinIntExpr()5590   ~MinIntExpr() override {}
Min() const5591   int64_t Min() const override {
5592     const int64_t lmin = left_->Min();
5593     const int64_t rmin = right_->Min();
5594     return std::min(lmin, rmin);
5595   }
SetMin(int64_t m)5596   void SetMin(int64_t m) override {
5597     left_->SetMin(m);
5598     right_->SetMin(m);
5599   }
Max() const5600   int64_t Max() const override {
5601     const int64_t lmax = left_->Max();
5602     const int64_t rmax = right_->Max();
5603     return std::min(lmax, rmax);
5604   }
SetMax(int64_t m)5605   void SetMax(int64_t m) override {
5606     if (left_->Min() > m) {
5607       right_->SetMax(m);
5608     }
5609     if (right_->Min() > m) {
5610       left_->SetMax(m);
5611     }
5612   }
name() const5613   std::string name() const override {
5614     return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5615   }
DebugString() const5616   std::string DebugString() const override {
5617     return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5618                            right_->DebugString());
5619   }
WhenRange(Demon * d)5620   void WhenRange(Demon* d) override {
5621     left_->WhenRange(d);
5622     right_->WhenRange(d);
5623   }
5624 
Accept(ModelVisitor * const visitor) const5625   void Accept(ModelVisitor* const visitor) const override {
5626     visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5627     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5628     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5629                                             right_);
5630     visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5631   }
5632 
5633  private:
5634   IntExpr* const left_;
5635   IntExpr* const right_;
5636 };
5637 
5638 // ----- Min(expr, constant) -----
5639 
5640 class MinCstIntExpr : public BaseIntExpr {
5641  public:
MinCstIntExpr(Solver * const s,IntExpr * const e,int64_t v)5642   MinCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5643       : BaseIntExpr(s), expr_(e), value_(v) {}
5644 
~MinCstIntExpr()5645   ~MinCstIntExpr() override {}
5646 
Min() const5647   int64_t Min() const override { return std::min(expr_->Min(), value_); }
5648 
SetMin(int64_t m)5649   void SetMin(int64_t m) override {
5650     if (m > value_) {
5651       solver()->Fail();
5652     }
5653     expr_->SetMin(m);
5654   }
5655 
Max() const5656   int64_t Max() const override { return std::min(expr_->Max(), value_); }
5657 
SetMax(int64_t m)5658   void SetMax(int64_t m) override {
5659     if (value_ > m) {
5660       expr_->SetMax(m);
5661     }
5662   }
5663 
Bound() const5664   bool Bound() const override {
5665     return (expr_->Bound() || expr_->Min() >= value_);
5666   }
5667 
name() const5668   std::string name() const override {
5669     return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5670   }
5671 
DebugString() const5672   std::string DebugString() const override {
5673     return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5674                            value_);
5675   }
5676 
WhenRange(Demon * d)5677   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5678 
Accept(ModelVisitor * const visitor) const5679   void Accept(ModelVisitor* const visitor) const override {
5680     visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5681     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5682                                             expr_);
5683     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5684     visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5685   }
5686 
5687  private:
5688   IntExpr* const expr_;
5689   const int64_t value_;
5690 };
5691 
5692 // ----- Max(expr, expr) -----
5693 
5694 class MaxIntExpr : public BaseIntExpr {
5695  public:
MaxIntExpr(Solver * const s,IntExpr * const l,IntExpr * const r)5696   MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5697       : BaseIntExpr(s), left_(l), right_(r) {}
5698 
~MaxIntExpr()5699   ~MaxIntExpr() override {}
5700 
Min() const5701   int64_t Min() const override { return std::max(left_->Min(), right_->Min()); }
5702 
SetMin(int64_t m)5703   void SetMin(int64_t m) override {
5704     if (left_->Max() < m) {
5705       right_->SetMin(m);
5706     } else {
5707       if (right_->Max() < m) {
5708         left_->SetMin(m);
5709       }
5710     }
5711   }
5712 
Max() const5713   int64_t Max() const override { return std::max(left_->Max(), right_->Max()); }
5714 
SetMax(int64_t m)5715   void SetMax(int64_t m) override {
5716     left_->SetMax(m);
5717     right_->SetMax(m);
5718   }
5719 
name() const5720   std::string name() const override {
5721     return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5722   }
5723 
DebugString() const5724   std::string DebugString() const override {
5725     return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5726                            right_->DebugString());
5727   }
5728 
WhenRange(Demon * d)5729   void WhenRange(Demon* d) override {
5730     left_->WhenRange(d);
5731     right_->WhenRange(d);
5732   }
5733 
Accept(ModelVisitor * const visitor) const5734   void Accept(ModelVisitor* const visitor) const override {
5735     visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5736     visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5737     visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5738                                             right_);
5739     visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5740   }
5741 
5742  private:
5743   IntExpr* const left_;
5744   IntExpr* const right_;
5745 };
5746 
5747 // ----- Max(expr, constant) -----
5748 
5749 class MaxCstIntExpr : public BaseIntExpr {
5750  public:
MaxCstIntExpr(Solver * const s,IntExpr * const e,int64_t v)5751   MaxCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5752       : BaseIntExpr(s), expr_(e), value_(v) {}
5753 
~MaxCstIntExpr()5754   ~MaxCstIntExpr() override {}
5755 
Min() const5756   int64_t Min() const override { return std::max(expr_->Min(), value_); }
5757 
SetMin(int64_t m)5758   void SetMin(int64_t m) override {
5759     if (value_ < m) {
5760       expr_->SetMin(m);
5761     }
5762   }
5763 
Max() const5764   int64_t Max() const override { return std::max(expr_->Max(), value_); }
5765 
SetMax(int64_t m)5766   void SetMax(int64_t m) override {
5767     if (m < value_) {
5768       solver()->Fail();
5769     }
5770     expr_->SetMax(m);
5771   }
5772 
Bound() const5773   bool Bound() const override {
5774     return (expr_->Bound() || expr_->Max() <= value_);
5775   }
5776 
name() const5777   std::string name() const override {
5778     return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5779   }
5780 
DebugString() const5781   std::string DebugString() const override {
5782     return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5783                            value_);
5784   }
5785 
WhenRange(Demon * d)5786   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5787 
Accept(ModelVisitor * const visitor) const5788   void Accept(ModelVisitor* const visitor) const override {
5789     visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5790     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5791                                             expr_);
5792     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5793     visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5794   }
5795 
5796  private:
5797   IntExpr* const expr_;
5798   const int64_t value_;
5799 };
5800 
5801 // ----- Convex Piecewise -----
5802 
5803 // This class is a very simple convex piecewise linear function.  The
5804 // argument of the function is the expression.  Between early_date and
5805 // late_date, the value of the function is 0.  Before early date, it
5806 // is affine and the cost is early_cost * (early_date - x). After
5807 // late_date, the cost is late_cost * (x - late_date).
5808 
5809 class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5810  public:
SimpleConvexPiecewiseExpr(Solver * const s,IntExpr * const e,int64_t ec,int64_t ed,int64_t ld,int64_t lc)5811   SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64_t ec,
5812                             int64_t ed, int64_t ld, int64_t lc)
5813       : BaseIntExpr(s),
5814         expr_(e),
5815         early_cost_(ec),
5816         early_date_(ec == 0 ? std::numeric_limits<int64_t>::min() : ed),
5817         late_date_(lc == 0 ? std::numeric_limits<int64_t>::max() : ld),
5818         late_cost_(lc) {
5819     DCHECK_GE(ec, int64_t{0});
5820     DCHECK_GE(lc, int64_t{0});
5821     DCHECK_GE(ld, ed);
5822 
5823     // If the penalty is 0, we can push the "confort zone or zone
5824     // of no cost towards infinity.
5825   }
5826 
~SimpleConvexPiecewiseExpr()5827   ~SimpleConvexPiecewiseExpr() override {}
5828 
Min() const5829   int64_t Min() const override {
5830     const int64_t vmin = expr_->Min();
5831     const int64_t vmax = expr_->Max();
5832     if (vmin >= late_date_) {
5833       return (vmin - late_date_) * late_cost_;
5834     } else if (vmax <= early_date_) {
5835       return (early_date_ - vmax) * early_cost_;
5836     } else {
5837       return 0LL;
5838     }
5839   }
5840 
SetMin(int64_t m)5841   void SetMin(int64_t m) override {
5842     if (m <= 0) {
5843       return;
5844     }
5845     int64_t vmin = 0;
5846     int64_t vmax = 0;
5847     expr_->Range(&vmin, &vmax);
5848 
5849     const int64_t rb =
5850         (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5851     const int64_t lb =
5852         (early_cost_ == 0 ? vmin
5853                           : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5854 
5855     if (expr_->IsVar()) {
5856       expr_->Var()->RemoveInterval(lb, rb);
5857     }
5858   }
5859 
Max() const5860   int64_t Max() const override {
5861     const int64_t vmin = expr_->Min();
5862     const int64_t vmax = expr_->Max();
5863     const int64_t mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5864     const int64_t ml =
5865         vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5866     return std::max(mr, ml);
5867   }
5868 
SetMax(int64_t m)5869   void SetMax(int64_t m) override {
5870     if (m < 0) {
5871       solver()->Fail();
5872     }
5873     if (late_cost_ != 0LL) {
5874       const int64_t rb = late_date_ + PosIntDivDown(m, late_cost_);
5875       if (early_cost_ != 0LL) {
5876         const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5877         expr_->SetRange(lb, rb);
5878       } else {
5879         expr_->SetMax(rb);
5880       }
5881     } else {
5882       if (early_cost_ != 0LL) {
5883         const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5884         expr_->SetMin(lb);
5885       }
5886     }
5887   }
5888 
name() const5889   std::string name() const override {
5890     return absl::StrFormat(
5891         "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5892         expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5893   }
5894 
DebugString() const5895   std::string DebugString() const override {
5896     return absl::StrFormat(
5897         "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5898         expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5899   }
5900 
WhenRange(Demon * d)5901   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5902 
Accept(ModelVisitor * const visitor) const5903   void Accept(ModelVisitor* const visitor) const override {
5904     visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5905     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5906                                             expr_);
5907     visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5908                                   early_cost_);
5909     visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5910                                   early_date_);
5911     visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5912     visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5913     visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5914   }
5915 
5916  private:
5917   IntExpr* const expr_;
5918   const int64_t early_cost_;
5919   const int64_t early_date_;
5920   const int64_t late_date_;
5921   const int64_t late_cost_;
5922 };
5923 
5924 // ----- Semi Continuous -----
5925 
5926 class SemiContinuousExpr : public BaseIntExpr {
5927  public:
SemiContinuousExpr(Solver * const s,IntExpr * const e,int64_t fixed_charge,int64_t step)5928   SemiContinuousExpr(Solver* const s, IntExpr* const e, int64_t fixed_charge,
5929                      int64_t step)
5930       : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5931     DCHECK_GE(fixed_charge, int64_t{0});
5932     DCHECK_GT(step, int64_t{0});
5933   }
5934 
~SemiContinuousExpr()5935   ~SemiContinuousExpr() override {}
5936 
Value(int64_t x) const5937   int64_t Value(int64_t x) const {
5938     if (x <= 0) {
5939       return 0;
5940     } else {
5941       return CapAdd(fixed_charge_, CapProd(x, step_));
5942     }
5943   }
5944 
Min() const5945   int64_t Min() const override { return Value(expr_->Min()); }
5946 
SetMin(int64_t m)5947   void SetMin(int64_t m) override {
5948     if (m >= CapAdd(fixed_charge_, step_)) {
5949       const int64_t y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5950       expr_->SetMin(y);
5951     } else if (m > 0) {
5952       expr_->SetMin(1);
5953     }
5954   }
5955 
Max() const5956   int64_t Max() const override { return Value(expr_->Max()); }
5957 
SetMax(int64_t m)5958   void SetMax(int64_t m) override {
5959     if (m < 0) {
5960       solver()->Fail();
5961     }
5962     if (m == std::numeric_limits<int64_t>::max()) {
5963       return;
5964     }
5965     if (m < CapAdd(fixed_charge_, step_)) {
5966       expr_->SetMax(0);
5967     } else {
5968       const int64_t y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5969       expr_->SetMax(y);
5970     }
5971   }
5972 
name() const5973   std::string name() const override {
5974     return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5975                            expr_->name(), fixed_charge_, step_);
5976   }
5977 
DebugString() const5978   std::string DebugString() const override {
5979     return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5980                            expr_->DebugString(), fixed_charge_, step_);
5981   }
5982 
WhenRange(Demon * d)5983   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5984 
Accept(ModelVisitor * const visitor) const5985   void Accept(ModelVisitor* const visitor) const override {
5986     visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5987     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5988                                             expr_);
5989     visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
5990                                   fixed_charge_);
5991     visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
5992     visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5993   }
5994 
5995  private:
5996   IntExpr* const expr_;
5997   const int64_t fixed_charge_;
5998   const int64_t step_;
5999 };
6000 
6001 class SemiContinuousStepOneExpr : public BaseIntExpr {
6002  public:
SemiContinuousStepOneExpr(Solver * const s,IntExpr * const e,int64_t fixed_charge)6003   SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
6004                             int64_t fixed_charge)
6005       : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6006     DCHECK_GE(fixed_charge, int64_t{0});
6007   }
6008 
~SemiContinuousStepOneExpr()6009   ~SemiContinuousStepOneExpr() override {}
6010 
Value(int64_t x) const6011   int64_t Value(int64_t x) const {
6012     if (x <= 0) {
6013       return 0;
6014     } else {
6015       return fixed_charge_ + x;
6016     }
6017   }
6018 
Min() const6019   int64_t Min() const override { return Value(expr_->Min()); }
6020 
SetMin(int64_t m)6021   void SetMin(int64_t m) override {
6022     if (m >= fixed_charge_ + 1) {
6023       expr_->SetMin(m - fixed_charge_);
6024     } else if (m > 0) {
6025       expr_->SetMin(1);
6026     }
6027   }
6028 
Max() const6029   int64_t Max() const override { return Value(expr_->Max()); }
6030 
SetMax(int64_t m)6031   void SetMax(int64_t m) override {
6032     if (m < 0) {
6033       solver()->Fail();
6034     }
6035     if (m < fixed_charge_ + 1) {
6036       expr_->SetMax(0);
6037     } else {
6038       expr_->SetMax(m - fixed_charge_);
6039     }
6040   }
6041 
name() const6042   std::string name() const override {
6043     return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6044                            expr_->name(), fixed_charge_);
6045   }
6046 
DebugString() const6047   std::string DebugString() const override {
6048     return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6049                            expr_->DebugString(), fixed_charge_);
6050   }
6051 
WhenRange(Demon * d)6052   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6053 
Accept(ModelVisitor * const visitor) const6054   void Accept(ModelVisitor* const visitor) const override {
6055     visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6056     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6057                                             expr_);
6058     visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6059                                   fixed_charge_);
6060     visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6061     visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6062   }
6063 
6064  private:
6065   IntExpr* const expr_;
6066   const int64_t fixed_charge_;
6067 };
6068 
6069 class SemiContinuousStepZeroExpr : public BaseIntExpr {
6070  public:
SemiContinuousStepZeroExpr(Solver * const s,IntExpr * const e,int64_t fixed_charge)6071   SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6072                              int64_t fixed_charge)
6073       : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6074     DCHECK_GT(fixed_charge, int64_t{0});
6075   }
6076 
~SemiContinuousStepZeroExpr()6077   ~SemiContinuousStepZeroExpr() override {}
6078 
Value(int64_t x) const6079   int64_t Value(int64_t x) const {
6080     if (x <= 0) {
6081       return 0;
6082     } else {
6083       return fixed_charge_;
6084     }
6085   }
6086 
Min() const6087   int64_t Min() const override { return Value(expr_->Min()); }
6088 
SetMin(int64_t m)6089   void SetMin(int64_t m) override {
6090     if (m >= fixed_charge_) {
6091       solver()->Fail();
6092     } else if (m > 0) {
6093       expr_->SetMin(1);
6094     }
6095   }
6096 
Max() const6097   int64_t Max() const override { return Value(expr_->Max()); }
6098 
SetMax(int64_t m)6099   void SetMax(int64_t m) override {
6100     if (m < 0) {
6101       solver()->Fail();
6102     }
6103     if (m < fixed_charge_) {
6104       expr_->SetMax(0);
6105     }
6106   }
6107 
name() const6108   std::string name() const override {
6109     return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6110                            expr_->name(), fixed_charge_);
6111   }
6112 
DebugString() const6113   std::string DebugString() const override {
6114     return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6115                            expr_->DebugString(), fixed_charge_);
6116   }
6117 
WhenRange(Demon * d)6118   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6119 
Accept(ModelVisitor * const visitor) const6120   void Accept(ModelVisitor* const visitor) const override {
6121     visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6122     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6123                                             expr_);
6124     visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6125                                   fixed_charge_);
6126     visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6127     visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6128   }
6129 
6130  private:
6131   IntExpr* const expr_;
6132   const int64_t fixed_charge_;
6133 };
6134 
6135 // This constraints links an expression and the variable it is casted into
6136 class LinkExprAndVar : public CastConstraint {
6137  public:
LinkExprAndVar(Solver * const s,IntExpr * const expr,IntVar * const var)6138   LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6139       : CastConstraint(s, var), expr_(expr) {}
6140 
~LinkExprAndVar()6141   ~LinkExprAndVar() override {}
6142 
Post()6143   void Post() override {
6144     Solver* const s = solver();
6145     Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6146     expr_->WhenRange(d);
6147     target_var_->WhenRange(d);
6148   }
6149 
InitialPropagate()6150   void InitialPropagate() override {
6151     expr_->SetRange(target_var_->Min(), target_var_->Max());
6152     int64_t l, u;
6153     expr_->Range(&l, &u);
6154     target_var_->SetRange(l, u);
6155   }
6156 
DebugString() const6157   std::string DebugString() const override {
6158     return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6159                            target_var_->DebugString());
6160   }
6161 
Accept(ModelVisitor * const visitor) const6162   void Accept(ModelVisitor* const visitor) const override {
6163     visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6164     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6165                                             expr_);
6166     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6167                                             target_var_);
6168     visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6169   }
6170 
6171  private:
6172   IntExpr* const expr_;
6173 };
6174 
6175 // ----- Conditional Expression -----
6176 
6177 class ExprWithEscapeValue : public BaseIntExpr {
6178  public:
ExprWithEscapeValue(Solver * const s,IntVar * const c,IntExpr * const e,int64_t unperformed_value)6179   ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6180                       int64_t unperformed_value)
6181       : BaseIntExpr(s),
6182         condition_(c),
6183         expression_(e),
6184         unperformed_value_(unperformed_value) {}
6185 
~ExprWithEscapeValue()6186   ~ExprWithEscapeValue() override {}
6187 
Min() const6188   int64_t Min() const override {
6189     if (condition_->Min() == 1) {
6190       return expression_->Min();
6191     } else if (condition_->Max() == 1) {
6192       return std::min(unperformed_value_, expression_->Min());
6193     } else {
6194       return unperformed_value_;
6195     }
6196   }
6197 
SetMin(int64_t m)6198   void SetMin(int64_t m) override {
6199     if (m > unperformed_value_) {
6200       condition_->SetValue(1);
6201       expression_->SetMin(m);
6202     } else if (condition_->Min() == 1) {
6203       expression_->SetMin(m);
6204     } else if (m > expression_->Max()) {
6205       condition_->SetValue(0);
6206     }
6207   }
6208 
Max() const6209   int64_t Max() const override {
6210     if (condition_->Min() == 1) {
6211       return expression_->Max();
6212     } else if (condition_->Max() == 1) {
6213       return std::max(unperformed_value_, expression_->Max());
6214     } else {
6215       return unperformed_value_;
6216     }
6217   }
6218 
SetMax(int64_t m)6219   void SetMax(int64_t m) override {
6220     if (m < unperformed_value_) {
6221       condition_->SetValue(1);
6222       expression_->SetMax(m);
6223     } else if (condition_->Min() == 1) {
6224       expression_->SetMax(m);
6225     } else if (m < expression_->Min()) {
6226       condition_->SetValue(0);
6227     }
6228   }
6229 
SetRange(int64_t mi,int64_t ma)6230   void SetRange(int64_t mi, int64_t ma) override {
6231     if (ma < unperformed_value_ || mi > unperformed_value_) {
6232       condition_->SetValue(1);
6233       expression_->SetRange(mi, ma);
6234     } else if (condition_->Min() == 1) {
6235       expression_->SetRange(mi, ma);
6236     } else if (ma < expression_->Min() || mi > expression_->Max()) {
6237       condition_->SetValue(0);
6238     }
6239   }
6240 
SetValue(int64_t v)6241   void SetValue(int64_t v) override {
6242     if (v != unperformed_value_) {
6243       condition_->SetValue(1);
6244       expression_->SetValue(v);
6245     } else if (condition_->Min() == 1) {
6246       expression_->SetValue(v);
6247     } else if (v < expression_->Min() || v > expression_->Max()) {
6248       condition_->SetValue(0);
6249     }
6250   }
6251 
Bound() const6252   bool Bound() const override {
6253     return condition_->Max() == 0 || expression_->Bound();
6254   }
6255 
WhenRange(Demon * d)6256   void WhenRange(Demon* d) override {
6257     expression_->WhenRange(d);
6258     condition_->WhenBound(d);
6259   }
6260 
DebugString() const6261   std::string DebugString() const override {
6262     return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6263                            condition_->DebugString(),
6264                            expression_->DebugString(), unperformed_value_);
6265   }
6266 
Accept(ModelVisitor * const visitor) const6267   void Accept(ModelVisitor* const visitor) const override {
6268     visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6269     visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6270                                             condition_);
6271     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6272                                             expression_);
6273     visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6274                                   unperformed_value_);
6275     visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6276   }
6277 
6278  private:
6279   IntVar* const condition_;
6280   IntExpr* const expression_;
6281   const int64_t unperformed_value_;
6282   DISALLOW_COPY_AND_ASSIGN(ExprWithEscapeValue);
6283 };
6284 
6285 // ----- This is a specialized case when the variable exact type is known -----
6286 class LinkExprAndDomainIntVar : public CastConstraint {
6287  public:
LinkExprAndDomainIntVar(Solver * const s,IntExpr * const expr,DomainIntVar * const var)6288   LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6289                           DomainIntVar* const var)
6290       : CastConstraint(s, var),
6291         expr_(expr),
6292         cached_min_(std::numeric_limits<int64_t>::min()),
6293         cached_max_(std::numeric_limits<int64_t>::max()),
6294         fail_stamp_(uint64_t{0}) {}
6295 
~LinkExprAndDomainIntVar()6296   ~LinkExprAndDomainIntVar() override {}
6297 
var() const6298   DomainIntVar* var() const {
6299     return reinterpret_cast<DomainIntVar*>(target_var_);
6300   }
6301 
Post()6302   void Post() override {
6303     Solver* const s = solver();
6304     Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6305     expr_->WhenRange(d);
6306     Demon* const target_var_demon = MakeConstraintDemon0(
6307         solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6308     target_var_->WhenRange(target_var_demon);
6309   }
6310 
InitialPropagate()6311   void InitialPropagate() override {
6312     expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6313     expr_->Range(&cached_min_, &cached_max_);
6314     var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6315   }
6316 
Propagate()6317   void Propagate() {
6318     if (var()->min_.Value() > cached_min_ ||
6319         var()->max_.Value() < cached_max_ ||
6320         solver()->fail_stamp() != fail_stamp_) {
6321       InitialPropagate();
6322       fail_stamp_ = solver()->fail_stamp();
6323     }
6324   }
6325 
DebugString() const6326   std::string DebugString() const override {
6327     return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6328                            target_var_->DebugString());
6329   }
6330 
Accept(ModelVisitor * const visitor) const6331   void Accept(ModelVisitor* const visitor) const override {
6332     visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6333     visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6334                                             expr_);
6335     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6336                                             target_var_);
6337     visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6338   }
6339 
6340  private:
6341   IntExpr* const expr_;
6342   int64_t cached_min_;
6343   int64_t cached_max_;
6344   uint64_t fail_stamp_;
6345 };
6346 }  //  namespace
6347 
6348 // ----- Misc -----
6349 
MakeHoleIterator(bool reversible) const6350 IntVarIterator* BooleanVar::MakeHoleIterator(bool reversible) const {
6351   return COND_REV_ALLOC(reversible, new EmptyIterator());
6352 }
MakeDomainIterator(bool reversible) const6353 IntVarIterator* BooleanVar::MakeDomainIterator(bool reversible) const {
6354   return COND_REV_ALLOC(reversible, new RangeIterator(this));
6355 }
6356 
6357 // ----- API -----
6358 
CleanVariableOnFail(IntVar * const var)6359 void CleanVariableOnFail(IntVar* const var) {
6360   DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6361   DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6362   dvar->CleanInProcess();
6363 }
6364 
SetIsEqual(IntVar * const var,const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)6365 Constraint* SetIsEqual(IntVar* const var, const std::vector<int64_t>& values,
6366                        const std::vector<IntVar*>& vars) {
6367   DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6368   CHECK(dvar != nullptr);
6369   return dvar->SetIsEqual(values, vars);
6370 }
6371 
SetIsGreaterOrEqual(IntVar * const var,const std::vector<int64_t> & values,const std::vector<IntVar * > & vars)6372 Constraint* SetIsGreaterOrEqual(IntVar* const var,
6373                                 const std::vector<int64_t>& values,
6374                                 const std::vector<IntVar*>& vars) {
6375   DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6376   CHECK(dvar != nullptr);
6377   return dvar->SetIsGreaterOrEqual(values, vars);
6378 }
6379 
RestoreBoolValue(IntVar * const var)6380 void RestoreBoolValue(IntVar* const var) {
6381   DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6382   BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6383   boolean_var->RestoreValue();
6384 }
6385 
6386 // ----- API -----
6387 
MakeIntVar(int64_t min,int64_t max,const std::string & name)6388 IntVar* Solver::MakeIntVar(int64_t min, int64_t max, const std::string& name) {
6389   if (min == max) {
6390     return MakeIntConst(min, name);
6391   }
6392   if (min == 0 && max == 1) {
6393     return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6394   } else if (CapSub(max, min) == 1) {
6395     const std::string inner_name = "inner_" + name;
6396     return RegisterIntVar(
6397         MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6398             ->VarWithName(name));
6399   } else {
6400     return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6401   }
6402 }
6403 
MakeIntVar(int64_t min,int64_t max)6404 IntVar* Solver::MakeIntVar(int64_t min, int64_t max) {
6405   return MakeIntVar(min, max, "");
6406 }
6407 
MakeBoolVar(const std::string & name)6408 IntVar* Solver::MakeBoolVar(const std::string& name) {
6409   return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6410 }
6411 
MakeBoolVar()6412 IntVar* Solver::MakeBoolVar() {
6413   return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6414 }
6415 
MakeIntVar(const std::vector<int64_t> & values,const std::string & name)6416 IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values,
6417                            const std::string& name) {
6418   DCHECK(!values.empty());
6419   // Fast-track the case where we have a single value.
6420   if (values.size() == 1) return MakeIntConst(values[0], name);
6421   // Sort and remove duplicates.
6422   std::vector<int64_t> unique_sorted_values = values;
6423   gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6424   // Case when we have a single value, after clean-up.
6425   if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6426   // Case when the values are a dense interval of integers.
6427   if (unique_sorted_values.size() ==
6428       unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6429     return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6430                       name);
6431   }
6432   // Compute the GCD: if it's not 1, we can express the variable's domain as
6433   // the product of the GCD and of a domain with smaller values.
6434   int64_t gcd = 0;
6435   for (const int64_t v : unique_sorted_values) {
6436     if (gcd == 0) {
6437       gcd = std::abs(v);
6438     } else {
6439       gcd = MathUtil::GCD64(gcd, std::abs(v));  // Supports v==0.
6440     }
6441     if (gcd == 1) {
6442       // If it's 1, though, we can't do anything special, so we
6443       // immediately return a new DomainIntVar.
6444       return RegisterIntVar(
6445           RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6446     }
6447   }
6448   DCHECK_GT(gcd, 1);
6449   for (int64_t& v : unique_sorted_values) {
6450     DCHECK_EQ(0, v % gcd);
6451     v /= gcd;
6452   }
6453   const std::string new_name = name.empty() ? "" : "inner_" + name;
6454   // Catch the case where the divided values are a dense set of integers.
6455   IntVar* inner_intvar = nullptr;
6456   if (unique_sorted_values.size() ==
6457       unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6458     inner_intvar = MakeIntVar(unique_sorted_values.front(),
6459                               unique_sorted_values.back(), new_name);
6460   } else {
6461     inner_intvar = RegisterIntVar(
6462         RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6463   }
6464   return MakeProd(inner_intvar, gcd)->Var();
6465 }
6466 
MakeIntVar(const std::vector<int64_t> & values)6467 IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values) {
6468   return MakeIntVar(values, "");
6469 }
6470 
MakeIntVar(const std::vector<int> & values,const std::string & name)6471 IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6472                            const std::string& name) {
6473   return MakeIntVar(ToInt64Vector(values), name);
6474 }
6475 
MakeIntVar(const std::vector<int> & values)6476 IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6477   return MakeIntVar(values, "");
6478 }
6479 
MakeIntConst(int64_t val,const std::string & name)6480 IntVar* Solver::MakeIntConst(int64_t val, const std::string& name) {
6481   // If IntConst is going to be named after its creation,
6482   // cp_share_int_consts should be set to false otherwise names can potentially
6483   // be overwritten.
6484   if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6485       val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6486     return cached_constants_[val - MIN_CACHED_INT_CONST];
6487   }
6488   return RevAlloc(new IntConst(this, val, name));
6489 }
6490 
MakeIntConst(int64_t val)6491 IntVar* Solver::MakeIntConst(int64_t val) { return MakeIntConst(val, ""); }
6492 
6493 // ----- Int Var and associated methods -----
6494 
6495 namespace {
IndexedName(const std::string & prefix,int index,int max_index)6496 std::string IndexedName(const std::string& prefix, int index, int max_index) {
6497 #if 0
6498 #if defined(_MSC_VER)
6499   const int digits = max_index > 0 ?
6500       static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6501       1;
6502 #else
6503   const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6504 #endif
6505   return absl::StrFormat("%s%0*d", prefix, digits, index);
6506 #else
6507   return absl::StrCat(prefix, index);
6508 #endif
6509 }
6510 }  // namespace
6511 
MakeIntVarArray(int var_count,int64_t vmin,int64_t vmax,const std::string & name,std::vector<IntVar * > * vars)6512 void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6513                              const std::string& name,
6514                              std::vector<IntVar*>* vars) {
6515   for (int i = 0; i < var_count; ++i) {
6516     vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6517   }
6518 }
6519 
MakeIntVarArray(int var_count,int64_t vmin,int64_t vmax,std::vector<IntVar * > * vars)6520 void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6521                              std::vector<IntVar*>* vars) {
6522   for (int i = 0; i < var_count; ++i) {
6523     vars->push_back(MakeIntVar(vmin, vmax));
6524   }
6525 }
6526 
MakeIntVarArray(int var_count,int64_t vmin,int64_t vmax,const std::string & name)6527 IntVar** Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6528                                  const std::string& name) {
6529   IntVar** vars = new IntVar*[var_count];
6530   for (int i = 0; i < var_count; ++i) {
6531     vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6532   }
6533   return vars;
6534 }
6535 
MakeBoolVarArray(int var_count,const std::string & name,std::vector<IntVar * > * vars)6536 void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6537                               std::vector<IntVar*>* vars) {
6538   for (int i = 0; i < var_count; ++i) {
6539     vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6540   }
6541 }
6542 
MakeBoolVarArray(int var_count,std::vector<IntVar * > * vars)6543 void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6544   for (int i = 0; i < var_count; ++i) {
6545     vars->push_back(MakeBoolVar());
6546   }
6547 }
6548 
MakeBoolVarArray(int var_count,const std::string & name)6549 IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6550   IntVar** vars = new IntVar*[var_count];
6551   for (int i = 0; i < var_count; ++i) {
6552     vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6553   }
6554   return vars;
6555 }
6556 
InitCachedIntConstants()6557 void Solver::InitCachedIntConstants() {
6558   for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6559     cached_constants_[i - MIN_CACHED_INT_CONST] =
6560         RevAlloc(new IntConst(this, i, ""));  // note the empty name
6561   }
6562 }
6563 
MakeSum(IntExpr * const left,IntExpr * const right)6564 IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6565   CHECK_EQ(this, left->solver());
6566   CHECK_EQ(this, right->solver());
6567   if (right->Bound()) {
6568     return MakeSum(left, right->Min());
6569   }
6570   if (left->Bound()) {
6571     return MakeSum(right, left->Min());
6572   }
6573   if (left == right) {
6574     return MakeProd(left, 2);
6575   }
6576   IntExpr* cache = model_cache_->FindExprExprExpression(
6577       left, right, ModelCache::EXPR_EXPR_SUM);
6578   if (cache == nullptr) {
6579     cache = model_cache_->FindExprExprExpression(right, left,
6580                                                  ModelCache::EXPR_EXPR_SUM);
6581   }
6582   if (cache != nullptr) {
6583     return cache;
6584   } else {
6585     IntExpr* const result =
6586         AddOverflows(left->Max(), right->Max()) ||
6587                 AddOverflows(left->Min(), right->Min())
6588             ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6589             : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6590     model_cache_->InsertExprExprExpression(result, left, right,
6591                                            ModelCache::EXPR_EXPR_SUM);
6592     return result;
6593   }
6594 }
6595 
MakeSum(IntExpr * const expr,int64_t value)6596 IntExpr* Solver::MakeSum(IntExpr* const expr, int64_t value) {
6597   CHECK_EQ(this, expr->solver());
6598   if (expr->Bound()) {
6599     return MakeIntConst(expr->Min() + value);
6600   }
6601   if (value == 0) {
6602     return expr;
6603   }
6604   IntExpr* result = Cache()->FindExprConstantExpression(
6605       expr, value, ModelCache::EXPR_CONSTANT_SUM);
6606   if (result == nullptr) {
6607     if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6608         !AddOverflows(value, expr->Min())) {
6609       IntVar* const var = expr->Var();
6610       switch (var->VarType()) {
6611         case DOMAIN_INT_VAR: {
6612           result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6613               this, reinterpret_cast<DomainIntVar*>(var), value)));
6614           break;
6615         }
6616         case CONST_VAR: {
6617           result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6618           break;
6619         }
6620         case VAR_ADD_CST: {
6621           PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6622           IntVar* const sub_var = add_var->SubVar();
6623           const int64_t new_constant = value + add_var->Constant();
6624           if (new_constant == 0) {
6625             result = sub_var;
6626           } else {
6627             if (sub_var->VarType() == DOMAIN_INT_VAR) {
6628               DomainIntVar* const dvar =
6629                   reinterpret_cast<DomainIntVar*>(sub_var);
6630               result = RegisterIntExpr(
6631                   RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6632             } else {
6633               result = RegisterIntExpr(
6634                   RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6635             }
6636           }
6637           break;
6638         }
6639         case CST_SUB_VAR: {
6640           SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6641           IntVar* const sub_var = add_var->SubVar();
6642           const int64_t new_constant = value + add_var->Constant();
6643           result = RegisterIntExpr(
6644               RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6645           break;
6646         }
6647         case OPP_VAR: {
6648           OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6649           IntVar* const sub_var = add_var->SubVar();
6650           result =
6651               RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6652           break;
6653         }
6654         default:
6655           result =
6656               RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6657       }
6658     } else {
6659       result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6660     }
6661     Cache()->InsertExprConstantExpression(result, expr, value,
6662                                           ModelCache::EXPR_CONSTANT_SUM);
6663   }
6664   return result;
6665 }
6666 
MakeDifference(IntExpr * const left,IntExpr * const right)6667 IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6668   CHECK_EQ(this, left->solver());
6669   CHECK_EQ(this, right->solver());
6670   if (left->Bound()) {
6671     return MakeDifference(left->Min(), right);
6672   }
6673   if (right->Bound()) {
6674     return MakeSum(left, -right->Min());
6675   }
6676   IntExpr* sub_left = nullptr;
6677   IntExpr* sub_right = nullptr;
6678   int64_t left_coef = 1;
6679   int64_t right_coef = 1;
6680   if (IsProduct(left, &sub_left, &left_coef) &&
6681       IsProduct(right, &sub_right, &right_coef)) {
6682     const int64_t abs_gcd =
6683         MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6684     if (abs_gcd != 0 && abs_gcd != 1) {
6685       return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6686                                      MakeProd(sub_right, right_coef / abs_gcd)),
6687                       abs_gcd);
6688     }
6689   }
6690 
6691   IntExpr* result = Cache()->FindExprExprExpression(
6692       left, right, ModelCache::EXPR_EXPR_DIFFERENCE);
6693   if (result == nullptr) {
6694     if (!SubOverflows(left->Min(), right->Max()) &&
6695         !SubOverflows(left->Max(), right->Min())) {
6696       result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6697     } else {
6698       result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6699     }
6700     Cache()->InsertExprExprExpression(result, left, right,
6701                                       ModelCache::EXPR_EXPR_DIFFERENCE);
6702   }
6703   return result;
6704 }
6705 
6706 // warning: this is 'value - expr'.
MakeDifference(int64_t value,IntExpr * const expr)6707 IntExpr* Solver::MakeDifference(int64_t value, IntExpr* const expr) {
6708   CHECK_EQ(this, expr->solver());
6709   if (expr->Bound()) {
6710     return MakeIntConst(value - expr->Min());
6711   }
6712   if (value == 0) {
6713     return MakeOpposite(expr);
6714   }
6715   IntExpr* result = Cache()->FindExprConstantExpression(
6716       expr, value, ModelCache::EXPR_CONSTANT_DIFFERENCE);
6717   if (result == nullptr) {
6718     if (expr->IsVar() && expr->Min() != std::numeric_limits<int64_t>::min() &&
6719         !SubOverflows(value, expr->Min()) &&
6720         !SubOverflows(value, expr->Max())) {
6721       IntVar* const var = expr->Var();
6722       switch (var->VarType()) {
6723         case VAR_ADD_CST: {
6724           PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6725           IntVar* const sub_var = add_var->SubVar();
6726           const int64_t new_constant = value - add_var->Constant();
6727           if (new_constant == 0) {
6728             result = sub_var;
6729           } else {
6730             result = RegisterIntExpr(
6731                 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6732           }
6733           break;
6734         }
6735         case CST_SUB_VAR: {
6736           SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6737           IntVar* const sub_var = add_var->SubVar();
6738           const int64_t new_constant = value - add_var->Constant();
6739           result = MakeSum(sub_var, new_constant);
6740           break;
6741         }
6742         case OPP_VAR: {
6743           OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6744           IntVar* const sub_var = add_var->SubVar();
6745           result = MakeSum(sub_var, value);
6746           break;
6747         }
6748         default:
6749           result =
6750               RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6751       }
6752     } else {
6753       result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6754     }
6755     Cache()->InsertExprConstantExpression(result, expr, value,
6756                                           ModelCache::EXPR_CONSTANT_DIFFERENCE);
6757   }
6758   return result;
6759 }
6760 
MakeOpposite(IntExpr * const expr)6761 IntExpr* Solver::MakeOpposite(IntExpr* const expr) {
6762   CHECK_EQ(this, expr->solver());
6763   if (expr->Bound()) {
6764     return MakeIntConst(-expr->Min());
6765   }
6766   IntExpr* result =
6767       Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6768   if (result == nullptr) {
6769     if (expr->IsVar()) {
6770       result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6771     } else {
6772       result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6773     }
6774     Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6775   }
6776   return result;
6777 }
6778 
MakeProd(IntExpr * const expr,int64_t value)6779 IntExpr* Solver::MakeProd(IntExpr* const expr, int64_t value) {
6780   CHECK_EQ(this, expr->solver());
6781   IntExpr* result = Cache()->FindExprConstantExpression(
6782       expr, value, ModelCache::EXPR_CONSTANT_PROD);
6783   if (result != nullptr) {
6784     return result;
6785   } else {
6786     IntExpr* m_expr = nullptr;
6787     int64_t coefficient = 1;
6788     if (IsProduct(expr, &m_expr, &coefficient)) {
6789       coefficient *= value;
6790     } else {
6791       m_expr = expr;
6792       coefficient = value;
6793     }
6794     if (m_expr->Bound()) {
6795       return MakeIntConst(coefficient * m_expr->Min());
6796     } else if (coefficient == 1) {
6797       return m_expr;
6798     } else if (coefficient == -1) {
6799       return MakeOpposite(m_expr);
6800     } else if (coefficient > 0) {
6801       if (m_expr->Max() > std::numeric_limits<int64_t>::max() / coefficient ||
6802           m_expr->Min() < std::numeric_limits<int64_t>::min() / coefficient) {
6803         result = RegisterIntExpr(
6804             RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6805       } else {
6806         result = RegisterIntExpr(
6807             RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6808       }
6809     } else if (coefficient == 0) {
6810       result = MakeIntConst(0);
6811     } else {  // coefficient < 0.
6812       result = RegisterIntExpr(
6813           RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6814     }
6815     if (m_expr->IsVar() &&
6816         !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6817       result = result->Var();
6818     }
6819     Cache()->InsertExprConstantExpression(result, expr, value,
6820                                           ModelCache::EXPR_CONSTANT_PROD);
6821     return result;
6822   }
6823 }
6824 
6825 namespace {
ExtractPower(IntExpr ** const expr,int64_t * const exponant)6826 void ExtractPower(IntExpr** const expr, int64_t* const exponant) {
6827   if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6828     BasePower* const power = dynamic_cast<BasePower*>(*expr);
6829     *expr = power->expr();
6830     *exponant = power->exponant();
6831   }
6832   if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6833     IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6834     *expr = power->expr();
6835     *exponant = 2;
6836   }
6837   if ((*expr)->IsVar()) {
6838     IntVar* const var = (*expr)->Var();
6839     IntExpr* const sub = var->solver()->CastExpression(var);
6840     if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6841       BasePower* const power = dynamic_cast<BasePower*>(sub);
6842       *expr = power->expr();
6843       *exponant = power->exponant();
6844     }
6845     if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6846       IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6847       *expr = power->expr();
6848       *exponant = 2;
6849     }
6850   }
6851 }
6852 
ExtractProduct(IntExpr ** const expr,int64_t * const coefficient,bool * modified)6853 void ExtractProduct(IntExpr** const expr, int64_t* const coefficient,
6854                     bool* modified) {
6855   if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6856     TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6857     *coefficient *= left_prod->Constant();
6858     *expr = left_prod->SubVar();
6859     *modified = true;
6860   } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6861     TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6862     *coefficient *= left_prod->Constant();
6863     *expr = left_prod->Expr();
6864     *modified = true;
6865   }
6866 }
6867 }  // namespace
6868 
MakeProd(IntExpr * const left,IntExpr * const right)6869 IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6870   if (left->Bound()) {
6871     return MakeProd(right, left->Min());
6872   }
6873 
6874   if (right->Bound()) {
6875     return MakeProd(left, right->Min());
6876   }
6877 
6878   // ----- Discover squares and powers -----
6879 
6880   IntExpr* m_left = left;
6881   IntExpr* m_right = right;
6882   int64_t left_exponant = 1;
6883   int64_t right_exponant = 1;
6884   ExtractPower(&m_left, &left_exponant);
6885   ExtractPower(&m_right, &right_exponant);
6886 
6887   if (m_left == m_right) {
6888     return MakePower(m_left, left_exponant + right_exponant);
6889   }
6890 
6891   // ----- Discover nested products -----
6892 
6893   m_left = left;
6894   m_right = right;
6895   int64_t coefficient = 1;
6896   bool modified = false;
6897 
6898   ExtractProduct(&m_left, &coefficient, &modified);
6899   ExtractProduct(&m_right, &coefficient, &modified);
6900   if (modified) {
6901     return MakeProd(MakeProd(m_left, m_right), coefficient);
6902   }
6903 
6904   // ----- Standard build -----
6905 
6906   CHECK_EQ(this, left->solver());
6907   CHECK_EQ(this, right->solver());
6908   IntExpr* result = model_cache_->FindExprExprExpression(
6909       left, right, ModelCache::EXPR_EXPR_PROD);
6910   if (result == nullptr) {
6911     result = model_cache_->FindExprExprExpression(right, left,
6912                                                   ModelCache::EXPR_EXPR_PROD);
6913   }
6914   if (result != nullptr) {
6915     return result;
6916   }
6917   if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6918     if (right->Min() >= 0) {
6919       result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6920           this, reinterpret_cast<BooleanVar*>(left), right)));
6921     } else {
6922       result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6923           this, reinterpret_cast<BooleanVar*>(left), right)));
6924     }
6925   } else if (right->IsVar() &&
6926              reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6927     if (left->Min() >= 0) {
6928       result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6929           this, reinterpret_cast<BooleanVar*>(right), left)));
6930     } else {
6931       result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6932           this, reinterpret_cast<BooleanVar*>(right), left)));
6933     }
6934   } else if (left->Min() >= 0 && right->Min() >= 0) {
6935     if (CapProd(left->Max(), right->Max()) ==
6936         std::numeric_limits<int64_t>::max()) {  // Potential overflow.
6937       result =
6938           RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6939     } else {
6940       result =
6941           RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6942     }
6943   } else {
6944     result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6945   }
6946   model_cache_->InsertExprExprExpression(result, left, right,
6947                                          ModelCache::EXPR_EXPR_PROD);
6948   return result;
6949 }
6950 
MakeDiv(IntExpr * const numerator,IntExpr * const denominator)6951 IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6952   CHECK(numerator != nullptr);
6953   CHECK(denominator != nullptr);
6954   if (denominator->Bound()) {
6955     return MakeDiv(numerator, denominator->Min());
6956   }
6957   IntExpr* result = model_cache_->FindExprExprExpression(
6958       numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6959   if (result != nullptr) {
6960     return result;
6961   }
6962 
6963   if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6964     AddConstraint(MakeNonEquality(denominator, 0));
6965   }
6966 
6967   if (denominator->Min() >= 0) {
6968     if (numerator->Min() >= 0) {
6969       result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6970     } else {
6971       result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
6972     }
6973   } else if (denominator->Max() <= 0) {
6974     if (numerator->Max() <= 0) {
6975       result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
6976                                              MakeOpposite(denominator)));
6977     } else {
6978       result = MakeOpposite(RevAlloc(
6979           new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
6980     }
6981   } else {
6982     result = RevAlloc(new DivIntExpr(this, numerator, denominator));
6983   }
6984   model_cache_->InsertExprExprExpression(result, numerator, denominator,
6985                                          ModelCache::EXPR_EXPR_DIV);
6986   return result;
6987 }
6988 
MakeDiv(IntExpr * const expr,int64_t value)6989 IntExpr* Solver::MakeDiv(IntExpr* const expr, int64_t value) {
6990   CHECK(expr != nullptr);
6991   CHECK_EQ(this, expr->solver());
6992   if (expr->Bound()) {
6993     return MakeIntConst(expr->Min() / value);
6994   } else if (value == 1) {
6995     return expr;
6996   } else if (value == -1) {
6997     return MakeOpposite(expr);
6998   } else if (value > 0) {
6999     return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
7000   } else if (value == 0) {
7001     LOG(FATAL) << "Cannot divide by 0";
7002     return nullptr;
7003   } else {
7004     return RegisterIntExpr(
7005         MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
7006     // TODO(user) : implement special case.
7007   }
7008 }
7009 
MakeAbsEquality(IntVar * const var,IntVar * const abs_var)7010 Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
7011   if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
7012     Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
7013   }
7014   return RevAlloc(new IntAbsConstraint(this, var, abs_var));
7015 }
7016 
MakeAbs(IntExpr * const e)7017 IntExpr* Solver::MakeAbs(IntExpr* const e) {
7018   CHECK_EQ(this, e->solver());
7019   if (e->Min() >= 0) {
7020     return e;
7021   } else if (e->Max() <= 0) {
7022     return MakeOpposite(e);
7023   }
7024   IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
7025   if (result == nullptr) {
7026     int64_t coefficient = 1;
7027     IntExpr* expr = nullptr;
7028     if (IsProduct(e, &expr, &coefficient)) {
7029       result = MakeProd(MakeAbs(expr), std::abs(coefficient));
7030     } else {
7031       result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
7032     }
7033     Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7034   }
7035   return result;
7036 }
7037 
MakeSquare(IntExpr * const expr)7038 IntExpr* Solver::MakeSquare(IntExpr* const expr) {
7039   CHECK_EQ(this, expr->solver());
7040   if (expr->Bound()) {
7041     const int64_t v = expr->Min();
7042     return MakeIntConst(v * v);
7043   }
7044   IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7045   if (result == nullptr) {
7046     if (expr->Min() >= 0) {
7047       result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7048     } else {
7049       result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7050     }
7051     Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7052   }
7053   return result;
7054 }
7055 
MakePower(IntExpr * const expr,int64_t n)7056 IntExpr* Solver::MakePower(IntExpr* const expr, int64_t n) {
7057   CHECK_EQ(this, expr->solver());
7058   CHECK_GE(n, 0);
7059   if (expr->Bound()) {
7060     const int64_t v = expr->Min();
7061     if (v >= OverflowLimit(n)) {  // Overflow.
7062       return MakeIntConst(std::numeric_limits<int64_t>::max());
7063     }
7064     return MakeIntConst(IntPower(v, n));
7065   }
7066   switch (n) {
7067     case 0:
7068       return MakeIntConst(1);
7069     case 1:
7070       return expr;
7071     case 2:
7072       return MakeSquare(expr);
7073     default: {
7074       IntExpr* result = nullptr;
7075       if (n % 2 == 0) {  // even.
7076         if (expr->Min() >= 0) {
7077           result =
7078               RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7079         } else {
7080           result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7081         }
7082       } else {
7083         result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7084       }
7085       return result;
7086     }
7087   }
7088 }
7089 
MakeMin(IntExpr * const left,IntExpr * const right)7090 IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7091   CHECK_EQ(this, left->solver());
7092   CHECK_EQ(this, right->solver());
7093   if (left->Bound()) {
7094     return MakeMin(right, left->Min());
7095   }
7096   if (right->Bound()) {
7097     return MakeMin(left, right->Min());
7098   }
7099   if (left->Min() >= right->Max()) {
7100     return right;
7101   }
7102   if (right->Min() >= left->Max()) {
7103     return left;
7104   }
7105   return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7106 }
7107 
MakeMin(IntExpr * const expr,int64_t value)7108 IntExpr* Solver::MakeMin(IntExpr* const expr, int64_t value) {
7109   CHECK_EQ(this, expr->solver());
7110   if (value <= expr->Min()) {
7111     return MakeIntConst(value);
7112   }
7113   if (expr->Bound()) {
7114     return MakeIntConst(std::min(expr->Min(), value));
7115   }
7116   if (expr->Max() <= value) {
7117     return expr;
7118   }
7119   return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7120 }
7121 
MakeMin(IntExpr * const expr,int value)7122 IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7123   return MakeMin(expr, static_cast<int64_t>(value));
7124 }
7125 
MakeMax(IntExpr * const left,IntExpr * const right)7126 IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7127   CHECK_EQ(this, left->solver());
7128   CHECK_EQ(this, right->solver());
7129   if (left->Bound()) {
7130     return MakeMax(right, left->Min());
7131   }
7132   if (right->Bound()) {
7133     return MakeMax(left, right->Min());
7134   }
7135   if (left->Min() >= right->Max()) {
7136     return left;
7137   }
7138   if (right->Min() >= left->Max()) {
7139     return right;
7140   }
7141   return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7142 }
7143 
MakeMax(IntExpr * const expr,int64_t value)7144 IntExpr* Solver::MakeMax(IntExpr* const expr, int64_t value) {
7145   CHECK_EQ(this, expr->solver());
7146   if (expr->Bound()) {
7147     return MakeIntConst(std::max(expr->Min(), value));
7148   }
7149   if (value <= expr->Min()) {
7150     return expr;
7151   }
7152   if (expr->Max() <= value) {
7153     return MakeIntConst(value);
7154   }
7155   return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7156 }
7157 
MakeMax(IntExpr * const expr,int value)7158 IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7159   return MakeMax(expr, static_cast<int64_t>(value));
7160 }
7161 
MakeConvexPiecewiseExpr(IntExpr * expr,int64_t early_cost,int64_t early_date,int64_t late_date,int64_t late_cost)7162 IntExpr* Solver::MakeConvexPiecewiseExpr(IntExpr* expr, int64_t early_cost,
7163                                          int64_t early_date, int64_t late_date,
7164                                          int64_t late_cost) {
7165   return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7166       this, expr, early_cost, early_date, late_date, late_cost)));
7167 }
7168 
MakeSemiContinuousExpr(IntExpr * const expr,int64_t fixed_charge,int64_t step)7169 IntExpr* Solver::MakeSemiContinuousExpr(IntExpr* const expr,
7170                                         int64_t fixed_charge, int64_t step) {
7171   if (step == 0) {
7172     if (fixed_charge == 0) {
7173       return MakeIntConst(int64_t{0});
7174     } else {
7175       return RegisterIntExpr(
7176           RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7177     }
7178   } else if (step == 1) {
7179     return RegisterIntExpr(
7180         RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7181   } else {
7182     return RegisterIntExpr(
7183         RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7184   }
7185   // TODO(user) : benchmark with virtualization of
7186   // PosIntDivDown and PosIntDivUp - or function pointers.
7187 }
7188 
7189 // ----- Piecewise Linear -----
7190 
7191 class PiecewiseLinearExpr : public BaseIntExpr {
7192  public:
PiecewiseLinearExpr(Solver * solver,IntExpr * expr,const PiecewiseLinearFunction & f)7193   PiecewiseLinearExpr(Solver* solver, IntExpr* expr,
7194                       const PiecewiseLinearFunction& f)
7195       : BaseIntExpr(solver), expr_(expr), f_(f) {}
~PiecewiseLinearExpr()7196   ~PiecewiseLinearExpr() override {}
Min() const7197   int64_t Min() const override {
7198     return f_.GetMinimum(expr_->Min(), expr_->Max());
7199   }
SetMin(int64_t m)7200   void SetMin(int64_t m) override {
7201     const auto& range =
7202         f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7203     expr_->SetRange(range.first, range.second);
7204   }
7205 
Max() const7206   int64_t Max() const override {
7207     return f_.GetMaximum(expr_->Min(), expr_->Max());
7208   }
7209 
SetMax(int64_t m)7210   void SetMax(int64_t m) override {
7211     const auto& range =
7212         f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7213     expr_->SetRange(range.first, range.second);
7214   }
7215 
SetRange(int64_t l,int64_t u)7216   void SetRange(int64_t l, int64_t u) override {
7217     const auto& range =
7218         f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7219     expr_->SetRange(range.first, range.second);
7220   }
name() const7221   std::string name() const override {
7222     return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7223                            f_.DebugString());
7224   }
7225 
DebugString() const7226   std::string DebugString() const override {
7227     return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7228                            f_.DebugString());
7229   }
7230 
WhenRange(Demon * d)7231   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7232 
Accept(ModelVisitor * const visitor) const7233   void Accept(ModelVisitor* const visitor) const override {
7234     // TODO(user): Implement visitor.
7235   }
7236 
7237  private:
7238   IntExpr* const expr_;
7239   const PiecewiseLinearFunction f_;
7240 };
7241 
MakePiecewiseLinearExpr(IntExpr * expr,const PiecewiseLinearFunction & f)7242 IntExpr* Solver::MakePiecewiseLinearExpr(IntExpr* expr,
7243                                          const PiecewiseLinearFunction& f) {
7244   return RegisterIntExpr(RevAlloc(new PiecewiseLinearExpr(this, expr, f)));
7245 }
7246 
7247 // ----- Conditional Expression -----
7248 
MakeConditionalExpression(IntVar * const condition,IntExpr * const expr,int64_t unperformed_value)7249 IntExpr* Solver::MakeConditionalExpression(IntVar* const condition,
7250                                            IntExpr* const expr,
7251                                            int64_t unperformed_value) {
7252   if (condition->Min() == 1) {
7253     return expr;
7254   } else if (condition->Max() == 0) {
7255     return MakeIntConst(unperformed_value);
7256   } else {
7257     IntExpr* cache = Cache()->FindExprExprConstantExpression(
7258         condition, expr, unperformed_value,
7259         ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7260     if (cache == nullptr) {
7261       cache = RevAlloc(
7262           new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7263       Cache()->InsertExprExprConstantExpression(
7264           cache, condition, expr, unperformed_value,
7265           ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7266     }
7267     return cache;
7268   }
7269 }
7270 
7271 // ----- Modulo -----
7272 
MakeModulo(IntExpr * const x,int64_t mod)7273 IntExpr* Solver::MakeModulo(IntExpr* const x, int64_t mod) {
7274   IntVar* const result =
7275       MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7276   if (mod >= 0) {
7277     AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7278   } else {
7279     AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7280   }
7281   return result;
7282 }
7283 
MakeModulo(IntExpr * const x,IntExpr * const mod)7284 IntExpr* Solver::MakeModulo(IntExpr* const x, IntExpr* const mod) {
7285   if (mod->Bound()) {
7286     return MakeModulo(x, mod->Min());
7287   }
7288   IntVar* const result =
7289       MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7290   AddConstraint(MakeLess(result, MakeAbs(mod)));
7291   AddConstraint(MakeGreater(result, MakeOpposite(MakeAbs(mod))));
7292   return result;
7293 }
7294 
7295 // --------- IntVar ---------
7296 
VarType() const7297 int IntVar::VarType() const { return UNSPECIFIED; }
7298 
RemoveValues(const std::vector<int64_t> & values)7299 void IntVar::RemoveValues(const std::vector<int64_t>& values) {
7300   // TODO(user): Check and maybe inline this code.
7301   const int size = values.size();
7302   DCHECK_GE(size, 0);
7303   switch (size) {
7304     case 0: {
7305       return;
7306     }
7307     case 1: {
7308       RemoveValue(values[0]);
7309       return;
7310     }
7311     case 2: {
7312       RemoveValue(values[0]);
7313       RemoveValue(values[1]);
7314       return;
7315     }
7316     case 3: {
7317       RemoveValue(values[0]);
7318       RemoveValue(values[1]);
7319       RemoveValue(values[2]);
7320       return;
7321     }
7322     default: {
7323       // 4 values, let's start doing some more clever things.
7324       // TODO(user) : Sort values!
7325       int start_index = 0;
7326       int64_t new_min = Min();
7327       if (values[start_index] <= new_min) {
7328         while (start_index < size - 1 &&
7329                values[start_index + 1] == values[start_index] + 1) {
7330           new_min = values[start_index + 1] + 1;
7331           start_index++;
7332         }
7333       }
7334       int end_index = size - 1;
7335       int64_t new_max = Max();
7336       if (values[end_index] >= new_max) {
7337         while (end_index > start_index + 1 &&
7338                values[end_index - 1] == values[end_index] - 1) {
7339           new_max = values[end_index - 1] - 1;
7340           end_index--;
7341         }
7342       }
7343       SetRange(new_min, new_max);
7344       for (int i = start_index; i <= end_index; ++i) {
7345         RemoveValue(values[i]);
7346       }
7347     }
7348   }
7349 }
7350 
Accept(ModelVisitor * const visitor) const7351 void IntVar::Accept(ModelVisitor* const visitor) const {
7352   IntExpr* const casted = solver()->CastExpression(this);
7353   visitor->VisitIntegerVariable(this, casted);
7354 }
7355 
SetValues(const std::vector<int64_t> & values)7356 void IntVar::SetValues(const std::vector<int64_t>& values) {
7357   switch (values.size()) {
7358     case 0: {
7359       solver()->Fail();
7360       break;
7361     }
7362     case 1: {
7363       SetValue(values.back());
7364       break;
7365     }
7366     case 2: {
7367       if (Contains(values[0])) {
7368         if (Contains(values[1])) {
7369           const int64_t l = std::min(values[0], values[1]);
7370           const int64_t u = std::max(values[0], values[1]);
7371           SetRange(l, u);
7372           if (u > l + 1) {
7373             RemoveInterval(l + 1, u - 1);
7374           }
7375         } else {
7376           SetValue(values[0]);
7377         }
7378       } else {
7379         SetValue(values[1]);
7380       }
7381       break;
7382     }
7383     default: {
7384       // TODO(user): use a clean and safe SortedUniqueCopy() class
7385       // that uses a global, static shared (and locked) storage.
7386       // TODO(user): [optional] consider porting
7387       // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7388       // existing open_source/base/stl_util.h and using it here.
7389       // TODO(user): We could filter out values not in the var.
7390       std::vector<int64_t>& tmp = solver()->tmp_vector_;
7391       tmp.clear();
7392       tmp.insert(tmp.end(), values.begin(), values.end());
7393       std::sort(tmp.begin(), tmp.end());
7394       tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7395       const int size = tmp.size();
7396       const int64_t vmin = Min();
7397       const int64_t vmax = Max();
7398       int first = 0;
7399       int last = size - 1;
7400       if (tmp.front() > vmax || tmp.back() < vmin) {
7401         solver()->Fail();
7402       }
7403       // TODO(user) : We could find the first position >= vmin by dichotomy.
7404       while (tmp[first] < vmin || !Contains(tmp[first])) {
7405         ++first;
7406         if (first > last || tmp[first] > vmax) {
7407           solver()->Fail();
7408         }
7409       }
7410       while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7411         // Note that last >= first implies tmp[last] >= vmin.
7412         --last;
7413       }
7414       DCHECK_GE(last, first);
7415       SetRange(tmp[first], tmp[last]);
7416       while (first < last) {
7417         const int64_t start = tmp[first] + 1;
7418         const int64_t end = tmp[first + 1] - 1;
7419         if (start <= end) {
7420           RemoveInterval(start, end);
7421         }
7422         first++;
7423       }
7424     }
7425   }
7426 }
7427 // ---------- BaseIntExpr ---------
7428 
LinkVarExpr(Solver * const s,IntExpr * const expr,IntVar * const var)7429 void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var) {
7430   if (!var->Bound()) {
7431     if (var->VarType() == DOMAIN_INT_VAR) {
7432       DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7433       s->AddCastConstraint(
7434           s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7435     } else {
7436       s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7437                            expr);
7438     }
7439   }
7440 }
7441 
Var()7442 IntVar* BaseIntExpr::Var() {
7443   if (var_ == nullptr) {
7444     solver()->SaveValue(reinterpret_cast<void**>(&var_));
7445     var_ = CastToVar();
7446   }
7447   return var_;
7448 }
7449 
CastToVar()7450 IntVar* BaseIntExpr::CastToVar() {
7451   int64_t vmin, vmax;
7452   Range(&vmin, &vmax);
7453   IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7454   LinkVarExpr(solver(), this, var);
7455   return var;
7456 }
7457 
7458 // Discovery methods
IsADifference(IntExpr * expr,IntExpr ** const left,IntExpr ** const right)7459 bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7460                            IntExpr** const right) {
7461   if (expr->IsVar()) {
7462     IntVar* const expr_var = expr->Var();
7463     expr = CastExpression(expr_var);
7464   }
7465   // This is a dynamic cast to check the type of expr.
7466   // It returns nullptr is expr is not a subclass of SubIntExpr.
7467   SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7468   if (sub_expr != nullptr) {
7469     *left = sub_expr->left();
7470     *right = sub_expr->right();
7471     return true;
7472   }
7473   return false;
7474 }
7475 
IsBooleanVar(IntExpr * const expr,IntVar ** inner_var,bool * is_negated) const7476 bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7477                           bool* is_negated) const {
7478   if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7479     *inner_var = expr->Var();
7480     *is_negated = false;
7481     return true;
7482   } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7483     SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7484     if (sub_var != nullptr && sub_var->Constant() == 1 &&
7485         sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7486       *is_negated = true;
7487       *inner_var = sub_var->SubVar();
7488       return true;
7489     }
7490   }
7491   return false;
7492 }
7493 
IsProduct(IntExpr * const expr,IntExpr ** inner_expr,int64_t * coefficient)7494 bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7495                        int64_t* coefficient) {
7496   if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7497     TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7498     *coefficient = var->Constant();
7499     *inner_expr = var->SubVar();
7500     return true;
7501   } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7502     TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7503     *coefficient = prod->Constant();
7504     *inner_expr = prod->Expr();
7505     return true;
7506   }
7507   *inner_expr = expr;
7508   *coefficient = 1;
7509   return false;
7510 }
7511 
7512 #undef COND_REV_ALLOC
7513 
7514 }  // namespace operations_research
7515