1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #include <algorithm>
15 #include <cstdint>
16 #include <limits>
17 #include <memory>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "ortools/base/integral_types.h"
26 #include "ortools/base/logging.h"
27 #include "ortools/constraint_solver/constraint_solver.h"
28 #include "ortools/constraint_solver/constraint_solveri.h"
29 #include "ortools/util/range_minimum_query.h"
30 #include "ortools/util/string_array.h"
31 
32 ABSL_FLAG(bool, cp_disable_element_cache, true,
33           "If true, caching for IntElement is disabled.");
34 
35 namespace operations_research {
36 
37 // ----- IntExprElement -----
38 void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var);
39 
40 namespace {
41 
42 template <class T>
43 class VectorLess {
44  public:
VectorLess(const std::vector<T> * values)45   explicit VectorLess(const std::vector<T>* values) : values_(values) {}
operator ()(const T & x,const T & y) const46   bool operator()(const T& x, const T& y) const {
47     return (*values_)[x] < (*values_)[y];
48   }
49 
50  private:
51   const std::vector<T>* values_;
52 };
53 
54 template <class T>
55 class VectorGreater {
56  public:
VectorGreater(const std::vector<T> * values)57   explicit VectorGreater(const std::vector<T>* values) : values_(values) {}
operator ()(const T & x,const T & y) const58   bool operator()(const T& x, const T& y) const {
59     return (*values_)[x] > (*values_)[y];
60   }
61 
62  private:
63   const std::vector<T>* values_;
64 };
65 
66 // ----- BaseIntExprElement -----
67 
68 class BaseIntExprElement : public BaseIntExpr {
69  public:
70   BaseIntExprElement(Solver* const s, IntVar* const e);
~BaseIntExprElement()71   ~BaseIntExprElement() override {}
72   int64_t Min() const override;
73   int64_t Max() const override;
74   void Range(int64_t* mi, int64_t* ma) override;
75   void SetMin(int64_t m) override;
76   void SetMax(int64_t m) override;
77   void SetRange(int64_t mi, int64_t ma) override;
Bound() const78   bool Bound() const override { return (expr_->Bound()); }
79   // TODO(user) : improve me, the previous test is not always true
WhenRange(Demon * d)80   void WhenRange(Demon* d) override { expr_->WhenRange(d); }
81 
82  protected:
83   virtual int64_t ElementValue(int index) const = 0;
84   virtual int64_t ExprMin() const = 0;
85   virtual int64_t ExprMax() const = 0;
86 
87   IntVar* const expr_;
88 
89  private:
90   void UpdateSupports() const;
91 
92   mutable int64_t min_;
93   mutable int min_support_;
94   mutable int64_t max_;
95   mutable int max_support_;
96   mutable bool initial_update_;
97   IntVarIterator* const expr_iterator_;
98 };
99 
BaseIntExprElement(Solver * const s,IntVar * const e)100 BaseIntExprElement::BaseIntExprElement(Solver* const s, IntVar* const e)
101     : BaseIntExpr(s),
102       expr_(e),
103       min_(0),
104       min_support_(-1),
105       max_(0),
106       max_support_(-1),
107       initial_update_(true),
108       expr_iterator_(expr_->MakeDomainIterator(true)) {
109   CHECK(s != nullptr);
110   CHECK(e != nullptr);
111 }
112 
Min() const113 int64_t BaseIntExprElement::Min() const {
114   UpdateSupports();
115   return min_;
116 }
117 
Max() const118 int64_t BaseIntExprElement::Max() const {
119   UpdateSupports();
120   return max_;
121 }
122 
Range(int64_t * mi,int64_t * ma)123 void BaseIntExprElement::Range(int64_t* mi, int64_t* ma) {
124   UpdateSupports();
125   *mi = min_;
126   *ma = max_;
127 }
128 
129 #define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test) \
130   const int64_t emin = ExprMin();              \
131   const int64_t emax = ExprMax();              \
132   int64_t nmin = emin;                         \
133   int64_t value = ElementValue(nmin);          \
134   while (nmin < emax && test) {                \
135     nmin++;                                    \
136     value = ElementValue(nmin);                \
137   }                                            \
138   if (nmin == emax && test) {                  \
139     solver()->Fail();                          \
140   }                                            \
141   int64_t nmax = emax;                         \
142   value = ElementValue(nmax);                  \
143   while (nmax >= nmin && test) {               \
144     nmax--;                                    \
145     value = ElementValue(nmax);                \
146   }                                            \
147   expr_->SetRange(nmin, nmax);
148 
SetMin(int64_t m)149 void BaseIntExprElement::SetMin(int64_t m) {
150   UPDATE_BASE_ELEMENT_INDEX_BOUNDS(value < m);
151 }
152 
SetMax(int64_t m)153 void BaseIntExprElement::SetMax(int64_t m) {
154   UPDATE_BASE_ELEMENT_INDEX_BOUNDS(value > m);
155 }
156 
SetRange(int64_t mi,int64_t ma)157 void BaseIntExprElement::SetRange(int64_t mi, int64_t ma) {
158   if (mi > ma) {
159     solver()->Fail();
160   }
161   UPDATE_BASE_ELEMENT_INDEX_BOUNDS((value < mi || value > ma));
162 }
163 
164 #undef UPDATE_BASE_ELEMENT_INDEX_BOUNDS
165 
UpdateSupports() const166 void BaseIntExprElement::UpdateSupports() const {
167   if (initial_update_ || !expr_->Contains(min_support_) ||
168       !expr_->Contains(max_support_)) {
169     const int64_t emin = ExprMin();
170     const int64_t emax = ExprMax();
171     int64_t min_value = ElementValue(emax);
172     int64_t max_value = min_value;
173     int min_support = emax;
174     int max_support = emax;
175     const uint64_t expr_size = expr_->Size();
176     if (expr_size > 1) {
177       if (expr_size == emax - emin + 1) {
178         // Value(emax) already stored in min_value, max_value.
179         for (int64_t index = emin; index < emax; ++index) {
180           const int64_t value = ElementValue(index);
181           if (value > max_value) {
182             max_value = value;
183             max_support = index;
184           } else if (value < min_value) {
185             min_value = value;
186             min_support = index;
187           }
188         }
189       } else {
190         for (const int64_t index : InitAndGetValues(expr_iterator_)) {
191           if (index >= emin && index <= emax) {
192             const int64_t value = ElementValue(index);
193             if (value > max_value) {
194               max_value = value;
195               max_support = index;
196             } else if (value < min_value) {
197               min_value = value;
198               min_support = index;
199             }
200           }
201         }
202       }
203     }
204     Solver* s = solver();
205     s->SaveAndSetValue(&min_, min_value);
206     s->SaveAndSetValue(&min_support_, min_support);
207     s->SaveAndSetValue(&max_, max_value);
208     s->SaveAndSetValue(&max_support_, max_support);
209     s->SaveAndSetValue(&initial_update_, false);
210   }
211 }
212 
213 // ----- IntElementConstraint -----
214 
215 // This constraint implements 'elem' == 'values'['index'].
216 // It scans the bounds of 'elem' to propagate on the domain of 'index'.
217 // It scans the domain of 'index' to compute the new bounds of 'elem'.
218 class IntElementConstraint : public CastConstraint {
219  public:
IntElementConstraint(Solver * const s,const std::vector<int64_t> & values,IntVar * const index,IntVar * const elem)220   IntElementConstraint(Solver* const s, const std::vector<int64_t>& values,
221                        IntVar* const index, IntVar* const elem)
222       : CastConstraint(s, elem),
223         values_(values),
224         index_(index),
225         index_iterator_(index_->MakeDomainIterator(true)) {
226     CHECK(index != nullptr);
227   }
228 
Post()229   void Post() override {
230     Demon* const d =
231         solver()->MakeDelayedConstraintInitialPropagateCallback(this);
232     index_->WhenDomain(d);
233     target_var_->WhenRange(d);
234   }
235 
InitialPropagate()236   void InitialPropagate() override {
237     index_->SetRange(0, values_.size() - 1);
238     const int64_t target_var_min = target_var_->Min();
239     const int64_t target_var_max = target_var_->Max();
240     int64_t new_min = target_var_max;
241     int64_t new_max = target_var_min;
242     to_remove_.clear();
243     for (const int64_t index : InitAndGetValues(index_iterator_)) {
244       const int64_t value = values_[index];
245       if (value < target_var_min || value > target_var_max) {
246         to_remove_.push_back(index);
247       } else {
248         if (value < new_min) {
249           new_min = value;
250         }
251         if (value > new_max) {
252           new_max = value;
253         }
254       }
255     }
256     target_var_->SetRange(new_min, new_max);
257     if (!to_remove_.empty()) {
258       index_->RemoveValues(to_remove_);
259     }
260   }
261 
DebugString() const262   std::string DebugString() const override {
263     return absl::StrFormat("IntElementConstraint(%s, %s, %s)",
264                            absl::StrJoin(values_, ", "), index_->DebugString(),
265                            target_var_->DebugString());
266   }
267 
Accept(ModelVisitor * const visitor) const268   void Accept(ModelVisitor* const visitor) const override {
269     visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
270     visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
271     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
272                                             index_);
273     visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
274                                             target_var_);
275     visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
276   }
277 
278  private:
279   const std::vector<int64_t> values_;
280   IntVar* const index_;
281   IntVarIterator* const index_iterator_;
282   std::vector<int64_t> to_remove_;
283 };
284 
285 // ----- IntExprElement
286 
287 IntVar* BuildDomainIntVar(Solver* const solver, std::vector<int64_t>* values);
288 
289 class IntExprElement : public BaseIntExprElement {
290  public:
IntExprElement(Solver * const s,const std::vector<int64_t> & vals,IntVar * const expr)291   IntExprElement(Solver* const s, const std::vector<int64_t>& vals,
292                  IntVar* const expr)
293       : BaseIntExprElement(s, expr), values_(vals) {}
294 
~IntExprElement()295   ~IntExprElement() override {}
296 
name() const297   std::string name() const override {
298     const int size = values_.size();
299     if (size > 10) {
300       return absl::StrFormat("IntElement(array of size %d, %s)", size,
301                              expr_->name());
302     } else {
303       return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
304                              expr_->name());
305     }
306   }
307 
DebugString() const308   std::string DebugString() const override {
309     const int size = values_.size();
310     if (size > 10) {
311       return absl::StrFormat("IntElement(array of size %d, %s)", size,
312                              expr_->DebugString());
313     } else {
314       return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
315                              expr_->DebugString());
316     }
317   }
318 
CastToVar()319   IntVar* CastToVar() override {
320     Solver* const s = solver();
321     IntVar* const var = s->MakeIntVar(values_);
322     s->AddCastConstraint(
323         s->RevAlloc(new IntElementConstraint(s, values_, expr_, var)), var,
324         this);
325     return var;
326   }
327 
Accept(ModelVisitor * const visitor) const328   void Accept(ModelVisitor* const visitor) const override {
329     visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
330     visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
331     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
332                                             expr_);
333     visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
334   }
335 
336  protected:
ElementValue(int index) const337   int64_t ElementValue(int index) const override {
338     DCHECK_LT(index, values_.size());
339     return values_[index];
340   }
ExprMin() const341   int64_t ExprMin() const override {
342     return std::max<int64_t>(0, expr_->Min());
343   }
ExprMax() const344   int64_t ExprMax() const override {
345     return values_.empty()
346                ? 0
347                : std::min<int64_t>(values_.size() - 1, expr_->Max());
348   }
349 
350  private:
351   const std::vector<int64_t> values_;
352 };
353 
354 // ----- Range Minimum Query-based Element -----
355 
356 class RangeMinimumQueryExprElement : public BaseIntExpr {
357  public:
358   RangeMinimumQueryExprElement(Solver* solver,
359                                const std::vector<int64_t>& values,
360                                IntVar* index);
~RangeMinimumQueryExprElement()361   ~RangeMinimumQueryExprElement() override {}
362   int64_t Min() const override;
363   int64_t Max() const override;
364   void Range(int64_t* mi, int64_t* ma) override;
365   void SetMin(int64_t m) override;
366   void SetMax(int64_t m) override;
367   void SetRange(int64_t mi, int64_t ma) override;
Bound() const368   bool Bound() const override { return (index_->Bound()); }
369   // TODO(user) : improve me, the previous test is not always true
WhenRange(Demon * d)370   void WhenRange(Demon* d) override { index_->WhenRange(d); }
CastToVar()371   IntVar* CastToVar() override {
372     // TODO(user): Should we try to make holes in the domain of index_, as we
373     // do here, or should we only propagate bounds as we do in
374     // IncreasingIntExprElement ?
375     IntVar* const var = solver()->MakeIntVar(min_rmq_.array());
376     solver()->AddCastConstraint(solver()->RevAlloc(new IntElementConstraint(
377                                     solver(), min_rmq_.array(), index_, var)),
378                                 var, this);
379     return var;
380   }
Accept(ModelVisitor * const visitor) const381   void Accept(ModelVisitor* const visitor) const override {
382     visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
383     visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
384                                        min_rmq_.array());
385     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
386                                             index_);
387     visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
388   }
389 
390  private:
IndexMin() const391   int64_t IndexMin() const { return std::max<int64_t>(0, index_->Min()); }
IndexMax() const392   int64_t IndexMax() const {
393     return std::min<int64_t>(min_rmq_.array().size() - 1, index_->Max());
394   }
395 
396   IntVar* const index_;
397   const RangeMinimumQuery<int64_t, std::less<int64_t>> min_rmq_;
398   const RangeMinimumQuery<int64_t, std::greater<int64_t>> max_rmq_;
399 };
400 
RangeMinimumQueryExprElement(Solver * solver,const std::vector<int64_t> & values,IntVar * index)401 RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(
402     Solver* solver, const std::vector<int64_t>& values, IntVar* index)
403     : BaseIntExpr(solver), index_(index), min_rmq_(values), max_rmq_(values) {
404   CHECK(solver != nullptr);
405   CHECK(index != nullptr);
406 }
407 
Min() const408 int64_t RangeMinimumQueryExprElement::Min() const {
409   return min_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
410 }
411 
Max() const412 int64_t RangeMinimumQueryExprElement::Max() const {
413   return max_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
414 }
415 
Range(int64_t * mi,int64_t * ma)416 void RangeMinimumQueryExprElement::Range(int64_t* mi, int64_t* ma) {
417   const int64_t range_min = IndexMin();
418   const int64_t range_max = IndexMax() + 1;
419   *mi = min_rmq_.GetMinimumFromRange(range_min, range_max);
420   *ma = max_rmq_.GetMinimumFromRange(range_min, range_max);
421 }
422 
423 #define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test)       \
424   const std::vector<int64_t>& values = min_rmq_.array(); \
425   int64_t index_min = IndexMin();                        \
426   int64_t index_max = IndexMax();                        \
427   int64_t value = values[index_min];                     \
428   while (index_min < index_max && (test)) {              \
429     index_min++;                                         \
430     value = values[index_min];                           \
431   }                                                      \
432   if (index_min == index_max && (test)) {                \
433     solver()->Fail();                                    \
434   }                                                      \
435   value = values[index_max];                             \
436   while (index_max >= index_min && (test)) {             \
437     index_max--;                                         \
438     value = values[index_max];                           \
439   }                                                      \
440   index_->SetRange(index_min, index_max);
441 
SetMin(int64_t m)442 void RangeMinimumQueryExprElement::SetMin(int64_t m) {
443   UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < m);
444 }
445 
SetMax(int64_t m)446 void RangeMinimumQueryExprElement::SetMax(int64_t m) {
447   UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value > m);
448 }
449 
SetRange(int64_t mi,int64_t ma)450 void RangeMinimumQueryExprElement::SetRange(int64_t mi, int64_t ma) {
451   if (mi > ma) {
452     solver()->Fail();
453   }
454   UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < mi || value > ma);
455 }
456 
457 #undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS
458 
459 // ----- Increasing Element -----
460 
461 class IncreasingIntExprElement : public BaseIntExpr {
462  public:
463   IncreasingIntExprElement(Solver* const s, const std::vector<int64_t>& values,
464                            IntVar* const index);
~IncreasingIntExprElement()465   ~IncreasingIntExprElement() override {}
466 
467   int64_t Min() const override;
468   void SetMin(int64_t m) override;
469   int64_t Max() const override;
470   void SetMax(int64_t m) override;
471   void SetRange(int64_t mi, int64_t ma) override;
Bound() const472   bool Bound() const override { return (index_->Bound()); }
473   // TODO(user) : improve me, the previous test is not always true
name() const474   std::string name() const override {
475     return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
476                            index_->name());
477   }
DebugString() const478   std::string DebugString() const override {
479     return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
480                            index_->DebugString());
481   }
482 
Accept(ModelVisitor * const visitor) const483   void Accept(ModelVisitor* const visitor) const override {
484     visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
485     visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
486     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
487                                             index_);
488     visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
489   }
490 
WhenRange(Demon * d)491   void WhenRange(Demon* d) override { index_->WhenRange(d); }
492 
CastToVar()493   IntVar* CastToVar() override {
494     Solver* const s = solver();
495     IntVar* const var = s->MakeIntVar(values_);
496     LinkVarExpr(s, this, var);
497     return var;
498   }
499 
500  private:
501   const std::vector<int64_t> values_;
502   IntVar* const index_;
503 };
504 
IncreasingIntExprElement(Solver * const s,const std::vector<int64_t> & values,IntVar * const index)505 IncreasingIntExprElement::IncreasingIntExprElement(
506     Solver* const s, const std::vector<int64_t>& values, IntVar* const index)
507     : BaseIntExpr(s), values_(values), index_(index) {
508   DCHECK(index);
509   DCHECK(s);
510 }
511 
Min() const512 int64_t IncreasingIntExprElement::Min() const {
513   const int64_t expression_min = std::max<int64_t>(0, index_->Min());
514   return (expression_min < values_.size()
515               ? values_[expression_min]
516               : std::numeric_limits<int64_t>::max());
517 }
518 
SetMin(int64_t m)519 void IncreasingIntExprElement::SetMin(int64_t m) {
520   const int64_t index_min = std::max<int64_t>(0, index_->Min());
521   const int64_t index_max =
522       std::min<int64_t>(values_.size() - 1, index_->Max());
523 
524   if (index_min > index_max || m > values_[index_max]) {
525     solver()->Fail();
526   }
527 
528   const std::vector<int64_t>::const_iterator first =
529       std::lower_bound(values_.begin(), values_.end(), m);
530   const int64_t new_index_min = first - values_.begin();
531   index_->SetMin(new_index_min);
532 }
533 
Max() const534 int64_t IncreasingIntExprElement::Max() const {
535   const int64_t expression_max =
536       std::min<int64_t>(values_.size() - 1, index_->Max());
537   return (expression_max >= 0 ? values_[expression_max]
538                               : std::numeric_limits<int64_t>::max());
539 }
540 
SetMax(int64_t m)541 void IncreasingIntExprElement::SetMax(int64_t m) {
542   int64_t index_min = std::max<int64_t>(0, index_->Min());
543   if (m < values_[index_min]) {
544     solver()->Fail();
545   }
546 
547   const std::vector<int64_t>::const_iterator last_after =
548       std::upper_bound(values_.begin(), values_.end(), m);
549   const int64_t new_index_max = (last_after - values_.begin()) - 1;
550   index_->SetRange(0, new_index_max);
551 }
552 
SetRange(int64_t mi,int64_t ma)553 void IncreasingIntExprElement::SetRange(int64_t mi, int64_t ma) {
554   if (mi > ma) {
555     solver()->Fail();
556   }
557   const int64_t index_min = std::max<int64_t>(0, index_->Min());
558   const int64_t index_max =
559       std::min<int64_t>(values_.size() - 1, index_->Max());
560 
561   if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {
562     solver()->Fail();
563   }
564 
565   const std::vector<int64_t>::const_iterator first =
566       std::lower_bound(values_.begin(), values_.end(), mi);
567   const int64_t new_index_min = first - values_.begin();
568 
569   const std::vector<int64_t>::const_iterator last_after =
570       std::upper_bound(first, values_.end(), ma);
571   const int64_t new_index_max = (last_after - values_.begin()) - 1;
572 
573   // Assign.
574   index_->SetRange(new_index_min, new_index_max);
575 }
576 
577 // ----- Solver::MakeElement(int array, int var) -----
BuildElement(Solver * const solver,const std::vector<int64_t> & values,IntVar * const index)578 IntExpr* BuildElement(Solver* const solver, const std::vector<int64_t>& values,
579                       IntVar* const index) {
580   // Various checks.
581   // Is array constant?
582   if (IsArrayConstant(values, values[0])) {
583     solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
584     return solver->MakeIntConst(values[0]);
585   }
586   // Is array built with booleans only?
587   // TODO(user): We could maintain the index of the first one.
588   if (IsArrayBoolean(values)) {
589     std::vector<int64_t> ones;
590     int first_zero = -1;
591     for (int i = 0; i < values.size(); ++i) {
592       if (values[i] == 1) {
593         ones.push_back(i);
594       } else {
595         first_zero = i;
596       }
597     }
598     if (ones.size() == 1) {
599       DCHECK_EQ(int64_t{1}, values[ones.back()]);
600       solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
601       return solver->MakeIsEqualCstVar(index, ones.back());
602     } else if (ones.size() == values.size() - 1) {
603       solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
604       return solver->MakeIsDifferentCstVar(index, first_zero);
605     } else if (ones.size() == ones.back() - ones.front() + 1) {  // contiguous.
606       solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
607       IntVar* const b = solver->MakeBoolVar("ContiguousBooleanElementVar");
608       solver->AddConstraint(
609           solver->MakeIsBetweenCt(index, ones.front(), ones.back(), b));
610       return b;
611     } else {
612       IntVar* const b = solver->MakeBoolVar("NonContiguousBooleanElementVar");
613       solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
614       solver->AddConstraint(solver->MakeIsMemberCt(index, ones, b));
615       return b;
616     }
617   }
618   IntExpr* cache = nullptr;
619   if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
620     cache = solver->Cache()->FindVarConstantArrayExpression(
621         index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
622   }
623   if (cache != nullptr) {
624     return cache;
625   } else {
626     IntExpr* result = nullptr;
627     if (values.size() >= 2 && index->Min() == 0 && index->Max() == 1) {
628       result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
629                                values[0]);
630     } else if (values.size() == 2 && index->Contains(0) && index->Contains(1)) {
631       solver->AddConstraint(solver->MakeBetweenCt(index, 0, 1));
632       result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
633                                values[0]);
634     } else if (IsIncreasingContiguous(values)) {
635       result = solver->MakeSum(index, values[0]);
636     } else if (IsIncreasing(values)) {
637       result = solver->RegisterIntExpr(solver->RevAlloc(
638           new IncreasingIntExprElement(solver, values, index)));
639     } else {
640       if (solver->parameters().use_element_rmq()) {
641         result = solver->RegisterIntExpr(solver->RevAlloc(
642             new RangeMinimumQueryExprElement(solver, values, index)));
643       } else {
644         result = solver->RegisterIntExpr(
645             solver->RevAlloc(new IntExprElement(solver, values, index)));
646       }
647     }
648     if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
649       solver->Cache()->InsertVarConstantArrayExpression(
650           result, index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
651     }
652     return result;
653   }
654 }
655 }  // namespace
656 
MakeElement(const std::vector<int64_t> & values,IntVar * const index)657 IntExpr* Solver::MakeElement(const std::vector<int64_t>& values,
658                              IntVar* const index) {
659   DCHECK(index);
660   DCHECK_EQ(this, index->solver());
661   if (index->Bound()) {
662     return MakeIntConst(values[index->Min()]);
663   }
664   return BuildElement(this, values, index);
665 }
666 
MakeElement(const std::vector<int> & values,IntVar * const index)667 IntExpr* Solver::MakeElement(const std::vector<int>& values,
668                              IntVar* const index) {
669   DCHECK(index);
670   DCHECK_EQ(this, index->solver());
671   if (index->Bound()) {
672     return MakeIntConst(values[index->Min()]);
673   }
674   return BuildElement(this, ToInt64Vector(values), index);
675 }
676 
677 // ----- IntExprFunctionElement -----
678 
679 namespace {
680 class IntExprFunctionElement : public BaseIntExprElement {
681  public:
682   IntExprFunctionElement(Solver* const s, Solver::IndexEvaluator1 values,
683                          IntVar* const e);
684   ~IntExprFunctionElement() override;
685 
name() const686   std::string name() const override {
687     return absl::StrFormat("IntFunctionElement(%s)", expr_->name());
688   }
689 
DebugString() const690   std::string DebugString() const override {
691     return absl::StrFormat("IntFunctionElement(%s)", expr_->DebugString());
692   }
693 
Accept(ModelVisitor * const visitor) const694   void Accept(ModelVisitor* const visitor) const override {
695     // Warning: This will expand all values into a vector.
696     visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
697     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
698                                             expr_);
699     visitor->VisitInt64ToInt64Extension(values_, expr_->Min(), expr_->Max());
700     visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
701   }
702 
703  protected:
ElementValue(int index) const704   int64_t ElementValue(int index) const override { return values_(index); }
ExprMin() const705   int64_t ExprMin() const override { return expr_->Min(); }
ExprMax() const706   int64_t ExprMax() const override { return expr_->Max(); }
707 
708  private:
709   Solver::IndexEvaluator1 values_;
710 };
711 
IntExprFunctionElement(Solver * const s,Solver::IndexEvaluator1 values,IntVar * const e)712 IntExprFunctionElement::IntExprFunctionElement(Solver* const s,
713                                                Solver::IndexEvaluator1 values,
714                                                IntVar* const e)
715     : BaseIntExprElement(s, e), values_(std::move(values)) {
716   CHECK(values_ != nullptr);
717 }
718 
~IntExprFunctionElement()719 IntExprFunctionElement::~IntExprFunctionElement() {}
720 
721 // ----- Increasing Element -----
722 
723 class IncreasingIntExprFunctionElement : public BaseIntExpr {
724  public:
IncreasingIntExprFunctionElement(Solver * const s,Solver::IndexEvaluator1 values,IntVar * const index)725   IncreasingIntExprFunctionElement(Solver* const s,
726                                    Solver::IndexEvaluator1 values,
727                                    IntVar* const index)
728       : BaseIntExpr(s), values_(std::move(values)), index_(index) {
729     DCHECK(values_ != nullptr);
730     DCHECK(index);
731     DCHECK(s);
732   }
733 
~IncreasingIntExprFunctionElement()734   ~IncreasingIntExprFunctionElement() override {}
735 
Min() const736   int64_t Min() const override { return values_(index_->Min()); }
737 
SetMin(int64_t m)738   void SetMin(int64_t m) override {
739     const int64_t index_min = index_->Min();
740     const int64_t index_max = index_->Max();
741     if (m > values_(index_max)) {
742       solver()->Fail();
743     }
744     const int64_t new_index_min = FindNewIndexMin(index_min, index_max, m);
745     index_->SetMin(new_index_min);
746   }
747 
Max() const748   int64_t Max() const override { return values_(index_->Max()); }
749 
SetMax(int64_t m)750   void SetMax(int64_t m) override {
751     int64_t index_min = index_->Min();
752     int64_t index_max = index_->Max();
753     if (m < values_(index_min)) {
754       solver()->Fail();
755     }
756     const int64_t new_index_max = FindNewIndexMax(index_min, index_max, m);
757     index_->SetMax(new_index_max);
758   }
759 
SetRange(int64_t mi,int64_t ma)760   void SetRange(int64_t mi, int64_t ma) override {
761     const int64_t index_min = index_->Min();
762     const int64_t index_max = index_->Max();
763     const int64_t value_min = values_(index_min);
764     const int64_t value_max = values_(index_max);
765     if (mi > ma || ma < value_min || mi > value_max) {
766       solver()->Fail();
767     }
768     if (mi <= value_min && ma >= value_max) {
769       // Nothing to do.
770       return;
771     }
772 
773     const int64_t new_index_min = FindNewIndexMin(index_min, index_max, mi);
774     const int64_t new_index_max = FindNewIndexMax(new_index_min, index_max, ma);
775     // Assign.
776     index_->SetRange(new_index_min, new_index_max);
777   }
778 
name() const779   std::string name() const override {
780     return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
781                            index_->name());
782   }
783 
DebugString() const784   std::string DebugString() const override {
785     return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
786                            index_->DebugString());
787   }
788 
WhenRange(Demon * d)789   void WhenRange(Demon* d) override { index_->WhenRange(d); }
790 
Accept(ModelVisitor * const visitor) const791   void Accept(ModelVisitor* const visitor) const override {
792     // Warning: This will expand all values into a vector.
793     visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
794     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
795                                             index_);
796     if (index_->Min() == 0) {
797       visitor->VisitInt64ToInt64AsArray(values_, ModelVisitor::kValuesArgument,
798                                         index_->Max());
799     } else {
800       visitor->VisitInt64ToInt64Extension(values_, index_->Min(),
801                                           index_->Max());
802     }
803     visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
804   }
805 
806  private:
FindNewIndexMin(int64_t index_min,int64_t index_max,int64_t m)807   int64_t FindNewIndexMin(int64_t index_min, int64_t index_max, int64_t m) {
808     if (m <= values_(index_min)) {
809       return index_min;
810     }
811 
812     DCHECK_LT(values_(index_min), m);
813     DCHECK_GE(values_(index_max), m);
814 
815     int64_t index_lower_bound = index_min;
816     int64_t index_upper_bound = index_max;
817     while (index_upper_bound - index_lower_bound > 1) {
818       DCHECK_LT(values_(index_lower_bound), m);
819       DCHECK_GE(values_(index_upper_bound), m);
820       const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
821       const int64_t pivot_value = values_(pivot);
822       if (pivot_value < m) {
823         index_lower_bound = pivot;
824       } else {
825         index_upper_bound = pivot;
826       }
827     }
828     DCHECK(values_(index_upper_bound) >= m);
829     return index_upper_bound;
830   }
831 
FindNewIndexMax(int64_t index_min,int64_t index_max,int64_t m)832   int64_t FindNewIndexMax(int64_t index_min, int64_t index_max, int64_t m) {
833     if (m >= values_(index_max)) {
834       return index_max;
835     }
836 
837     DCHECK_LE(values_(index_min), m);
838     DCHECK_GT(values_(index_max), m);
839 
840     int64_t index_lower_bound = index_min;
841     int64_t index_upper_bound = index_max;
842     while (index_upper_bound - index_lower_bound > 1) {
843       DCHECK_LE(values_(index_lower_bound), m);
844       DCHECK_GT(values_(index_upper_bound), m);
845       const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
846       const int64_t pivot_value = values_(pivot);
847       if (pivot_value > m) {
848         index_upper_bound = pivot;
849       } else {
850         index_lower_bound = pivot;
851       }
852     }
853     DCHECK(values_(index_lower_bound) <= m);
854     return index_lower_bound;
855   }
856 
857   Solver::IndexEvaluator1 values_;
858   IntVar* const index_;
859 };
860 }  // namespace
861 
MakeElement(Solver::IndexEvaluator1 values,IntVar * const index)862 IntExpr* Solver::MakeElement(Solver::IndexEvaluator1 values,
863                              IntVar* const index) {
864   CHECK_EQ(this, index->solver());
865   return RegisterIntExpr(
866       RevAlloc(new IntExprFunctionElement(this, std::move(values), index)));
867 }
868 
MakeMonotonicElement(Solver::IndexEvaluator1 values,bool increasing,IntVar * const index)869 IntExpr* Solver::MakeMonotonicElement(Solver::IndexEvaluator1 values,
870                                       bool increasing, IntVar* const index) {
871   CHECK_EQ(this, index->solver());
872   if (increasing) {
873     return RegisterIntExpr(
874         RevAlloc(new IncreasingIntExprFunctionElement(this, values, index)));
875   } else {
876     // You need to pass by copy such that opposite_value does not include a
877     // dandling reference when leaving this scope.
878     Solver::IndexEvaluator1 opposite_values = [values](int64_t i) {
879       return -values(i);
880     };
881     return RegisterIntExpr(MakeOpposite(RevAlloc(
882         new IncreasingIntExprFunctionElement(this, opposite_values, index))));
883   }
884 }
885 
886 // ----- IntIntExprFunctionElement -----
887 
888 namespace {
889 class IntIntExprFunctionElement : public BaseIntExpr {
890  public:
891   IntIntExprFunctionElement(Solver* const s, Solver::IndexEvaluator2 values,
892                             IntVar* const expr1, IntVar* const expr2);
893   ~IntIntExprFunctionElement() override;
DebugString() const894   std::string DebugString() const override {
895     return absl::StrFormat("IntIntFunctionElement(%s,%s)",
896                            expr1_->DebugString(), expr2_->DebugString());
897   }
898   int64_t Min() const override;
899   int64_t Max() const override;
900   void Range(int64_t* lower_bound, int64_t* upper_bound) override;
901   void SetMin(int64_t lower_bound) override;
902   void SetMax(int64_t upper_bound) override;
903   void SetRange(int64_t lower_bound, int64_t upper_bound) override;
Bound() const904   bool Bound() const override { return expr1_->Bound() && expr2_->Bound(); }
905   // TODO(user) : improve me, the previous test is not always true
WhenRange(Demon * d)906   void WhenRange(Demon* d) override {
907     expr1_->WhenRange(d);
908     expr2_->WhenRange(d);
909   }
910 
Accept(ModelVisitor * const visitor) const911   void Accept(ModelVisitor* const visitor) const override {
912     visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
913     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
914                                             expr1_);
915     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndex2Argument,
916                                             expr2_);
917     // Warning: This will expand all values into a vector.
918     const int64_t expr1_min = expr1_->Min();
919     const int64_t expr1_max = expr1_->Max();
920     visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, expr1_min);
921     visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, expr1_max);
922     for (int i = expr1_min; i <= expr1_max; ++i) {
923       visitor->VisitInt64ToInt64Extension(
924           [this, i](int64_t j) { return values_(i, j); }, expr2_->Min(),
925           expr2_->Max());
926     }
927     visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
928   }
929 
930  private:
ElementValue(int index1,int index2) const931   int64_t ElementValue(int index1, int index2) const {
932     return values_(index1, index2);
933   }
934   void UpdateSupports() const;
935 
936   IntVar* const expr1_;
937   IntVar* const expr2_;
938   mutable int64_t min_;
939   mutable int min_support1_;
940   mutable int min_support2_;
941   mutable int64_t max_;
942   mutable int max_support1_;
943   mutable int max_support2_;
944   mutable bool initial_update_;
945   Solver::IndexEvaluator2 values_;
946   IntVarIterator* const expr1_iterator_;
947   IntVarIterator* const expr2_iterator_;
948 };
949 
IntIntExprFunctionElement(Solver * const s,Solver::IndexEvaluator2 values,IntVar * const expr1,IntVar * const expr2)950 IntIntExprFunctionElement::IntIntExprFunctionElement(
951     Solver* const s, Solver::IndexEvaluator2 values, IntVar* const expr1,
952     IntVar* const expr2)
953     : BaseIntExpr(s),
954       expr1_(expr1),
955       expr2_(expr2),
956       min_(0),
957       min_support1_(-1),
958       min_support2_(-1),
959       max_(0),
960       max_support1_(-1),
961       max_support2_(-1),
962       initial_update_(true),
963       values_(std::move(values)),
964       expr1_iterator_(expr1_->MakeDomainIterator(true)),
965       expr2_iterator_(expr2_->MakeDomainIterator(true)) {
966   CHECK(values_ != nullptr);
967 }
968 
~IntIntExprFunctionElement()969 IntIntExprFunctionElement::~IntIntExprFunctionElement() {}
970 
Min() const971 int64_t IntIntExprFunctionElement::Min() const {
972   UpdateSupports();
973   return min_;
974 }
975 
Max() const976 int64_t IntIntExprFunctionElement::Max() const {
977   UpdateSupports();
978   return max_;
979 }
980 
Range(int64_t * lower_bound,int64_t * upper_bound)981 void IntIntExprFunctionElement::Range(int64_t* lower_bound,
982                                       int64_t* upper_bound) {
983   UpdateSupports();
984   *lower_bound = min_;
985   *upper_bound = max_;
986 }
987 
988 #define UPDATE_ELEMENT_INDEX_BOUNDS(test)     \
989   const int64_t emin1 = expr1_->Min();        \
990   const int64_t emax1 = expr1_->Max();        \
991   const int64_t emin2 = expr2_->Min();        \
992   const int64_t emax2 = expr2_->Max();        \
993   int64_t nmin1 = emin1;                      \
994   bool found = false;                         \
995   while (nmin1 <= emax1 && !found) {          \
996     for (int i = emin2; i <= emax2; ++i) {    \
997       int64_t value = ElementValue(nmin1, i); \
998       if (test) {                             \
999         found = true;                         \
1000         break;                                \
1001       }                                       \
1002     }                                         \
1003     if (!found) {                             \
1004       nmin1++;                                \
1005     }                                         \
1006   }                                           \
1007   if (nmin1 > emax1) {                        \
1008     solver()->Fail();                         \
1009   }                                           \
1010   int64_t nmin2 = emin2;                      \
1011   found = false;                              \
1012   while (nmin2 <= emax2 && !found) {          \
1013     for (int i = emin1; i <= emax1; ++i) {    \
1014       int64_t value = ElementValue(i, nmin2); \
1015       if (test) {                             \
1016         found = true;                         \
1017         break;                                \
1018       }                                       \
1019     }                                         \
1020     if (!found) {                             \
1021       nmin2++;                                \
1022     }                                         \
1023   }                                           \
1024   if (nmin2 > emax2) {                        \
1025     solver()->Fail();                         \
1026   }                                           \
1027   int64_t nmax1 = emax1;                      \
1028   found = false;                              \
1029   while (nmax1 >= nmin1 && !found) {          \
1030     for (int i = emin2; i <= emax2; ++i) {    \
1031       int64_t value = ElementValue(nmax1, i); \
1032       if (test) {                             \
1033         found = true;                         \
1034         break;                                \
1035       }                                       \
1036     }                                         \
1037     if (!found) {                             \
1038       nmax1--;                                \
1039     }                                         \
1040   }                                           \
1041   int64_t nmax2 = emax2;                      \
1042   found = false;                              \
1043   while (nmax2 >= nmin2 && !found) {          \
1044     for (int i = emin1; i <= emax1; ++i) {    \
1045       int64_t value = ElementValue(i, nmax2); \
1046       if (test) {                             \
1047         found = true;                         \
1048         break;                                \
1049       }                                       \
1050     }                                         \
1051     if (!found) {                             \
1052       nmax2--;                                \
1053     }                                         \
1054   }                                           \
1055   expr1_->SetRange(nmin1, nmax1);             \
1056   expr2_->SetRange(nmin2, nmax2);
1057 
SetMin(int64_t lower_bound)1058 void IntIntExprFunctionElement::SetMin(int64_t lower_bound) {
1059   UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound);
1060 }
1061 
SetMax(int64_t upper_bound)1062 void IntIntExprFunctionElement::SetMax(int64_t upper_bound) {
1063   UPDATE_ELEMENT_INDEX_BOUNDS(value <= upper_bound);
1064 }
1065 
SetRange(int64_t lower_bound,int64_t upper_bound)1066 void IntIntExprFunctionElement::SetRange(int64_t lower_bound,
1067                                          int64_t upper_bound) {
1068   if (lower_bound > upper_bound) {
1069     solver()->Fail();
1070   }
1071   UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound && value <= upper_bound);
1072 }
1073 
1074 #undef UPDATE_ELEMENT_INDEX_BOUNDS
1075 
UpdateSupports() const1076 void IntIntExprFunctionElement::UpdateSupports() const {
1077   if (initial_update_ || !expr1_->Contains(min_support1_) ||
1078       !expr1_->Contains(max_support1_) || !expr2_->Contains(min_support2_) ||
1079       !expr2_->Contains(max_support2_)) {
1080     const int64_t emax1 = expr1_->Max();
1081     const int64_t emax2 = expr2_->Max();
1082     int64_t min_value = ElementValue(emax1, emax2);
1083     int64_t max_value = min_value;
1084     int min_support1 = emax1;
1085     int max_support1 = emax1;
1086     int min_support2 = emax2;
1087     int max_support2 = emax2;
1088     for (const int64_t index1 : InitAndGetValues(expr1_iterator_)) {
1089       for (const int64_t index2 : InitAndGetValues(expr2_iterator_)) {
1090         const int64_t value = ElementValue(index1, index2);
1091         if (value > max_value) {
1092           max_value = value;
1093           max_support1 = index1;
1094           max_support2 = index2;
1095         } else if (value < min_value) {
1096           min_value = value;
1097           min_support1 = index1;
1098           min_support2 = index2;
1099         }
1100       }
1101     }
1102     Solver* s = solver();
1103     s->SaveAndSetValue(&min_, min_value);
1104     s->SaveAndSetValue(&min_support1_, min_support1);
1105     s->SaveAndSetValue(&min_support2_, min_support2);
1106     s->SaveAndSetValue(&max_, max_value);
1107     s->SaveAndSetValue(&max_support1_, max_support1);
1108     s->SaveAndSetValue(&max_support2_, max_support2);
1109     s->SaveAndSetValue(&initial_update_, false);
1110   }
1111 }
1112 }  // namespace
1113 
MakeElement(Solver::IndexEvaluator2 values,IntVar * const index1,IntVar * const index2)1114 IntExpr* Solver::MakeElement(Solver::IndexEvaluator2 values,
1115                              IntVar* const index1, IntVar* const index2) {
1116   CHECK_EQ(this, index1->solver());
1117   CHECK_EQ(this, index2->solver());
1118   return RegisterIntExpr(RevAlloc(
1119       new IntIntExprFunctionElement(this, std::move(values), index1, index2)));
1120 }
1121 
1122 // ---------- Generalized element ----------
1123 
1124 // ----- IfThenElseCt -----
1125 
1126 class IfThenElseCt : public CastConstraint {
1127  public:
IfThenElseCt(Solver * const solver,IntVar * const condition,IntExpr * const one,IntExpr * const zero,IntVar * const target)1128   IfThenElseCt(Solver* const solver, IntVar* const condition,
1129                IntExpr* const one, IntExpr* const zero, IntVar* const target)
1130       : CastConstraint(solver, target),
1131         condition_(condition),
1132         zero_(zero),
1133         one_(one) {}
1134 
~IfThenElseCt()1135   ~IfThenElseCt() override {}
1136 
Post()1137   void Post() override {
1138     Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1139     condition_->WhenBound(demon);
1140     one_->WhenRange(demon);
1141     zero_->WhenRange(demon);
1142     target_var_->WhenRange(demon);
1143   }
1144 
InitialPropagate()1145   void InitialPropagate() override {
1146     condition_->SetRange(0, 1);
1147     const int64_t target_var_min = target_var_->Min();
1148     const int64_t target_var_max = target_var_->Max();
1149     int64_t new_min = std::numeric_limits<int64_t>::min();
1150     int64_t new_max = std::numeric_limits<int64_t>::max();
1151     if (condition_->Max() == 0) {
1152       zero_->SetRange(target_var_min, target_var_max);
1153       zero_->Range(&new_min, &new_max);
1154     } else if (condition_->Min() == 1) {
1155       one_->SetRange(target_var_min, target_var_max);
1156       one_->Range(&new_min, &new_max);
1157     } else {
1158       if (target_var_max < zero_->Min() || target_var_min > zero_->Max()) {
1159         condition_->SetValue(1);
1160         one_->SetRange(target_var_min, target_var_max);
1161         one_->Range(&new_min, &new_max);
1162       } else if (target_var_max < one_->Min() || target_var_min > one_->Max()) {
1163         condition_->SetValue(0);
1164         zero_->SetRange(target_var_min, target_var_max);
1165         zero_->Range(&new_min, &new_max);
1166       } else {
1167         int64_t zl = 0;
1168         int64_t zu = 0;
1169         int64_t ol = 0;
1170         int64_t ou = 0;
1171         zero_->Range(&zl, &zu);
1172         one_->Range(&ol, &ou);
1173         new_min = std::min(zl, ol);
1174         new_max = std::max(zu, ou);
1175       }
1176     }
1177     target_var_->SetRange(new_min, new_max);
1178   }
1179 
DebugString() const1180   std::string DebugString() const override {
1181     return absl::StrFormat("(%s ? %s : %s) == %s", condition_->DebugString(),
1182                            one_->DebugString(), zero_->DebugString(),
1183                            target_var_->DebugString());
1184   }
1185 
Accept(ModelVisitor * const visitor) const1186   void Accept(ModelVisitor* const visitor) const override {}
1187 
1188  private:
1189   IntVar* const condition_;
1190   IntExpr* const zero_;
1191   IntExpr* const one_;
1192 };
1193 
1194 // ----- IntExprEvaluatorElementCt -----
1195 
1196 // This constraint implements evaluator(index) == var. It is delayed such
1197 // that propagation only occurs when all variables have been touched.
1198 // The range of the evaluator is [range_start, range_end).
1199 
1200 namespace {
1201 class IntExprEvaluatorElementCt : public CastConstraint {
1202  public:
1203   IntExprEvaluatorElementCt(Solver* const s, Solver::Int64ToIntVar evaluator,
1204                             int64_t range_start, int64_t range_end,
1205                             IntVar* const index, IntVar* const target_var);
~IntExprEvaluatorElementCt()1206   ~IntExprEvaluatorElementCt() override {}
1207 
1208   void Post() override;
1209   void InitialPropagate() override;
1210 
1211   void Propagate();
1212   void Update(int index);
1213   void UpdateExpr();
1214 
1215   std::string DebugString() const override;
1216   void Accept(ModelVisitor* const visitor) const override;
1217 
1218  protected:
1219   IntVar* const index_;
1220 
1221  private:
1222   const Solver::Int64ToIntVar evaluator_;
1223   const int64_t range_start_;
1224   const int64_t range_end_;
1225   int min_support_;
1226   int max_support_;
1227 };
1228 
IntExprEvaluatorElementCt(Solver * const s,Solver::Int64ToIntVar evaluator,int64_t range_start,int64_t range_end,IntVar * const index,IntVar * const target_var)1229 IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(
1230     Solver* const s, Solver::Int64ToIntVar evaluator, int64_t range_start,
1231     int64_t range_end, IntVar* const index, IntVar* const target_var)
1232     : CastConstraint(s, target_var),
1233       index_(index),
1234       evaluator_(std::move(evaluator)),
1235       range_start_(range_start),
1236       range_end_(range_end),
1237       min_support_(-1),
1238       max_support_(-1) {}
1239 
Post()1240 void IntExprEvaluatorElementCt::Post() {
1241   Demon* const delayed_propagate_demon = MakeDelayedConstraintDemon0(
1242       solver(), this, &IntExprEvaluatorElementCt::Propagate, "Propagate");
1243   for (int i = range_start_; i < range_end_; ++i) {
1244     IntVar* const current_var = evaluator_(i);
1245     current_var->WhenRange(delayed_propagate_demon);
1246     Demon* const update_demon = MakeConstraintDemon1(
1247         solver(), this, &IntExprEvaluatorElementCt::Update, "Update", i);
1248     current_var->WhenRange(update_demon);
1249   }
1250   index_->WhenRange(delayed_propagate_demon);
1251   Demon* const update_expr_demon = MakeConstraintDemon0(
1252       solver(), this, &IntExprEvaluatorElementCt::UpdateExpr, "UpdateExpr");
1253   index_->WhenRange(update_expr_demon);
1254   Demon* const update_var_demon = MakeConstraintDemon0(
1255       solver(), this, &IntExprEvaluatorElementCt::Propagate, "UpdateVar");
1256 
1257   target_var_->WhenRange(update_var_demon);
1258 }
1259 
InitialPropagate()1260 void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }
1261 
Propagate()1262 void IntExprEvaluatorElementCt::Propagate() {
1263   const int64_t emin = std::max(range_start_, index_->Min());
1264   const int64_t emax = std::min<int64_t>(range_end_ - 1, index_->Max());
1265   const int64_t vmin = target_var_->Min();
1266   const int64_t vmax = target_var_->Max();
1267   if (emin == emax) {
1268     index_->SetValue(emin);  // in case it was reduced by the above min/max.
1269     evaluator_(emin)->SetRange(vmin, vmax);
1270   } else {
1271     int64_t nmin = emin;
1272     for (; nmin <= emax; nmin++) {
1273       // break if the intersection of
1274       // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1275       // is non-empty.
1276       IntVar* const nmin_var = evaluator_(nmin);
1277       if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin) break;
1278     }
1279     int64_t nmax = emax;
1280     for (; nmin <= nmax; nmax--) {
1281       // break if the intersection of
1282       // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1283       // is non-empty.
1284       IntExpr* const nmax_var = evaluator_(nmax);
1285       if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin) break;
1286     }
1287     index_->SetRange(nmin, nmax);
1288     if (nmin == nmax) {
1289       evaluator_(nmin)->SetRange(vmin, vmax);
1290     }
1291   }
1292   if (min_support_ == -1 || max_support_ == -1) {
1293     int min_support = -1;
1294     int max_support = -1;
1295     int64_t gmin = std::numeric_limits<int64_t>::max();
1296     int64_t gmax = std::numeric_limits<int64_t>::min();
1297     for (int i = index_->Min(); i <= index_->Max(); ++i) {
1298       IntExpr* const var_i = evaluator_(i);
1299       const int64_t vmin = var_i->Min();
1300       if (vmin < gmin) {
1301         gmin = vmin;
1302       }
1303       const int64_t vmax = var_i->Max();
1304       if (vmax > gmax) {
1305         gmax = vmax;
1306       }
1307     }
1308     solver()->SaveAndSetValue(&min_support_, min_support);
1309     solver()->SaveAndSetValue(&max_support_, max_support);
1310     target_var_->SetRange(gmin, gmax);
1311   }
1312 }
1313 
Update(int index)1314 void IntExprEvaluatorElementCt::Update(int index) {
1315   if (index == min_support_ || index == max_support_) {
1316     solver()->SaveAndSetValue(&min_support_, -1);
1317     solver()->SaveAndSetValue(&max_support_, -1);
1318   }
1319 }
1320 
UpdateExpr()1321 void IntExprEvaluatorElementCt::UpdateExpr() {
1322   if (!index_->Contains(min_support_) || !index_->Contains(max_support_)) {
1323     solver()->SaveAndSetValue(&min_support_, -1);
1324     solver()->SaveAndSetValue(&max_support_, -1);
1325   }
1326 }
1327 
1328 namespace {
StringifyEvaluatorBare(const Solver::Int64ToIntVar & evaluator,int64_t range_start,int64_t range_end)1329 std::string StringifyEvaluatorBare(const Solver::Int64ToIntVar& evaluator,
1330                                    int64_t range_start, int64_t range_end) {
1331   std::string out;
1332   for (int64_t i = range_start; i < range_end; ++i) {
1333     if (i != range_start) {
1334       out += ", ";
1335     }
1336     out += absl::StrFormat("%d -> %s", i, evaluator(i)->DebugString());
1337   }
1338   return out;
1339 }
1340 
StringifyInt64ToIntVar(const Solver::Int64ToIntVar & evaluator,int64_t range_begin,int64_t range_end)1341 std::string StringifyInt64ToIntVar(const Solver::Int64ToIntVar& evaluator,
1342                                    int64_t range_begin, int64_t range_end) {
1343   std::string out;
1344   if (range_end - range_begin > 10) {
1345     out = absl::StrFormat(
1346         "IntToIntVar(%s, ...%s)",
1347         StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),
1348         StringifyEvaluatorBare(evaluator, range_end - 5, range_end));
1349   } else {
1350     out = absl::StrFormat(
1351         "IntToIntVar(%s)",
1352         StringifyEvaluatorBare(evaluator, range_begin, range_end));
1353   }
1354   return out;
1355 }
1356 }  // namespace
1357 
DebugString() const1358 std::string IntExprEvaluatorElementCt::DebugString() const {
1359   return StringifyInt64ToIntVar(evaluator_, range_start_, range_end_);
1360 }
1361 
Accept(ModelVisitor * const visitor) const1362 void IntExprEvaluatorElementCt::Accept(ModelVisitor* const visitor) const {
1363   visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1364   visitor->VisitIntegerVariableEvaluatorArgument(
1365       ModelVisitor::kEvaluatorArgument, evaluator_);
1366   visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1367   visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1368                                           target_var_);
1369   visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1370 }
1371 
1372 // ----- IntExprArrayElementCt -----
1373 
1374 // This constraint implements vars[index] == var. It is delayed such
1375 // that propagation only occurs when all variables have been touched.
1376 
1377 class IntExprArrayElementCt : public IntExprEvaluatorElementCt {
1378  public:
1379   IntExprArrayElementCt(Solver* const s, std::vector<IntVar*> vars,
1380                         IntVar* const index, IntVar* const target_var);
1381 
1382   std::string DebugString() const override;
1383   void Accept(ModelVisitor* const visitor) const override;
1384 
1385  private:
1386   const std::vector<IntVar*> vars_;
1387 };
1388 
IntExprArrayElementCt(Solver * const s,std::vector<IntVar * > vars,IntVar * const index,IntVar * const target_var)1389 IntExprArrayElementCt::IntExprArrayElementCt(Solver* const s,
1390                                              std::vector<IntVar*> vars,
1391                                              IntVar* const index,
1392                                              IntVar* const target_var)
1393     : IntExprEvaluatorElementCt(
1394           s, [this](int64_t idx) { return vars_[idx]; }, 0, vars.size(), index,
1395           target_var),
1396       vars_(std::move(vars)) {}
1397 
DebugString() const1398 std::string IntExprArrayElementCt::DebugString() const {
1399   int64_t size = vars_.size();
1400   if (size > 10) {
1401     return absl::StrFormat(
1402         "IntExprArrayElement(var array of size %d, %s) == %s", size,
1403         index_->DebugString(), target_var_->DebugString());
1404   } else {
1405     return absl::StrFormat("IntExprArrayElement([%s], %s) == %s",
1406                            JoinDebugStringPtr(vars_, ", "),
1407                            index_->DebugString(), target_var_->DebugString());
1408   }
1409 }
1410 
Accept(ModelVisitor * const visitor) const1411 void IntExprArrayElementCt::Accept(ModelVisitor* const visitor) const {
1412   visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1413   visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1414                                              vars_);
1415   visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1416   visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1417                                           target_var_);
1418   visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1419 }
1420 
1421 // ----- IntExprArrayElementCstCt -----
1422 
1423 // This constraint implements vars[index] == constant.
1424 
1425 class IntExprArrayElementCstCt : public Constraint {
1426  public:
IntExprArrayElementCstCt(Solver * const s,const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1427   IntExprArrayElementCstCt(Solver* const s, const std::vector<IntVar*>& vars,
1428                            IntVar* const index, int64_t target)
1429       : Constraint(s),
1430         vars_(vars),
1431         index_(index),
1432         target_(target),
1433         demons_(vars.size()) {}
1434 
~IntExprArrayElementCstCt()1435   ~IntExprArrayElementCstCt() override {}
1436 
Post()1437   void Post() override {
1438     for (int i = 0; i < vars_.size(); ++i) {
1439       demons_[i] = MakeConstraintDemon1(
1440           solver(), this, &IntExprArrayElementCstCt::Propagate, "Propagate", i);
1441       vars_[i]->WhenDomain(demons_[i]);
1442     }
1443     Demon* const index_demon = MakeConstraintDemon0(
1444         solver(), this, &IntExprArrayElementCstCt::PropagateIndex,
1445         "PropagateIndex");
1446     index_->WhenBound(index_demon);
1447   }
1448 
InitialPropagate()1449   void InitialPropagate() override {
1450     for (int i = 0; i < vars_.size(); ++i) {
1451       Propagate(i);
1452     }
1453     PropagateIndex();
1454   }
1455 
Propagate(int index)1456   void Propagate(int index) {
1457     if (!vars_[index]->Contains(target_)) {
1458       index_->RemoveValue(index);
1459       demons_[index]->inhibit(solver());
1460     }
1461   }
1462 
PropagateIndex()1463   void PropagateIndex() {
1464     if (index_->Bound()) {
1465       vars_[index_->Min()]->SetValue(target_);
1466     }
1467   }
1468 
DebugString() const1469   std::string DebugString() const override {
1470     return absl::StrFormat("IntExprArrayElement([%s], %s) == %d",
1471                            JoinDebugStringPtr(vars_, ", "),
1472                            index_->DebugString(), target_);
1473   }
1474 
Accept(ModelVisitor * const visitor) const1475   void Accept(ModelVisitor* const visitor) const override {
1476     visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1477     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1478                                                vars_);
1479     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1480                                             index_);
1481     visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1482     visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1483   }
1484 
1485  private:
1486   const std::vector<IntVar*> vars_;
1487   IntVar* const index_;
1488   const int64_t target_;
1489   std::vector<Demon*> demons_;
1490 };
1491 
1492 // This constraint implements index == position(constant in vars).
1493 
1494 class IntExprIndexOfCt : public Constraint {
1495  public:
IntExprIndexOfCt(Solver * const s,const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1496   IntExprIndexOfCt(Solver* const s, const std::vector<IntVar*>& vars,
1497                    IntVar* const index, int64_t target)
1498       : Constraint(s),
1499         vars_(vars),
1500         index_(index),
1501         target_(target),
1502         demons_(vars_.size()),
1503         index_iterator_(index->MakeHoleIterator(true)) {}
1504 
~IntExprIndexOfCt()1505   ~IntExprIndexOfCt() override {}
1506 
Post()1507   void Post() override {
1508     for (int i = 0; i < vars_.size(); ++i) {
1509       demons_[i] = MakeConstraintDemon1(
1510           solver(), this, &IntExprIndexOfCt::Propagate, "Propagate", i);
1511       vars_[i]->WhenDomain(demons_[i]);
1512     }
1513     Demon* const index_demon = MakeConstraintDemon0(
1514         solver(), this, &IntExprIndexOfCt::PropagateIndex, "PropagateIndex");
1515     index_->WhenDomain(index_demon);
1516   }
1517 
InitialPropagate()1518   void InitialPropagate() override {
1519     for (int i = 0; i < vars_.size(); ++i) {
1520       if (!index_->Contains(i)) {
1521         vars_[i]->RemoveValue(target_);
1522       } else if (!vars_[i]->Contains(target_)) {
1523         index_->RemoveValue(i);
1524         demons_[i]->inhibit(solver());
1525       } else if (vars_[i]->Bound()) {
1526         index_->SetValue(i);
1527         demons_[i]->inhibit(solver());
1528       }
1529     }
1530   }
1531 
Propagate(int index)1532   void Propagate(int index) {
1533     if (!vars_[index]->Contains(target_)) {
1534       index_->RemoveValue(index);
1535       demons_[index]->inhibit(solver());
1536     } else if (vars_[index]->Bound()) {
1537       index_->SetValue(index);
1538     }
1539   }
1540 
PropagateIndex()1541   void PropagateIndex() {
1542     const int64_t oldmax = index_->OldMax();
1543     const int64_t vmin = index_->Min();
1544     const int64_t vmax = index_->Max();
1545     for (int64_t value = index_->OldMin(); value < vmin; ++value) {
1546       vars_[value]->RemoveValue(target_);
1547       demons_[value]->inhibit(solver());
1548     }
1549     for (const int64_t value : InitAndGetValues(index_iterator_)) {
1550       vars_[value]->RemoveValue(target_);
1551       demons_[value]->inhibit(solver());
1552     }
1553     for (int64_t value = vmax + 1; value <= oldmax; ++value) {
1554       vars_[value]->RemoveValue(target_);
1555       demons_[value]->inhibit(solver());
1556     }
1557     if (index_->Bound()) {
1558       vars_[index_->Min()]->SetValue(target_);
1559     }
1560   }
1561 
DebugString() const1562   std::string DebugString() const override {
1563     return absl::StrFormat("IntExprIndexOf([%s], %s) == %d",
1564                            JoinDebugStringPtr(vars_, ", "),
1565                            index_->DebugString(), target_);
1566   }
1567 
Accept(ModelVisitor * const visitor) const1568   void Accept(ModelVisitor* const visitor) const override {
1569     visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);
1570     visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1571                                                vars_);
1572     visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1573                                             index_);
1574     visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1575     visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);
1576   }
1577 
1578  private:
1579   const std::vector<IntVar*> vars_;
1580   IntVar* const index_;
1581   const int64_t target_;
1582   std::vector<Demon*> demons_;
1583   IntVarIterator* const index_iterator_;
1584 };
1585 
1586 // Factory helper.
1587 
MakeElementEqualityFunc(Solver * const solver,const std::vector<int64_t> & vals,IntVar * const index,IntVar * const target)1588 Constraint* MakeElementEqualityFunc(Solver* const solver,
1589                                     const std::vector<int64_t>& vals,
1590                                     IntVar* const index, IntVar* const target) {
1591   if (index->Bound()) {
1592     const int64_t val = index->Min();
1593     if (val < 0 || val >= vals.size()) {
1594       return solver->MakeFalseConstraint();
1595     } else {
1596       return solver->MakeEquality(target, vals[val]);
1597     }
1598   } else {
1599     if (IsIncreasingContiguous(vals)) {
1600       return solver->MakeEquality(target, solver->MakeSum(index, vals[0]));
1601     } else {
1602       return solver->RevAlloc(
1603           new IntElementConstraint(solver, vals, index, target));
1604     }
1605   }
1606 }
1607 }  // namespace
1608 
MakeIfThenElseCt(IntVar * const condition,IntExpr * const then_expr,IntExpr * const else_expr,IntVar * const target_var)1609 Constraint* Solver::MakeIfThenElseCt(IntVar* const condition,
1610                                      IntExpr* const then_expr,
1611                                      IntExpr* const else_expr,
1612                                      IntVar* const target_var) {
1613   return RevAlloc(
1614       new IfThenElseCt(this, condition, then_expr, else_expr, target_var));
1615 }
1616 
MakeElement(const std::vector<IntVar * > & vars,IntVar * const index)1617 IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars,
1618                              IntVar* const index) {
1619   if (index->Bound()) {
1620     return vars[index->Min()];
1621   }
1622   const int size = vars.size();
1623   if (AreAllBound(vars)) {
1624     std::vector<int64_t> values(size);
1625     for (int i = 0; i < size; ++i) {
1626       values[i] = vars[i]->Value();
1627     }
1628     return MakeElement(values, index);
1629   }
1630   if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
1631       index->Min() >= 0 && index->Max() < vars.size()) {
1632     // Let's get the index between 0 and 1.
1633     IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
1634     IntVar* const zero = vars[index->Min()];
1635     IntVar* const one = vars[index->Max()];
1636     const std::string name = absl::StrFormat(
1637         "ElementVar([%s], %s)", JoinNamePtr(vars, ", "), index->name());
1638     IntVar* const target = MakeIntVar(std::min(zero->Min(), one->Min()),
1639                                       std::max(zero->Max(), one->Max()), name);
1640     AddConstraint(
1641         RevAlloc(new IfThenElseCt(this, scaled_index, one, zero, target)));
1642     return target;
1643   }
1644   int64_t emin = std::numeric_limits<int64_t>::max();
1645   int64_t emax = std::numeric_limits<int64_t>::min();
1646   std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
1647   for (const int64_t index_value : InitAndGetValues(iterator.get())) {
1648     if (index_value >= 0 && index_value < size) {
1649       emin = std::min(emin, vars[index_value]->Min());
1650       emax = std::max(emax, vars[index_value]->Max());
1651     }
1652   }
1653   const std::string vname =
1654       size > 10 ? absl::StrFormat("ElementVar(var array of size %d, %s)", size,
1655                                   index->DebugString())
1656                 : absl::StrFormat("ElementVar([%s], %s)",
1657                                   JoinNamePtr(vars, ", "), index->name());
1658   IntVar* const element_var = MakeIntVar(emin, emax, vname);
1659   AddConstraint(
1660       RevAlloc(new IntExprArrayElementCt(this, vars, index, element_var)));
1661   return element_var;
1662 }
1663 
MakeElement(Int64ToIntVar vars,int64_t range_start,int64_t range_end,IntVar * argument)1664 IntExpr* Solver::MakeElement(Int64ToIntVar vars, int64_t range_start,
1665                              int64_t range_end, IntVar* argument) {
1666   const std::string index_name =
1667       !argument->name().empty() ? argument->name() : argument->DebugString();
1668   const std::string vname = absl::StrFormat(
1669       "ElementVar(%s, %s)",
1670       StringifyInt64ToIntVar(vars, range_start, range_end), index_name);
1671   IntVar* const element_var =
1672       MakeIntVar(std::numeric_limits<int64_t>::min(),
1673                  std::numeric_limits<int64_t>::max(), vname);
1674   IntExprEvaluatorElementCt* evaluation_ct = new IntExprEvaluatorElementCt(
1675       this, std::move(vars), range_start, range_end, argument, element_var);
1676   AddConstraint(RevAlloc(evaluation_ct));
1677   evaluation_ct->Propagate();
1678   return element_var;
1679 }
1680 
MakeElementEquality(const std::vector<int64_t> & vals,IntVar * const index,IntVar * const target)1681 Constraint* Solver::MakeElementEquality(const std::vector<int64_t>& vals,
1682                                         IntVar* const index,
1683                                         IntVar* const target) {
1684   return MakeElementEqualityFunc(this, vals, index, target);
1685 }
1686 
MakeElementEquality(const std::vector<int> & vals,IntVar * const index,IntVar * const target)1687 Constraint* Solver::MakeElementEquality(const std::vector<int>& vals,
1688                                         IntVar* const index,
1689                                         IntVar* const target) {
1690   return MakeElementEqualityFunc(this, ToInt64Vector(vals), index, target);
1691 }
1692 
MakeElementEquality(const std::vector<IntVar * > & vars,IntVar * const index,IntVar * const target)1693 Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1694                                         IntVar* const index,
1695                                         IntVar* const target) {
1696   if (AreAllBound(vars)) {
1697     std::vector<int64_t> values(vars.size());
1698     for (int i = 0; i < vars.size(); ++i) {
1699       values[i] = vars[i]->Value();
1700     }
1701     return MakeElementEquality(values, index, target);
1702   }
1703   if (index->Bound()) {
1704     const int64_t val = index->Min();
1705     if (val < 0 || val >= vars.size()) {
1706       return MakeFalseConstraint();
1707     } else {
1708       return MakeEquality(target, vars[val]);
1709     }
1710   } else {
1711     if (target->Bound()) {
1712       return RevAlloc(
1713           new IntExprArrayElementCstCt(this, vars, index, target->Min()));
1714     } else {
1715       return RevAlloc(new IntExprArrayElementCt(this, vars, index, target));
1716     }
1717   }
1718 }
1719 
MakeElementEquality(const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1720 Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1721                                         IntVar* const index, int64_t target) {
1722   if (AreAllBound(vars)) {
1723     std::vector<int> valid_indices;
1724     for (int i = 0; i < vars.size(); ++i) {
1725       if (vars[i]->Value() == target) {
1726         valid_indices.push_back(i);
1727       }
1728     }
1729     return MakeMemberCt(index, valid_indices);
1730   }
1731   if (index->Bound()) {
1732     const int64_t pos = index->Min();
1733     if (pos >= 0 && pos < vars.size()) {
1734       IntVar* const var = vars[pos];
1735       return MakeEquality(var, target);
1736     } else {
1737       return MakeFalseConstraint();
1738     }
1739   } else {
1740     return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));
1741   }
1742 }
1743 
MakeIndexOfConstraint(const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1744 Constraint* Solver::MakeIndexOfConstraint(const std::vector<IntVar*>& vars,
1745                                           IntVar* const index, int64_t target) {
1746   if (index->Bound()) {
1747     const int64_t pos = index->Min();
1748     if (pos >= 0 && pos < vars.size()) {
1749       IntVar* const var = vars[pos];
1750       return MakeEquality(var, target);
1751     } else {
1752       return MakeFalseConstraint();
1753     }
1754   } else {
1755     return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));
1756   }
1757 }
1758 
MakeIndexExpression(const std::vector<IntVar * > & vars,int64_t value)1759 IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
1760                                      int64_t value) {
1761   IntExpr* const cache = model_cache_->FindVarArrayConstantExpression(
1762       vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX);
1763   if (cache != nullptr) {
1764     return cache->Var();
1765   } else {
1766     const std::string name =
1767         absl::StrFormat("Index(%s, %d)", JoinNamePtr(vars, ", "), value);
1768     IntVar* const index = MakeIntVar(0, vars.size() - 1, name);
1769     AddConstraint(MakeIndexOfConstraint(vars, index, value));
1770     model_cache_->InsertVarArrayConstantExpression(
1771         index, vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX);
1772     return index;
1773   }
1774 }
1775 }  // namespace operations_research
1776