1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 #include <algorithm>
15 #include <cstdint>
16 #include <limits>
17 #include <memory>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "ortools/base/integral_types.h"
26 #include "ortools/base/logging.h"
27 #include "ortools/constraint_solver/constraint_solver.h"
28 #include "ortools/constraint_solver/constraint_solveri.h"
29 #include "ortools/util/range_minimum_query.h"
30 #include "ortools/util/string_array.h"
31
32 ABSL_FLAG(bool, cp_disable_element_cache, true,
33 "If true, caching for IntElement is disabled.");
34
35 namespace operations_research {
36
37 // ----- IntExprElement -----
38 void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var);
39
40 namespace {
41
42 template <class T>
43 class VectorLess {
44 public:
VectorLess(const std::vector<T> * values)45 explicit VectorLess(const std::vector<T>* values) : values_(values) {}
operator ()(const T & x,const T & y) const46 bool operator()(const T& x, const T& y) const {
47 return (*values_)[x] < (*values_)[y];
48 }
49
50 private:
51 const std::vector<T>* values_;
52 };
53
54 template <class T>
55 class VectorGreater {
56 public:
VectorGreater(const std::vector<T> * values)57 explicit VectorGreater(const std::vector<T>* values) : values_(values) {}
operator ()(const T & x,const T & y) const58 bool operator()(const T& x, const T& y) const {
59 return (*values_)[x] > (*values_)[y];
60 }
61
62 private:
63 const std::vector<T>* values_;
64 };
65
66 // ----- BaseIntExprElement -----
67
68 class BaseIntExprElement : public BaseIntExpr {
69 public:
70 BaseIntExprElement(Solver* const s, IntVar* const e);
~BaseIntExprElement()71 ~BaseIntExprElement() override {}
72 int64_t Min() const override;
73 int64_t Max() const override;
74 void Range(int64_t* mi, int64_t* ma) override;
75 void SetMin(int64_t m) override;
76 void SetMax(int64_t m) override;
77 void SetRange(int64_t mi, int64_t ma) override;
Bound() const78 bool Bound() const override { return (expr_->Bound()); }
79 // TODO(user) : improve me, the previous test is not always true
WhenRange(Demon * d)80 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
81
82 protected:
83 virtual int64_t ElementValue(int index) const = 0;
84 virtual int64_t ExprMin() const = 0;
85 virtual int64_t ExprMax() const = 0;
86
87 IntVar* const expr_;
88
89 private:
90 void UpdateSupports() const;
91
92 mutable int64_t min_;
93 mutable int min_support_;
94 mutable int64_t max_;
95 mutable int max_support_;
96 mutable bool initial_update_;
97 IntVarIterator* const expr_iterator_;
98 };
99
BaseIntExprElement(Solver * const s,IntVar * const e)100 BaseIntExprElement::BaseIntExprElement(Solver* const s, IntVar* const e)
101 : BaseIntExpr(s),
102 expr_(e),
103 min_(0),
104 min_support_(-1),
105 max_(0),
106 max_support_(-1),
107 initial_update_(true),
108 expr_iterator_(expr_->MakeDomainIterator(true)) {
109 CHECK(s != nullptr);
110 CHECK(e != nullptr);
111 }
112
Min() const113 int64_t BaseIntExprElement::Min() const {
114 UpdateSupports();
115 return min_;
116 }
117
Max() const118 int64_t BaseIntExprElement::Max() const {
119 UpdateSupports();
120 return max_;
121 }
122
Range(int64_t * mi,int64_t * ma)123 void BaseIntExprElement::Range(int64_t* mi, int64_t* ma) {
124 UpdateSupports();
125 *mi = min_;
126 *ma = max_;
127 }
128
129 #define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test) \
130 const int64_t emin = ExprMin(); \
131 const int64_t emax = ExprMax(); \
132 int64_t nmin = emin; \
133 int64_t value = ElementValue(nmin); \
134 while (nmin < emax && test) { \
135 nmin++; \
136 value = ElementValue(nmin); \
137 } \
138 if (nmin == emax && test) { \
139 solver()->Fail(); \
140 } \
141 int64_t nmax = emax; \
142 value = ElementValue(nmax); \
143 while (nmax >= nmin && test) { \
144 nmax--; \
145 value = ElementValue(nmax); \
146 } \
147 expr_->SetRange(nmin, nmax);
148
SetMin(int64_t m)149 void BaseIntExprElement::SetMin(int64_t m) {
150 UPDATE_BASE_ELEMENT_INDEX_BOUNDS(value < m);
151 }
152
SetMax(int64_t m)153 void BaseIntExprElement::SetMax(int64_t m) {
154 UPDATE_BASE_ELEMENT_INDEX_BOUNDS(value > m);
155 }
156
SetRange(int64_t mi,int64_t ma)157 void BaseIntExprElement::SetRange(int64_t mi, int64_t ma) {
158 if (mi > ma) {
159 solver()->Fail();
160 }
161 UPDATE_BASE_ELEMENT_INDEX_BOUNDS((value < mi || value > ma));
162 }
163
164 #undef UPDATE_BASE_ELEMENT_INDEX_BOUNDS
165
UpdateSupports() const166 void BaseIntExprElement::UpdateSupports() const {
167 if (initial_update_ || !expr_->Contains(min_support_) ||
168 !expr_->Contains(max_support_)) {
169 const int64_t emin = ExprMin();
170 const int64_t emax = ExprMax();
171 int64_t min_value = ElementValue(emax);
172 int64_t max_value = min_value;
173 int min_support = emax;
174 int max_support = emax;
175 const uint64_t expr_size = expr_->Size();
176 if (expr_size > 1) {
177 if (expr_size == emax - emin + 1) {
178 // Value(emax) already stored in min_value, max_value.
179 for (int64_t index = emin; index < emax; ++index) {
180 const int64_t value = ElementValue(index);
181 if (value > max_value) {
182 max_value = value;
183 max_support = index;
184 } else if (value < min_value) {
185 min_value = value;
186 min_support = index;
187 }
188 }
189 } else {
190 for (const int64_t index : InitAndGetValues(expr_iterator_)) {
191 if (index >= emin && index <= emax) {
192 const int64_t value = ElementValue(index);
193 if (value > max_value) {
194 max_value = value;
195 max_support = index;
196 } else if (value < min_value) {
197 min_value = value;
198 min_support = index;
199 }
200 }
201 }
202 }
203 }
204 Solver* s = solver();
205 s->SaveAndSetValue(&min_, min_value);
206 s->SaveAndSetValue(&min_support_, min_support);
207 s->SaveAndSetValue(&max_, max_value);
208 s->SaveAndSetValue(&max_support_, max_support);
209 s->SaveAndSetValue(&initial_update_, false);
210 }
211 }
212
213 // ----- IntElementConstraint -----
214
215 // This constraint implements 'elem' == 'values'['index'].
216 // It scans the bounds of 'elem' to propagate on the domain of 'index'.
217 // It scans the domain of 'index' to compute the new bounds of 'elem'.
218 class IntElementConstraint : public CastConstraint {
219 public:
IntElementConstraint(Solver * const s,const std::vector<int64_t> & values,IntVar * const index,IntVar * const elem)220 IntElementConstraint(Solver* const s, const std::vector<int64_t>& values,
221 IntVar* const index, IntVar* const elem)
222 : CastConstraint(s, elem),
223 values_(values),
224 index_(index),
225 index_iterator_(index_->MakeDomainIterator(true)) {
226 CHECK(index != nullptr);
227 }
228
Post()229 void Post() override {
230 Demon* const d =
231 solver()->MakeDelayedConstraintInitialPropagateCallback(this);
232 index_->WhenDomain(d);
233 target_var_->WhenRange(d);
234 }
235
InitialPropagate()236 void InitialPropagate() override {
237 index_->SetRange(0, values_.size() - 1);
238 const int64_t target_var_min = target_var_->Min();
239 const int64_t target_var_max = target_var_->Max();
240 int64_t new_min = target_var_max;
241 int64_t new_max = target_var_min;
242 to_remove_.clear();
243 for (const int64_t index : InitAndGetValues(index_iterator_)) {
244 const int64_t value = values_[index];
245 if (value < target_var_min || value > target_var_max) {
246 to_remove_.push_back(index);
247 } else {
248 if (value < new_min) {
249 new_min = value;
250 }
251 if (value > new_max) {
252 new_max = value;
253 }
254 }
255 }
256 target_var_->SetRange(new_min, new_max);
257 if (!to_remove_.empty()) {
258 index_->RemoveValues(to_remove_);
259 }
260 }
261
DebugString() const262 std::string DebugString() const override {
263 return absl::StrFormat("IntElementConstraint(%s, %s, %s)",
264 absl::StrJoin(values_, ", "), index_->DebugString(),
265 target_var_->DebugString());
266 }
267
Accept(ModelVisitor * const visitor) const268 void Accept(ModelVisitor* const visitor) const override {
269 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
270 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
271 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
272 index_);
273 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
274 target_var_);
275 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
276 }
277
278 private:
279 const std::vector<int64_t> values_;
280 IntVar* const index_;
281 IntVarIterator* const index_iterator_;
282 std::vector<int64_t> to_remove_;
283 };
284
285 // ----- IntExprElement
286
287 IntVar* BuildDomainIntVar(Solver* const solver, std::vector<int64_t>* values);
288
289 class IntExprElement : public BaseIntExprElement {
290 public:
IntExprElement(Solver * const s,const std::vector<int64_t> & vals,IntVar * const expr)291 IntExprElement(Solver* const s, const std::vector<int64_t>& vals,
292 IntVar* const expr)
293 : BaseIntExprElement(s, expr), values_(vals) {}
294
~IntExprElement()295 ~IntExprElement() override {}
296
name() const297 std::string name() const override {
298 const int size = values_.size();
299 if (size > 10) {
300 return absl::StrFormat("IntElement(array of size %d, %s)", size,
301 expr_->name());
302 } else {
303 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
304 expr_->name());
305 }
306 }
307
DebugString() const308 std::string DebugString() const override {
309 const int size = values_.size();
310 if (size > 10) {
311 return absl::StrFormat("IntElement(array of size %d, %s)", size,
312 expr_->DebugString());
313 } else {
314 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
315 expr_->DebugString());
316 }
317 }
318
CastToVar()319 IntVar* CastToVar() override {
320 Solver* const s = solver();
321 IntVar* const var = s->MakeIntVar(values_);
322 s->AddCastConstraint(
323 s->RevAlloc(new IntElementConstraint(s, values_, expr_, var)), var,
324 this);
325 return var;
326 }
327
Accept(ModelVisitor * const visitor) const328 void Accept(ModelVisitor* const visitor) const override {
329 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
330 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
331 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
332 expr_);
333 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
334 }
335
336 protected:
ElementValue(int index) const337 int64_t ElementValue(int index) const override {
338 DCHECK_LT(index, values_.size());
339 return values_[index];
340 }
ExprMin() const341 int64_t ExprMin() const override {
342 return std::max<int64_t>(0, expr_->Min());
343 }
ExprMax() const344 int64_t ExprMax() const override {
345 return values_.empty()
346 ? 0
347 : std::min<int64_t>(values_.size() - 1, expr_->Max());
348 }
349
350 private:
351 const std::vector<int64_t> values_;
352 };
353
354 // ----- Range Minimum Query-based Element -----
355
356 class RangeMinimumQueryExprElement : public BaseIntExpr {
357 public:
358 RangeMinimumQueryExprElement(Solver* solver,
359 const std::vector<int64_t>& values,
360 IntVar* index);
~RangeMinimumQueryExprElement()361 ~RangeMinimumQueryExprElement() override {}
362 int64_t Min() const override;
363 int64_t Max() const override;
364 void Range(int64_t* mi, int64_t* ma) override;
365 void SetMin(int64_t m) override;
366 void SetMax(int64_t m) override;
367 void SetRange(int64_t mi, int64_t ma) override;
Bound() const368 bool Bound() const override { return (index_->Bound()); }
369 // TODO(user) : improve me, the previous test is not always true
WhenRange(Demon * d)370 void WhenRange(Demon* d) override { index_->WhenRange(d); }
CastToVar()371 IntVar* CastToVar() override {
372 // TODO(user): Should we try to make holes in the domain of index_, as we
373 // do here, or should we only propagate bounds as we do in
374 // IncreasingIntExprElement ?
375 IntVar* const var = solver()->MakeIntVar(min_rmq_.array());
376 solver()->AddCastConstraint(solver()->RevAlloc(new IntElementConstraint(
377 solver(), min_rmq_.array(), index_, var)),
378 var, this);
379 return var;
380 }
Accept(ModelVisitor * const visitor) const381 void Accept(ModelVisitor* const visitor) const override {
382 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
383 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
384 min_rmq_.array());
385 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
386 index_);
387 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
388 }
389
390 private:
IndexMin() const391 int64_t IndexMin() const { return std::max<int64_t>(0, index_->Min()); }
IndexMax() const392 int64_t IndexMax() const {
393 return std::min<int64_t>(min_rmq_.array().size() - 1, index_->Max());
394 }
395
396 IntVar* const index_;
397 const RangeMinimumQuery<int64_t, std::less<int64_t>> min_rmq_;
398 const RangeMinimumQuery<int64_t, std::greater<int64_t>> max_rmq_;
399 };
400
RangeMinimumQueryExprElement(Solver * solver,const std::vector<int64_t> & values,IntVar * index)401 RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(
402 Solver* solver, const std::vector<int64_t>& values, IntVar* index)
403 : BaseIntExpr(solver), index_(index), min_rmq_(values), max_rmq_(values) {
404 CHECK(solver != nullptr);
405 CHECK(index != nullptr);
406 }
407
Min() const408 int64_t RangeMinimumQueryExprElement::Min() const {
409 return min_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
410 }
411
Max() const412 int64_t RangeMinimumQueryExprElement::Max() const {
413 return max_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
414 }
415
Range(int64_t * mi,int64_t * ma)416 void RangeMinimumQueryExprElement::Range(int64_t* mi, int64_t* ma) {
417 const int64_t range_min = IndexMin();
418 const int64_t range_max = IndexMax() + 1;
419 *mi = min_rmq_.GetMinimumFromRange(range_min, range_max);
420 *ma = max_rmq_.GetMinimumFromRange(range_min, range_max);
421 }
422
423 #define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test) \
424 const std::vector<int64_t>& values = min_rmq_.array(); \
425 int64_t index_min = IndexMin(); \
426 int64_t index_max = IndexMax(); \
427 int64_t value = values[index_min]; \
428 while (index_min < index_max && (test)) { \
429 index_min++; \
430 value = values[index_min]; \
431 } \
432 if (index_min == index_max && (test)) { \
433 solver()->Fail(); \
434 } \
435 value = values[index_max]; \
436 while (index_max >= index_min && (test)) { \
437 index_max--; \
438 value = values[index_max]; \
439 } \
440 index_->SetRange(index_min, index_max);
441
SetMin(int64_t m)442 void RangeMinimumQueryExprElement::SetMin(int64_t m) {
443 UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < m);
444 }
445
SetMax(int64_t m)446 void RangeMinimumQueryExprElement::SetMax(int64_t m) {
447 UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value > m);
448 }
449
SetRange(int64_t mi,int64_t ma)450 void RangeMinimumQueryExprElement::SetRange(int64_t mi, int64_t ma) {
451 if (mi > ma) {
452 solver()->Fail();
453 }
454 UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < mi || value > ma);
455 }
456
457 #undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS
458
459 // ----- Increasing Element -----
460
461 class IncreasingIntExprElement : public BaseIntExpr {
462 public:
463 IncreasingIntExprElement(Solver* const s, const std::vector<int64_t>& values,
464 IntVar* const index);
~IncreasingIntExprElement()465 ~IncreasingIntExprElement() override {}
466
467 int64_t Min() const override;
468 void SetMin(int64_t m) override;
469 int64_t Max() const override;
470 void SetMax(int64_t m) override;
471 void SetRange(int64_t mi, int64_t ma) override;
Bound() const472 bool Bound() const override { return (index_->Bound()); }
473 // TODO(user) : improve me, the previous test is not always true
name() const474 std::string name() const override {
475 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
476 index_->name());
477 }
DebugString() const478 std::string DebugString() const override {
479 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
480 index_->DebugString());
481 }
482
Accept(ModelVisitor * const visitor) const483 void Accept(ModelVisitor* const visitor) const override {
484 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
485 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
486 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
487 index_);
488 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
489 }
490
WhenRange(Demon * d)491 void WhenRange(Demon* d) override { index_->WhenRange(d); }
492
CastToVar()493 IntVar* CastToVar() override {
494 Solver* const s = solver();
495 IntVar* const var = s->MakeIntVar(values_);
496 LinkVarExpr(s, this, var);
497 return var;
498 }
499
500 private:
501 const std::vector<int64_t> values_;
502 IntVar* const index_;
503 };
504
IncreasingIntExprElement(Solver * const s,const std::vector<int64_t> & values,IntVar * const index)505 IncreasingIntExprElement::IncreasingIntExprElement(
506 Solver* const s, const std::vector<int64_t>& values, IntVar* const index)
507 : BaseIntExpr(s), values_(values), index_(index) {
508 DCHECK(index);
509 DCHECK(s);
510 }
511
Min() const512 int64_t IncreasingIntExprElement::Min() const {
513 const int64_t expression_min = std::max<int64_t>(0, index_->Min());
514 return (expression_min < values_.size()
515 ? values_[expression_min]
516 : std::numeric_limits<int64_t>::max());
517 }
518
SetMin(int64_t m)519 void IncreasingIntExprElement::SetMin(int64_t m) {
520 const int64_t index_min = std::max<int64_t>(0, index_->Min());
521 const int64_t index_max =
522 std::min<int64_t>(values_.size() - 1, index_->Max());
523
524 if (index_min > index_max || m > values_[index_max]) {
525 solver()->Fail();
526 }
527
528 const std::vector<int64_t>::const_iterator first =
529 std::lower_bound(values_.begin(), values_.end(), m);
530 const int64_t new_index_min = first - values_.begin();
531 index_->SetMin(new_index_min);
532 }
533
Max() const534 int64_t IncreasingIntExprElement::Max() const {
535 const int64_t expression_max =
536 std::min<int64_t>(values_.size() - 1, index_->Max());
537 return (expression_max >= 0 ? values_[expression_max]
538 : std::numeric_limits<int64_t>::max());
539 }
540
SetMax(int64_t m)541 void IncreasingIntExprElement::SetMax(int64_t m) {
542 int64_t index_min = std::max<int64_t>(0, index_->Min());
543 if (m < values_[index_min]) {
544 solver()->Fail();
545 }
546
547 const std::vector<int64_t>::const_iterator last_after =
548 std::upper_bound(values_.begin(), values_.end(), m);
549 const int64_t new_index_max = (last_after - values_.begin()) - 1;
550 index_->SetRange(0, new_index_max);
551 }
552
SetRange(int64_t mi,int64_t ma)553 void IncreasingIntExprElement::SetRange(int64_t mi, int64_t ma) {
554 if (mi > ma) {
555 solver()->Fail();
556 }
557 const int64_t index_min = std::max<int64_t>(0, index_->Min());
558 const int64_t index_max =
559 std::min<int64_t>(values_.size() - 1, index_->Max());
560
561 if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {
562 solver()->Fail();
563 }
564
565 const std::vector<int64_t>::const_iterator first =
566 std::lower_bound(values_.begin(), values_.end(), mi);
567 const int64_t new_index_min = first - values_.begin();
568
569 const std::vector<int64_t>::const_iterator last_after =
570 std::upper_bound(first, values_.end(), ma);
571 const int64_t new_index_max = (last_after - values_.begin()) - 1;
572
573 // Assign.
574 index_->SetRange(new_index_min, new_index_max);
575 }
576
577 // ----- Solver::MakeElement(int array, int var) -----
BuildElement(Solver * const solver,const std::vector<int64_t> & values,IntVar * const index)578 IntExpr* BuildElement(Solver* const solver, const std::vector<int64_t>& values,
579 IntVar* const index) {
580 // Various checks.
581 // Is array constant?
582 if (IsArrayConstant(values, values[0])) {
583 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
584 return solver->MakeIntConst(values[0]);
585 }
586 // Is array built with booleans only?
587 // TODO(user): We could maintain the index of the first one.
588 if (IsArrayBoolean(values)) {
589 std::vector<int64_t> ones;
590 int first_zero = -1;
591 for (int i = 0; i < values.size(); ++i) {
592 if (values[i] == 1) {
593 ones.push_back(i);
594 } else {
595 first_zero = i;
596 }
597 }
598 if (ones.size() == 1) {
599 DCHECK_EQ(int64_t{1}, values[ones.back()]);
600 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
601 return solver->MakeIsEqualCstVar(index, ones.back());
602 } else if (ones.size() == values.size() - 1) {
603 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
604 return solver->MakeIsDifferentCstVar(index, first_zero);
605 } else if (ones.size() == ones.back() - ones.front() + 1) { // contiguous.
606 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
607 IntVar* const b = solver->MakeBoolVar("ContiguousBooleanElementVar");
608 solver->AddConstraint(
609 solver->MakeIsBetweenCt(index, ones.front(), ones.back(), b));
610 return b;
611 } else {
612 IntVar* const b = solver->MakeBoolVar("NonContiguousBooleanElementVar");
613 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
614 solver->AddConstraint(solver->MakeIsMemberCt(index, ones, b));
615 return b;
616 }
617 }
618 IntExpr* cache = nullptr;
619 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
620 cache = solver->Cache()->FindVarConstantArrayExpression(
621 index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
622 }
623 if (cache != nullptr) {
624 return cache;
625 } else {
626 IntExpr* result = nullptr;
627 if (values.size() >= 2 && index->Min() == 0 && index->Max() == 1) {
628 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
629 values[0]);
630 } else if (values.size() == 2 && index->Contains(0) && index->Contains(1)) {
631 solver->AddConstraint(solver->MakeBetweenCt(index, 0, 1));
632 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
633 values[0]);
634 } else if (IsIncreasingContiguous(values)) {
635 result = solver->MakeSum(index, values[0]);
636 } else if (IsIncreasing(values)) {
637 result = solver->RegisterIntExpr(solver->RevAlloc(
638 new IncreasingIntExprElement(solver, values, index)));
639 } else {
640 if (solver->parameters().use_element_rmq()) {
641 result = solver->RegisterIntExpr(solver->RevAlloc(
642 new RangeMinimumQueryExprElement(solver, values, index)));
643 } else {
644 result = solver->RegisterIntExpr(
645 solver->RevAlloc(new IntExprElement(solver, values, index)));
646 }
647 }
648 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
649 solver->Cache()->InsertVarConstantArrayExpression(
650 result, index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
651 }
652 return result;
653 }
654 }
655 } // namespace
656
MakeElement(const std::vector<int64_t> & values,IntVar * const index)657 IntExpr* Solver::MakeElement(const std::vector<int64_t>& values,
658 IntVar* const index) {
659 DCHECK(index);
660 DCHECK_EQ(this, index->solver());
661 if (index->Bound()) {
662 return MakeIntConst(values[index->Min()]);
663 }
664 return BuildElement(this, values, index);
665 }
666
MakeElement(const std::vector<int> & values,IntVar * const index)667 IntExpr* Solver::MakeElement(const std::vector<int>& values,
668 IntVar* const index) {
669 DCHECK(index);
670 DCHECK_EQ(this, index->solver());
671 if (index->Bound()) {
672 return MakeIntConst(values[index->Min()]);
673 }
674 return BuildElement(this, ToInt64Vector(values), index);
675 }
676
677 // ----- IntExprFunctionElement -----
678
679 namespace {
680 class IntExprFunctionElement : public BaseIntExprElement {
681 public:
682 IntExprFunctionElement(Solver* const s, Solver::IndexEvaluator1 values,
683 IntVar* const e);
684 ~IntExprFunctionElement() override;
685
name() const686 std::string name() const override {
687 return absl::StrFormat("IntFunctionElement(%s)", expr_->name());
688 }
689
DebugString() const690 std::string DebugString() const override {
691 return absl::StrFormat("IntFunctionElement(%s)", expr_->DebugString());
692 }
693
Accept(ModelVisitor * const visitor) const694 void Accept(ModelVisitor* const visitor) const override {
695 // Warning: This will expand all values into a vector.
696 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
697 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
698 expr_);
699 visitor->VisitInt64ToInt64Extension(values_, expr_->Min(), expr_->Max());
700 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
701 }
702
703 protected:
ElementValue(int index) const704 int64_t ElementValue(int index) const override { return values_(index); }
ExprMin() const705 int64_t ExprMin() const override { return expr_->Min(); }
ExprMax() const706 int64_t ExprMax() const override { return expr_->Max(); }
707
708 private:
709 Solver::IndexEvaluator1 values_;
710 };
711
IntExprFunctionElement(Solver * const s,Solver::IndexEvaluator1 values,IntVar * const e)712 IntExprFunctionElement::IntExprFunctionElement(Solver* const s,
713 Solver::IndexEvaluator1 values,
714 IntVar* const e)
715 : BaseIntExprElement(s, e), values_(std::move(values)) {
716 CHECK(values_ != nullptr);
717 }
718
~IntExprFunctionElement()719 IntExprFunctionElement::~IntExprFunctionElement() {}
720
721 // ----- Increasing Element -----
722
723 class IncreasingIntExprFunctionElement : public BaseIntExpr {
724 public:
IncreasingIntExprFunctionElement(Solver * const s,Solver::IndexEvaluator1 values,IntVar * const index)725 IncreasingIntExprFunctionElement(Solver* const s,
726 Solver::IndexEvaluator1 values,
727 IntVar* const index)
728 : BaseIntExpr(s), values_(std::move(values)), index_(index) {
729 DCHECK(values_ != nullptr);
730 DCHECK(index);
731 DCHECK(s);
732 }
733
~IncreasingIntExprFunctionElement()734 ~IncreasingIntExprFunctionElement() override {}
735
Min() const736 int64_t Min() const override { return values_(index_->Min()); }
737
SetMin(int64_t m)738 void SetMin(int64_t m) override {
739 const int64_t index_min = index_->Min();
740 const int64_t index_max = index_->Max();
741 if (m > values_(index_max)) {
742 solver()->Fail();
743 }
744 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, m);
745 index_->SetMin(new_index_min);
746 }
747
Max() const748 int64_t Max() const override { return values_(index_->Max()); }
749
SetMax(int64_t m)750 void SetMax(int64_t m) override {
751 int64_t index_min = index_->Min();
752 int64_t index_max = index_->Max();
753 if (m < values_(index_min)) {
754 solver()->Fail();
755 }
756 const int64_t new_index_max = FindNewIndexMax(index_min, index_max, m);
757 index_->SetMax(new_index_max);
758 }
759
SetRange(int64_t mi,int64_t ma)760 void SetRange(int64_t mi, int64_t ma) override {
761 const int64_t index_min = index_->Min();
762 const int64_t index_max = index_->Max();
763 const int64_t value_min = values_(index_min);
764 const int64_t value_max = values_(index_max);
765 if (mi > ma || ma < value_min || mi > value_max) {
766 solver()->Fail();
767 }
768 if (mi <= value_min && ma >= value_max) {
769 // Nothing to do.
770 return;
771 }
772
773 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, mi);
774 const int64_t new_index_max = FindNewIndexMax(new_index_min, index_max, ma);
775 // Assign.
776 index_->SetRange(new_index_min, new_index_max);
777 }
778
name() const779 std::string name() const override {
780 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
781 index_->name());
782 }
783
DebugString() const784 std::string DebugString() const override {
785 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
786 index_->DebugString());
787 }
788
WhenRange(Demon * d)789 void WhenRange(Demon* d) override { index_->WhenRange(d); }
790
Accept(ModelVisitor * const visitor) const791 void Accept(ModelVisitor* const visitor) const override {
792 // Warning: This will expand all values into a vector.
793 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
794 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
795 index_);
796 if (index_->Min() == 0) {
797 visitor->VisitInt64ToInt64AsArray(values_, ModelVisitor::kValuesArgument,
798 index_->Max());
799 } else {
800 visitor->VisitInt64ToInt64Extension(values_, index_->Min(),
801 index_->Max());
802 }
803 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
804 }
805
806 private:
FindNewIndexMin(int64_t index_min,int64_t index_max,int64_t m)807 int64_t FindNewIndexMin(int64_t index_min, int64_t index_max, int64_t m) {
808 if (m <= values_(index_min)) {
809 return index_min;
810 }
811
812 DCHECK_LT(values_(index_min), m);
813 DCHECK_GE(values_(index_max), m);
814
815 int64_t index_lower_bound = index_min;
816 int64_t index_upper_bound = index_max;
817 while (index_upper_bound - index_lower_bound > 1) {
818 DCHECK_LT(values_(index_lower_bound), m);
819 DCHECK_GE(values_(index_upper_bound), m);
820 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
821 const int64_t pivot_value = values_(pivot);
822 if (pivot_value < m) {
823 index_lower_bound = pivot;
824 } else {
825 index_upper_bound = pivot;
826 }
827 }
828 DCHECK(values_(index_upper_bound) >= m);
829 return index_upper_bound;
830 }
831
FindNewIndexMax(int64_t index_min,int64_t index_max,int64_t m)832 int64_t FindNewIndexMax(int64_t index_min, int64_t index_max, int64_t m) {
833 if (m >= values_(index_max)) {
834 return index_max;
835 }
836
837 DCHECK_LE(values_(index_min), m);
838 DCHECK_GT(values_(index_max), m);
839
840 int64_t index_lower_bound = index_min;
841 int64_t index_upper_bound = index_max;
842 while (index_upper_bound - index_lower_bound > 1) {
843 DCHECK_LE(values_(index_lower_bound), m);
844 DCHECK_GT(values_(index_upper_bound), m);
845 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
846 const int64_t pivot_value = values_(pivot);
847 if (pivot_value > m) {
848 index_upper_bound = pivot;
849 } else {
850 index_lower_bound = pivot;
851 }
852 }
853 DCHECK(values_(index_lower_bound) <= m);
854 return index_lower_bound;
855 }
856
857 Solver::IndexEvaluator1 values_;
858 IntVar* const index_;
859 };
860 } // namespace
861
MakeElement(Solver::IndexEvaluator1 values,IntVar * const index)862 IntExpr* Solver::MakeElement(Solver::IndexEvaluator1 values,
863 IntVar* const index) {
864 CHECK_EQ(this, index->solver());
865 return RegisterIntExpr(
866 RevAlloc(new IntExprFunctionElement(this, std::move(values), index)));
867 }
868
MakeMonotonicElement(Solver::IndexEvaluator1 values,bool increasing,IntVar * const index)869 IntExpr* Solver::MakeMonotonicElement(Solver::IndexEvaluator1 values,
870 bool increasing, IntVar* const index) {
871 CHECK_EQ(this, index->solver());
872 if (increasing) {
873 return RegisterIntExpr(
874 RevAlloc(new IncreasingIntExprFunctionElement(this, values, index)));
875 } else {
876 // You need to pass by copy such that opposite_value does not include a
877 // dandling reference when leaving this scope.
878 Solver::IndexEvaluator1 opposite_values = [values](int64_t i) {
879 return -values(i);
880 };
881 return RegisterIntExpr(MakeOpposite(RevAlloc(
882 new IncreasingIntExprFunctionElement(this, opposite_values, index))));
883 }
884 }
885
886 // ----- IntIntExprFunctionElement -----
887
888 namespace {
889 class IntIntExprFunctionElement : public BaseIntExpr {
890 public:
891 IntIntExprFunctionElement(Solver* const s, Solver::IndexEvaluator2 values,
892 IntVar* const expr1, IntVar* const expr2);
893 ~IntIntExprFunctionElement() override;
DebugString() const894 std::string DebugString() const override {
895 return absl::StrFormat("IntIntFunctionElement(%s,%s)",
896 expr1_->DebugString(), expr2_->DebugString());
897 }
898 int64_t Min() const override;
899 int64_t Max() const override;
900 void Range(int64_t* lower_bound, int64_t* upper_bound) override;
901 void SetMin(int64_t lower_bound) override;
902 void SetMax(int64_t upper_bound) override;
903 void SetRange(int64_t lower_bound, int64_t upper_bound) override;
Bound() const904 bool Bound() const override { return expr1_->Bound() && expr2_->Bound(); }
905 // TODO(user) : improve me, the previous test is not always true
WhenRange(Demon * d)906 void WhenRange(Demon* d) override {
907 expr1_->WhenRange(d);
908 expr2_->WhenRange(d);
909 }
910
Accept(ModelVisitor * const visitor) const911 void Accept(ModelVisitor* const visitor) const override {
912 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
913 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
914 expr1_);
915 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndex2Argument,
916 expr2_);
917 // Warning: This will expand all values into a vector.
918 const int64_t expr1_min = expr1_->Min();
919 const int64_t expr1_max = expr1_->Max();
920 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, expr1_min);
921 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, expr1_max);
922 for (int i = expr1_min; i <= expr1_max; ++i) {
923 visitor->VisitInt64ToInt64Extension(
924 [this, i](int64_t j) { return values_(i, j); }, expr2_->Min(),
925 expr2_->Max());
926 }
927 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
928 }
929
930 private:
ElementValue(int index1,int index2) const931 int64_t ElementValue(int index1, int index2) const {
932 return values_(index1, index2);
933 }
934 void UpdateSupports() const;
935
936 IntVar* const expr1_;
937 IntVar* const expr2_;
938 mutable int64_t min_;
939 mutable int min_support1_;
940 mutable int min_support2_;
941 mutable int64_t max_;
942 mutable int max_support1_;
943 mutable int max_support2_;
944 mutable bool initial_update_;
945 Solver::IndexEvaluator2 values_;
946 IntVarIterator* const expr1_iterator_;
947 IntVarIterator* const expr2_iterator_;
948 };
949
IntIntExprFunctionElement(Solver * const s,Solver::IndexEvaluator2 values,IntVar * const expr1,IntVar * const expr2)950 IntIntExprFunctionElement::IntIntExprFunctionElement(
951 Solver* const s, Solver::IndexEvaluator2 values, IntVar* const expr1,
952 IntVar* const expr2)
953 : BaseIntExpr(s),
954 expr1_(expr1),
955 expr2_(expr2),
956 min_(0),
957 min_support1_(-1),
958 min_support2_(-1),
959 max_(0),
960 max_support1_(-1),
961 max_support2_(-1),
962 initial_update_(true),
963 values_(std::move(values)),
964 expr1_iterator_(expr1_->MakeDomainIterator(true)),
965 expr2_iterator_(expr2_->MakeDomainIterator(true)) {
966 CHECK(values_ != nullptr);
967 }
968
~IntIntExprFunctionElement()969 IntIntExprFunctionElement::~IntIntExprFunctionElement() {}
970
Min() const971 int64_t IntIntExprFunctionElement::Min() const {
972 UpdateSupports();
973 return min_;
974 }
975
Max() const976 int64_t IntIntExprFunctionElement::Max() const {
977 UpdateSupports();
978 return max_;
979 }
980
Range(int64_t * lower_bound,int64_t * upper_bound)981 void IntIntExprFunctionElement::Range(int64_t* lower_bound,
982 int64_t* upper_bound) {
983 UpdateSupports();
984 *lower_bound = min_;
985 *upper_bound = max_;
986 }
987
988 #define UPDATE_ELEMENT_INDEX_BOUNDS(test) \
989 const int64_t emin1 = expr1_->Min(); \
990 const int64_t emax1 = expr1_->Max(); \
991 const int64_t emin2 = expr2_->Min(); \
992 const int64_t emax2 = expr2_->Max(); \
993 int64_t nmin1 = emin1; \
994 bool found = false; \
995 while (nmin1 <= emax1 && !found) { \
996 for (int i = emin2; i <= emax2; ++i) { \
997 int64_t value = ElementValue(nmin1, i); \
998 if (test) { \
999 found = true; \
1000 break; \
1001 } \
1002 } \
1003 if (!found) { \
1004 nmin1++; \
1005 } \
1006 } \
1007 if (nmin1 > emax1) { \
1008 solver()->Fail(); \
1009 } \
1010 int64_t nmin2 = emin2; \
1011 found = false; \
1012 while (nmin2 <= emax2 && !found) { \
1013 for (int i = emin1; i <= emax1; ++i) { \
1014 int64_t value = ElementValue(i, nmin2); \
1015 if (test) { \
1016 found = true; \
1017 break; \
1018 } \
1019 } \
1020 if (!found) { \
1021 nmin2++; \
1022 } \
1023 } \
1024 if (nmin2 > emax2) { \
1025 solver()->Fail(); \
1026 } \
1027 int64_t nmax1 = emax1; \
1028 found = false; \
1029 while (nmax1 >= nmin1 && !found) { \
1030 for (int i = emin2; i <= emax2; ++i) { \
1031 int64_t value = ElementValue(nmax1, i); \
1032 if (test) { \
1033 found = true; \
1034 break; \
1035 } \
1036 } \
1037 if (!found) { \
1038 nmax1--; \
1039 } \
1040 } \
1041 int64_t nmax2 = emax2; \
1042 found = false; \
1043 while (nmax2 >= nmin2 && !found) { \
1044 for (int i = emin1; i <= emax1; ++i) { \
1045 int64_t value = ElementValue(i, nmax2); \
1046 if (test) { \
1047 found = true; \
1048 break; \
1049 } \
1050 } \
1051 if (!found) { \
1052 nmax2--; \
1053 } \
1054 } \
1055 expr1_->SetRange(nmin1, nmax1); \
1056 expr2_->SetRange(nmin2, nmax2);
1057
SetMin(int64_t lower_bound)1058 void IntIntExprFunctionElement::SetMin(int64_t lower_bound) {
1059 UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound);
1060 }
1061
SetMax(int64_t upper_bound)1062 void IntIntExprFunctionElement::SetMax(int64_t upper_bound) {
1063 UPDATE_ELEMENT_INDEX_BOUNDS(value <= upper_bound);
1064 }
1065
SetRange(int64_t lower_bound,int64_t upper_bound)1066 void IntIntExprFunctionElement::SetRange(int64_t lower_bound,
1067 int64_t upper_bound) {
1068 if (lower_bound > upper_bound) {
1069 solver()->Fail();
1070 }
1071 UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound && value <= upper_bound);
1072 }
1073
1074 #undef UPDATE_ELEMENT_INDEX_BOUNDS
1075
UpdateSupports() const1076 void IntIntExprFunctionElement::UpdateSupports() const {
1077 if (initial_update_ || !expr1_->Contains(min_support1_) ||
1078 !expr1_->Contains(max_support1_) || !expr2_->Contains(min_support2_) ||
1079 !expr2_->Contains(max_support2_)) {
1080 const int64_t emax1 = expr1_->Max();
1081 const int64_t emax2 = expr2_->Max();
1082 int64_t min_value = ElementValue(emax1, emax2);
1083 int64_t max_value = min_value;
1084 int min_support1 = emax1;
1085 int max_support1 = emax1;
1086 int min_support2 = emax2;
1087 int max_support2 = emax2;
1088 for (const int64_t index1 : InitAndGetValues(expr1_iterator_)) {
1089 for (const int64_t index2 : InitAndGetValues(expr2_iterator_)) {
1090 const int64_t value = ElementValue(index1, index2);
1091 if (value > max_value) {
1092 max_value = value;
1093 max_support1 = index1;
1094 max_support2 = index2;
1095 } else if (value < min_value) {
1096 min_value = value;
1097 min_support1 = index1;
1098 min_support2 = index2;
1099 }
1100 }
1101 }
1102 Solver* s = solver();
1103 s->SaveAndSetValue(&min_, min_value);
1104 s->SaveAndSetValue(&min_support1_, min_support1);
1105 s->SaveAndSetValue(&min_support2_, min_support2);
1106 s->SaveAndSetValue(&max_, max_value);
1107 s->SaveAndSetValue(&max_support1_, max_support1);
1108 s->SaveAndSetValue(&max_support2_, max_support2);
1109 s->SaveAndSetValue(&initial_update_, false);
1110 }
1111 }
1112 } // namespace
1113
MakeElement(Solver::IndexEvaluator2 values,IntVar * const index1,IntVar * const index2)1114 IntExpr* Solver::MakeElement(Solver::IndexEvaluator2 values,
1115 IntVar* const index1, IntVar* const index2) {
1116 CHECK_EQ(this, index1->solver());
1117 CHECK_EQ(this, index2->solver());
1118 return RegisterIntExpr(RevAlloc(
1119 new IntIntExprFunctionElement(this, std::move(values), index1, index2)));
1120 }
1121
1122 // ---------- Generalized element ----------
1123
1124 // ----- IfThenElseCt -----
1125
1126 class IfThenElseCt : public CastConstraint {
1127 public:
IfThenElseCt(Solver * const solver,IntVar * const condition,IntExpr * const one,IntExpr * const zero,IntVar * const target)1128 IfThenElseCt(Solver* const solver, IntVar* const condition,
1129 IntExpr* const one, IntExpr* const zero, IntVar* const target)
1130 : CastConstraint(solver, target),
1131 condition_(condition),
1132 zero_(zero),
1133 one_(one) {}
1134
~IfThenElseCt()1135 ~IfThenElseCt() override {}
1136
Post()1137 void Post() override {
1138 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1139 condition_->WhenBound(demon);
1140 one_->WhenRange(demon);
1141 zero_->WhenRange(demon);
1142 target_var_->WhenRange(demon);
1143 }
1144
InitialPropagate()1145 void InitialPropagate() override {
1146 condition_->SetRange(0, 1);
1147 const int64_t target_var_min = target_var_->Min();
1148 const int64_t target_var_max = target_var_->Max();
1149 int64_t new_min = std::numeric_limits<int64_t>::min();
1150 int64_t new_max = std::numeric_limits<int64_t>::max();
1151 if (condition_->Max() == 0) {
1152 zero_->SetRange(target_var_min, target_var_max);
1153 zero_->Range(&new_min, &new_max);
1154 } else if (condition_->Min() == 1) {
1155 one_->SetRange(target_var_min, target_var_max);
1156 one_->Range(&new_min, &new_max);
1157 } else {
1158 if (target_var_max < zero_->Min() || target_var_min > zero_->Max()) {
1159 condition_->SetValue(1);
1160 one_->SetRange(target_var_min, target_var_max);
1161 one_->Range(&new_min, &new_max);
1162 } else if (target_var_max < one_->Min() || target_var_min > one_->Max()) {
1163 condition_->SetValue(0);
1164 zero_->SetRange(target_var_min, target_var_max);
1165 zero_->Range(&new_min, &new_max);
1166 } else {
1167 int64_t zl = 0;
1168 int64_t zu = 0;
1169 int64_t ol = 0;
1170 int64_t ou = 0;
1171 zero_->Range(&zl, &zu);
1172 one_->Range(&ol, &ou);
1173 new_min = std::min(zl, ol);
1174 new_max = std::max(zu, ou);
1175 }
1176 }
1177 target_var_->SetRange(new_min, new_max);
1178 }
1179
DebugString() const1180 std::string DebugString() const override {
1181 return absl::StrFormat("(%s ? %s : %s) == %s", condition_->DebugString(),
1182 one_->DebugString(), zero_->DebugString(),
1183 target_var_->DebugString());
1184 }
1185
Accept(ModelVisitor * const visitor) const1186 void Accept(ModelVisitor* const visitor) const override {}
1187
1188 private:
1189 IntVar* const condition_;
1190 IntExpr* const zero_;
1191 IntExpr* const one_;
1192 };
1193
1194 // ----- IntExprEvaluatorElementCt -----
1195
1196 // This constraint implements evaluator(index) == var. It is delayed such
1197 // that propagation only occurs when all variables have been touched.
1198 // The range of the evaluator is [range_start, range_end).
1199
1200 namespace {
1201 class IntExprEvaluatorElementCt : public CastConstraint {
1202 public:
1203 IntExprEvaluatorElementCt(Solver* const s, Solver::Int64ToIntVar evaluator,
1204 int64_t range_start, int64_t range_end,
1205 IntVar* const index, IntVar* const target_var);
~IntExprEvaluatorElementCt()1206 ~IntExprEvaluatorElementCt() override {}
1207
1208 void Post() override;
1209 void InitialPropagate() override;
1210
1211 void Propagate();
1212 void Update(int index);
1213 void UpdateExpr();
1214
1215 std::string DebugString() const override;
1216 void Accept(ModelVisitor* const visitor) const override;
1217
1218 protected:
1219 IntVar* const index_;
1220
1221 private:
1222 const Solver::Int64ToIntVar evaluator_;
1223 const int64_t range_start_;
1224 const int64_t range_end_;
1225 int min_support_;
1226 int max_support_;
1227 };
1228
IntExprEvaluatorElementCt(Solver * const s,Solver::Int64ToIntVar evaluator,int64_t range_start,int64_t range_end,IntVar * const index,IntVar * const target_var)1229 IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(
1230 Solver* const s, Solver::Int64ToIntVar evaluator, int64_t range_start,
1231 int64_t range_end, IntVar* const index, IntVar* const target_var)
1232 : CastConstraint(s, target_var),
1233 index_(index),
1234 evaluator_(std::move(evaluator)),
1235 range_start_(range_start),
1236 range_end_(range_end),
1237 min_support_(-1),
1238 max_support_(-1) {}
1239
Post()1240 void IntExprEvaluatorElementCt::Post() {
1241 Demon* const delayed_propagate_demon = MakeDelayedConstraintDemon0(
1242 solver(), this, &IntExprEvaluatorElementCt::Propagate, "Propagate");
1243 for (int i = range_start_; i < range_end_; ++i) {
1244 IntVar* const current_var = evaluator_(i);
1245 current_var->WhenRange(delayed_propagate_demon);
1246 Demon* const update_demon = MakeConstraintDemon1(
1247 solver(), this, &IntExprEvaluatorElementCt::Update, "Update", i);
1248 current_var->WhenRange(update_demon);
1249 }
1250 index_->WhenRange(delayed_propagate_demon);
1251 Demon* const update_expr_demon = MakeConstraintDemon0(
1252 solver(), this, &IntExprEvaluatorElementCt::UpdateExpr, "UpdateExpr");
1253 index_->WhenRange(update_expr_demon);
1254 Demon* const update_var_demon = MakeConstraintDemon0(
1255 solver(), this, &IntExprEvaluatorElementCt::Propagate, "UpdateVar");
1256
1257 target_var_->WhenRange(update_var_demon);
1258 }
1259
InitialPropagate()1260 void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }
1261
Propagate()1262 void IntExprEvaluatorElementCt::Propagate() {
1263 const int64_t emin = std::max(range_start_, index_->Min());
1264 const int64_t emax = std::min<int64_t>(range_end_ - 1, index_->Max());
1265 const int64_t vmin = target_var_->Min();
1266 const int64_t vmax = target_var_->Max();
1267 if (emin == emax) {
1268 index_->SetValue(emin); // in case it was reduced by the above min/max.
1269 evaluator_(emin)->SetRange(vmin, vmax);
1270 } else {
1271 int64_t nmin = emin;
1272 for (; nmin <= emax; nmin++) {
1273 // break if the intersection of
1274 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1275 // is non-empty.
1276 IntVar* const nmin_var = evaluator_(nmin);
1277 if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin) break;
1278 }
1279 int64_t nmax = emax;
1280 for (; nmin <= nmax; nmax--) {
1281 // break if the intersection of
1282 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1283 // is non-empty.
1284 IntExpr* const nmax_var = evaluator_(nmax);
1285 if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin) break;
1286 }
1287 index_->SetRange(nmin, nmax);
1288 if (nmin == nmax) {
1289 evaluator_(nmin)->SetRange(vmin, vmax);
1290 }
1291 }
1292 if (min_support_ == -1 || max_support_ == -1) {
1293 int min_support = -1;
1294 int max_support = -1;
1295 int64_t gmin = std::numeric_limits<int64_t>::max();
1296 int64_t gmax = std::numeric_limits<int64_t>::min();
1297 for (int i = index_->Min(); i <= index_->Max(); ++i) {
1298 IntExpr* const var_i = evaluator_(i);
1299 const int64_t vmin = var_i->Min();
1300 if (vmin < gmin) {
1301 gmin = vmin;
1302 }
1303 const int64_t vmax = var_i->Max();
1304 if (vmax > gmax) {
1305 gmax = vmax;
1306 }
1307 }
1308 solver()->SaveAndSetValue(&min_support_, min_support);
1309 solver()->SaveAndSetValue(&max_support_, max_support);
1310 target_var_->SetRange(gmin, gmax);
1311 }
1312 }
1313
Update(int index)1314 void IntExprEvaluatorElementCt::Update(int index) {
1315 if (index == min_support_ || index == max_support_) {
1316 solver()->SaveAndSetValue(&min_support_, -1);
1317 solver()->SaveAndSetValue(&max_support_, -1);
1318 }
1319 }
1320
UpdateExpr()1321 void IntExprEvaluatorElementCt::UpdateExpr() {
1322 if (!index_->Contains(min_support_) || !index_->Contains(max_support_)) {
1323 solver()->SaveAndSetValue(&min_support_, -1);
1324 solver()->SaveAndSetValue(&max_support_, -1);
1325 }
1326 }
1327
1328 namespace {
StringifyEvaluatorBare(const Solver::Int64ToIntVar & evaluator,int64_t range_start,int64_t range_end)1329 std::string StringifyEvaluatorBare(const Solver::Int64ToIntVar& evaluator,
1330 int64_t range_start, int64_t range_end) {
1331 std::string out;
1332 for (int64_t i = range_start; i < range_end; ++i) {
1333 if (i != range_start) {
1334 out += ", ";
1335 }
1336 out += absl::StrFormat("%d -> %s", i, evaluator(i)->DebugString());
1337 }
1338 return out;
1339 }
1340
StringifyInt64ToIntVar(const Solver::Int64ToIntVar & evaluator,int64_t range_begin,int64_t range_end)1341 std::string StringifyInt64ToIntVar(const Solver::Int64ToIntVar& evaluator,
1342 int64_t range_begin, int64_t range_end) {
1343 std::string out;
1344 if (range_end - range_begin > 10) {
1345 out = absl::StrFormat(
1346 "IntToIntVar(%s, ...%s)",
1347 StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),
1348 StringifyEvaluatorBare(evaluator, range_end - 5, range_end));
1349 } else {
1350 out = absl::StrFormat(
1351 "IntToIntVar(%s)",
1352 StringifyEvaluatorBare(evaluator, range_begin, range_end));
1353 }
1354 return out;
1355 }
1356 } // namespace
1357
DebugString() const1358 std::string IntExprEvaluatorElementCt::DebugString() const {
1359 return StringifyInt64ToIntVar(evaluator_, range_start_, range_end_);
1360 }
1361
Accept(ModelVisitor * const visitor) const1362 void IntExprEvaluatorElementCt::Accept(ModelVisitor* const visitor) const {
1363 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1364 visitor->VisitIntegerVariableEvaluatorArgument(
1365 ModelVisitor::kEvaluatorArgument, evaluator_);
1366 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1367 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1368 target_var_);
1369 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1370 }
1371
1372 // ----- IntExprArrayElementCt -----
1373
1374 // This constraint implements vars[index] == var. It is delayed such
1375 // that propagation only occurs when all variables have been touched.
1376
1377 class IntExprArrayElementCt : public IntExprEvaluatorElementCt {
1378 public:
1379 IntExprArrayElementCt(Solver* const s, std::vector<IntVar*> vars,
1380 IntVar* const index, IntVar* const target_var);
1381
1382 std::string DebugString() const override;
1383 void Accept(ModelVisitor* const visitor) const override;
1384
1385 private:
1386 const std::vector<IntVar*> vars_;
1387 };
1388
IntExprArrayElementCt(Solver * const s,std::vector<IntVar * > vars,IntVar * const index,IntVar * const target_var)1389 IntExprArrayElementCt::IntExprArrayElementCt(Solver* const s,
1390 std::vector<IntVar*> vars,
1391 IntVar* const index,
1392 IntVar* const target_var)
1393 : IntExprEvaluatorElementCt(
1394 s, [this](int64_t idx) { return vars_[idx]; }, 0, vars.size(), index,
1395 target_var),
1396 vars_(std::move(vars)) {}
1397
DebugString() const1398 std::string IntExprArrayElementCt::DebugString() const {
1399 int64_t size = vars_.size();
1400 if (size > 10) {
1401 return absl::StrFormat(
1402 "IntExprArrayElement(var array of size %d, %s) == %s", size,
1403 index_->DebugString(), target_var_->DebugString());
1404 } else {
1405 return absl::StrFormat("IntExprArrayElement([%s], %s) == %s",
1406 JoinDebugStringPtr(vars_, ", "),
1407 index_->DebugString(), target_var_->DebugString());
1408 }
1409 }
1410
Accept(ModelVisitor * const visitor) const1411 void IntExprArrayElementCt::Accept(ModelVisitor* const visitor) const {
1412 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1413 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1414 vars_);
1415 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1416 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1417 target_var_);
1418 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1419 }
1420
1421 // ----- IntExprArrayElementCstCt -----
1422
1423 // This constraint implements vars[index] == constant.
1424
1425 class IntExprArrayElementCstCt : public Constraint {
1426 public:
IntExprArrayElementCstCt(Solver * const s,const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1427 IntExprArrayElementCstCt(Solver* const s, const std::vector<IntVar*>& vars,
1428 IntVar* const index, int64_t target)
1429 : Constraint(s),
1430 vars_(vars),
1431 index_(index),
1432 target_(target),
1433 demons_(vars.size()) {}
1434
~IntExprArrayElementCstCt()1435 ~IntExprArrayElementCstCt() override {}
1436
Post()1437 void Post() override {
1438 for (int i = 0; i < vars_.size(); ++i) {
1439 demons_[i] = MakeConstraintDemon1(
1440 solver(), this, &IntExprArrayElementCstCt::Propagate, "Propagate", i);
1441 vars_[i]->WhenDomain(demons_[i]);
1442 }
1443 Demon* const index_demon = MakeConstraintDemon0(
1444 solver(), this, &IntExprArrayElementCstCt::PropagateIndex,
1445 "PropagateIndex");
1446 index_->WhenBound(index_demon);
1447 }
1448
InitialPropagate()1449 void InitialPropagate() override {
1450 for (int i = 0; i < vars_.size(); ++i) {
1451 Propagate(i);
1452 }
1453 PropagateIndex();
1454 }
1455
Propagate(int index)1456 void Propagate(int index) {
1457 if (!vars_[index]->Contains(target_)) {
1458 index_->RemoveValue(index);
1459 demons_[index]->inhibit(solver());
1460 }
1461 }
1462
PropagateIndex()1463 void PropagateIndex() {
1464 if (index_->Bound()) {
1465 vars_[index_->Min()]->SetValue(target_);
1466 }
1467 }
1468
DebugString() const1469 std::string DebugString() const override {
1470 return absl::StrFormat("IntExprArrayElement([%s], %s) == %d",
1471 JoinDebugStringPtr(vars_, ", "),
1472 index_->DebugString(), target_);
1473 }
1474
Accept(ModelVisitor * const visitor) const1475 void Accept(ModelVisitor* const visitor) const override {
1476 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1477 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1478 vars_);
1479 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1480 index_);
1481 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1482 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1483 }
1484
1485 private:
1486 const std::vector<IntVar*> vars_;
1487 IntVar* const index_;
1488 const int64_t target_;
1489 std::vector<Demon*> demons_;
1490 };
1491
1492 // This constraint implements index == position(constant in vars).
1493
1494 class IntExprIndexOfCt : public Constraint {
1495 public:
IntExprIndexOfCt(Solver * const s,const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1496 IntExprIndexOfCt(Solver* const s, const std::vector<IntVar*>& vars,
1497 IntVar* const index, int64_t target)
1498 : Constraint(s),
1499 vars_(vars),
1500 index_(index),
1501 target_(target),
1502 demons_(vars_.size()),
1503 index_iterator_(index->MakeHoleIterator(true)) {}
1504
~IntExprIndexOfCt()1505 ~IntExprIndexOfCt() override {}
1506
Post()1507 void Post() override {
1508 for (int i = 0; i < vars_.size(); ++i) {
1509 demons_[i] = MakeConstraintDemon1(
1510 solver(), this, &IntExprIndexOfCt::Propagate, "Propagate", i);
1511 vars_[i]->WhenDomain(demons_[i]);
1512 }
1513 Demon* const index_demon = MakeConstraintDemon0(
1514 solver(), this, &IntExprIndexOfCt::PropagateIndex, "PropagateIndex");
1515 index_->WhenDomain(index_demon);
1516 }
1517
InitialPropagate()1518 void InitialPropagate() override {
1519 for (int i = 0; i < vars_.size(); ++i) {
1520 if (!index_->Contains(i)) {
1521 vars_[i]->RemoveValue(target_);
1522 } else if (!vars_[i]->Contains(target_)) {
1523 index_->RemoveValue(i);
1524 demons_[i]->inhibit(solver());
1525 } else if (vars_[i]->Bound()) {
1526 index_->SetValue(i);
1527 demons_[i]->inhibit(solver());
1528 }
1529 }
1530 }
1531
Propagate(int index)1532 void Propagate(int index) {
1533 if (!vars_[index]->Contains(target_)) {
1534 index_->RemoveValue(index);
1535 demons_[index]->inhibit(solver());
1536 } else if (vars_[index]->Bound()) {
1537 index_->SetValue(index);
1538 }
1539 }
1540
PropagateIndex()1541 void PropagateIndex() {
1542 const int64_t oldmax = index_->OldMax();
1543 const int64_t vmin = index_->Min();
1544 const int64_t vmax = index_->Max();
1545 for (int64_t value = index_->OldMin(); value < vmin; ++value) {
1546 vars_[value]->RemoveValue(target_);
1547 demons_[value]->inhibit(solver());
1548 }
1549 for (const int64_t value : InitAndGetValues(index_iterator_)) {
1550 vars_[value]->RemoveValue(target_);
1551 demons_[value]->inhibit(solver());
1552 }
1553 for (int64_t value = vmax + 1; value <= oldmax; ++value) {
1554 vars_[value]->RemoveValue(target_);
1555 demons_[value]->inhibit(solver());
1556 }
1557 if (index_->Bound()) {
1558 vars_[index_->Min()]->SetValue(target_);
1559 }
1560 }
1561
DebugString() const1562 std::string DebugString() const override {
1563 return absl::StrFormat("IntExprIndexOf([%s], %s) == %d",
1564 JoinDebugStringPtr(vars_, ", "),
1565 index_->DebugString(), target_);
1566 }
1567
Accept(ModelVisitor * const visitor) const1568 void Accept(ModelVisitor* const visitor) const override {
1569 visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);
1570 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1571 vars_);
1572 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1573 index_);
1574 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1575 visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);
1576 }
1577
1578 private:
1579 const std::vector<IntVar*> vars_;
1580 IntVar* const index_;
1581 const int64_t target_;
1582 std::vector<Demon*> demons_;
1583 IntVarIterator* const index_iterator_;
1584 };
1585
1586 // Factory helper.
1587
MakeElementEqualityFunc(Solver * const solver,const std::vector<int64_t> & vals,IntVar * const index,IntVar * const target)1588 Constraint* MakeElementEqualityFunc(Solver* const solver,
1589 const std::vector<int64_t>& vals,
1590 IntVar* const index, IntVar* const target) {
1591 if (index->Bound()) {
1592 const int64_t val = index->Min();
1593 if (val < 0 || val >= vals.size()) {
1594 return solver->MakeFalseConstraint();
1595 } else {
1596 return solver->MakeEquality(target, vals[val]);
1597 }
1598 } else {
1599 if (IsIncreasingContiguous(vals)) {
1600 return solver->MakeEquality(target, solver->MakeSum(index, vals[0]));
1601 } else {
1602 return solver->RevAlloc(
1603 new IntElementConstraint(solver, vals, index, target));
1604 }
1605 }
1606 }
1607 } // namespace
1608
MakeIfThenElseCt(IntVar * const condition,IntExpr * const then_expr,IntExpr * const else_expr,IntVar * const target_var)1609 Constraint* Solver::MakeIfThenElseCt(IntVar* const condition,
1610 IntExpr* const then_expr,
1611 IntExpr* const else_expr,
1612 IntVar* const target_var) {
1613 return RevAlloc(
1614 new IfThenElseCt(this, condition, then_expr, else_expr, target_var));
1615 }
1616
MakeElement(const std::vector<IntVar * > & vars,IntVar * const index)1617 IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars,
1618 IntVar* const index) {
1619 if (index->Bound()) {
1620 return vars[index->Min()];
1621 }
1622 const int size = vars.size();
1623 if (AreAllBound(vars)) {
1624 std::vector<int64_t> values(size);
1625 for (int i = 0; i < size; ++i) {
1626 values[i] = vars[i]->Value();
1627 }
1628 return MakeElement(values, index);
1629 }
1630 if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
1631 index->Min() >= 0 && index->Max() < vars.size()) {
1632 // Let's get the index between 0 and 1.
1633 IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
1634 IntVar* const zero = vars[index->Min()];
1635 IntVar* const one = vars[index->Max()];
1636 const std::string name = absl::StrFormat(
1637 "ElementVar([%s], %s)", JoinNamePtr(vars, ", "), index->name());
1638 IntVar* const target = MakeIntVar(std::min(zero->Min(), one->Min()),
1639 std::max(zero->Max(), one->Max()), name);
1640 AddConstraint(
1641 RevAlloc(new IfThenElseCt(this, scaled_index, one, zero, target)));
1642 return target;
1643 }
1644 int64_t emin = std::numeric_limits<int64_t>::max();
1645 int64_t emax = std::numeric_limits<int64_t>::min();
1646 std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
1647 for (const int64_t index_value : InitAndGetValues(iterator.get())) {
1648 if (index_value >= 0 && index_value < size) {
1649 emin = std::min(emin, vars[index_value]->Min());
1650 emax = std::max(emax, vars[index_value]->Max());
1651 }
1652 }
1653 const std::string vname =
1654 size > 10 ? absl::StrFormat("ElementVar(var array of size %d, %s)", size,
1655 index->DebugString())
1656 : absl::StrFormat("ElementVar([%s], %s)",
1657 JoinNamePtr(vars, ", "), index->name());
1658 IntVar* const element_var = MakeIntVar(emin, emax, vname);
1659 AddConstraint(
1660 RevAlloc(new IntExprArrayElementCt(this, vars, index, element_var)));
1661 return element_var;
1662 }
1663
MakeElement(Int64ToIntVar vars,int64_t range_start,int64_t range_end,IntVar * argument)1664 IntExpr* Solver::MakeElement(Int64ToIntVar vars, int64_t range_start,
1665 int64_t range_end, IntVar* argument) {
1666 const std::string index_name =
1667 !argument->name().empty() ? argument->name() : argument->DebugString();
1668 const std::string vname = absl::StrFormat(
1669 "ElementVar(%s, %s)",
1670 StringifyInt64ToIntVar(vars, range_start, range_end), index_name);
1671 IntVar* const element_var =
1672 MakeIntVar(std::numeric_limits<int64_t>::min(),
1673 std::numeric_limits<int64_t>::max(), vname);
1674 IntExprEvaluatorElementCt* evaluation_ct = new IntExprEvaluatorElementCt(
1675 this, std::move(vars), range_start, range_end, argument, element_var);
1676 AddConstraint(RevAlloc(evaluation_ct));
1677 evaluation_ct->Propagate();
1678 return element_var;
1679 }
1680
MakeElementEquality(const std::vector<int64_t> & vals,IntVar * const index,IntVar * const target)1681 Constraint* Solver::MakeElementEquality(const std::vector<int64_t>& vals,
1682 IntVar* const index,
1683 IntVar* const target) {
1684 return MakeElementEqualityFunc(this, vals, index, target);
1685 }
1686
MakeElementEquality(const std::vector<int> & vals,IntVar * const index,IntVar * const target)1687 Constraint* Solver::MakeElementEquality(const std::vector<int>& vals,
1688 IntVar* const index,
1689 IntVar* const target) {
1690 return MakeElementEqualityFunc(this, ToInt64Vector(vals), index, target);
1691 }
1692
MakeElementEquality(const std::vector<IntVar * > & vars,IntVar * const index,IntVar * const target)1693 Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1694 IntVar* const index,
1695 IntVar* const target) {
1696 if (AreAllBound(vars)) {
1697 std::vector<int64_t> values(vars.size());
1698 for (int i = 0; i < vars.size(); ++i) {
1699 values[i] = vars[i]->Value();
1700 }
1701 return MakeElementEquality(values, index, target);
1702 }
1703 if (index->Bound()) {
1704 const int64_t val = index->Min();
1705 if (val < 0 || val >= vars.size()) {
1706 return MakeFalseConstraint();
1707 } else {
1708 return MakeEquality(target, vars[val]);
1709 }
1710 } else {
1711 if (target->Bound()) {
1712 return RevAlloc(
1713 new IntExprArrayElementCstCt(this, vars, index, target->Min()));
1714 } else {
1715 return RevAlloc(new IntExprArrayElementCt(this, vars, index, target));
1716 }
1717 }
1718 }
1719
MakeElementEquality(const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1720 Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1721 IntVar* const index, int64_t target) {
1722 if (AreAllBound(vars)) {
1723 std::vector<int> valid_indices;
1724 for (int i = 0; i < vars.size(); ++i) {
1725 if (vars[i]->Value() == target) {
1726 valid_indices.push_back(i);
1727 }
1728 }
1729 return MakeMemberCt(index, valid_indices);
1730 }
1731 if (index->Bound()) {
1732 const int64_t pos = index->Min();
1733 if (pos >= 0 && pos < vars.size()) {
1734 IntVar* const var = vars[pos];
1735 return MakeEquality(var, target);
1736 } else {
1737 return MakeFalseConstraint();
1738 }
1739 } else {
1740 return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));
1741 }
1742 }
1743
MakeIndexOfConstraint(const std::vector<IntVar * > & vars,IntVar * const index,int64_t target)1744 Constraint* Solver::MakeIndexOfConstraint(const std::vector<IntVar*>& vars,
1745 IntVar* const index, int64_t target) {
1746 if (index->Bound()) {
1747 const int64_t pos = index->Min();
1748 if (pos >= 0 && pos < vars.size()) {
1749 IntVar* const var = vars[pos];
1750 return MakeEquality(var, target);
1751 } else {
1752 return MakeFalseConstraint();
1753 }
1754 } else {
1755 return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));
1756 }
1757 }
1758
MakeIndexExpression(const std::vector<IntVar * > & vars,int64_t value)1759 IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
1760 int64_t value) {
1761 IntExpr* const cache = model_cache_->FindVarArrayConstantExpression(
1762 vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX);
1763 if (cache != nullptr) {
1764 return cache->Var();
1765 } else {
1766 const std::string name =
1767 absl::StrFormat("Index(%s, %d)", JoinNamePtr(vars, ", "), value);
1768 IntVar* const index = MakeIntVar(0, vars.size() - 1, name);
1769 AddConstraint(MakeIndexOfConstraint(vars, index, value));
1770 model_cache_->InsertVarArrayConstantExpression(
1771 index, vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX);
1772 return index;
1773 }
1774 }
1775 } // namespace operations_research
1776