1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #include <cstdint>
15 #include <string>
16 #include <vector>
17 
18 #include "ortools/base/commandlineflags.h"
19 #include "ortools/base/integral_types.h"
20 #include "ortools/base/logging.h"
21 #include "ortools/base/stl_util.h"
22 #include "ortools/constraint_solver/constraint_solver.h"
23 #include "ortools/constraint_solver/constraint_solveri.h"
24 
25 ABSL_DECLARE_FLAG(int, cache_initial_size);
26 ABSL_FLAG(bool, cp_disable_cache, false, "Disable caching of model objects");
27 
28 namespace operations_research {
29 // ----- ModelCache -----
30 
ModelCache(Solver * const solver)31 ModelCache::ModelCache(Solver* const solver) : solver_(solver) {}
32 
~ModelCache()33 ModelCache::~ModelCache() {}
34 
solver() const35 Solver* ModelCache::solver() const { return solver_; }
36 
37 namespace {
38 // ----- Helpers -----
39 
40 template <class T>
IsEqual(const T & a1,const T & a2)41 bool IsEqual(const T& a1, const T& a2) {
42   return a1 == a2;
43 }
44 
45 template <class T>
IsEqual(const std::vector<T * > & a1,const std::vector<T * > & a2)46 bool IsEqual(const std::vector<T*>& a1, const std::vector<T*>& a2) {
47   if (a1.size() != a2.size()) {
48     return false;
49   }
50   for (int i = 0; i < a1.size(); ++i) {
51     if (a1[i] != a2[i]) {
52       return false;
53     }
54   }
55   return true;
56 }
57 
58 template <class A1, class A2>
Hash2(const A1 & a1,const A2 & a2)59 uint64_t Hash2(const A1& a1, const A2& a2) {
60   uint64_t a = Hash1(a1);
61   uint64_t b = uint64_t{0xe08c1d668b756f82};  // more of the golden ratio
62   uint64_t c = Hash1(a2);
63   mix(a, b, c);
64   return c;
65 }
66 
67 template <class A1, class A2, class A3>
Hash3(const A1 & a1,const A2 & a2,const A3 & a3)68 uint64_t Hash3(const A1& a1, const A2& a2, const A3& a3) {
69   uint64_t a = Hash1(a1);
70   uint64_t b = Hash1(a2);
71   uint64_t c = Hash1(a3);
72   mix(a, b, c);
73   return c;
74 }
75 
76 template <class A1, class A2, class A3, class A4>
Hash4(const A1 & a1,const A2 & a2,const A3 & a3,const A4 & a4)77 uint64_t Hash4(const A1& a1, const A2& a2, const A3& a3, const A4& a4) {
78   uint64_t a = Hash1(a1);
79   uint64_t b = Hash1(a2);
80   uint64_t c = Hash2(a3, a4);
81   mix(a, b, c);
82   return c;
83 }
84 
85 template <class C>
Double(C *** array_ptr,int * size_ptr)86 void Double(C*** array_ptr, int* size_ptr) {
87   DCHECK(array_ptr != nullptr);
88   DCHECK(size_ptr != nullptr);
89   C** const old_cell_array = *array_ptr;
90   const int old_size = *size_ptr;
91   (*size_ptr) *= 2;
92   (*array_ptr) = new C*[(*size_ptr)];
93   memset(*array_ptr, 0, (*size_ptr) * sizeof(**array_ptr));
94   for (int i = 0; i < old_size; ++i) {
95     C* tmp = old_cell_array[i];
96     while (tmp != nullptr) {
97       C* const to_reinsert = tmp;
98       tmp = tmp->next();
99       const uint64_t position = to_reinsert->Hash() % (*size_ptr);
100       to_reinsert->set_next((*array_ptr)[position]);
101       (*array_ptr)[position] = to_reinsert;
102     }
103   }
104   delete[](old_cell_array);
105 }
106 
107 // ----- Cache objects built with 1 object -----
108 
109 template <class C, class A1>
110 class Cache1 {
111  public:
Cache1()112   Cache1()
113       : array_(new Cell*[absl::GetFlag(FLAGS_cache_initial_size)]),
114         size_(absl::GetFlag(FLAGS_cache_initial_size)),
115         num_items_(0) {
116     memset(array_, 0, sizeof(*array_) * size_);
117   }
118 
~Cache1()119   ~Cache1() {
120     for (int i = 0; i < size_; ++i) {
121       Cell* tmp = array_[i];
122       while (tmp != nullptr) {
123         Cell* const to_delete = tmp;
124         tmp = tmp->next();
125         delete to_delete;
126       }
127     }
128     delete[] array_;
129   }
130 
Clear()131   void Clear() {
132     for (int i = 0; i < size_; ++i) {
133       Cell* tmp = array_[i];
134       while (tmp != nullptr) {
135         Cell* const to_delete = tmp;
136         tmp = tmp->next();
137         delete to_delete;
138       }
139       array_[i] = nullptr;
140     }
141   }
142 
Find(const A1 & a1) const143   C* Find(const A1& a1) const {
144     uint64_t code = Hash1(a1) % size_;
145     Cell* tmp = array_[code];
146     while (tmp) {
147       C* const result = tmp->ReturnsIfEqual(a1);
148       if (result != nullptr) {
149         return result;
150       }
151       tmp = tmp->next();
152     }
153     return nullptr;
154   }
155 
UnsafeInsert(const A1 & a1,C * const c)156   void UnsafeInsert(const A1& a1, C* const c) {
157     const int position = Hash1(a1) % size_;
158     Cell* const cell = new Cell(a1, c, array_[position]);
159     array_[position] = cell;
160     if (++num_items_ > 2 * size_) {
161       Double(&array_, &size_);
162     }
163   }
164 
165  private:
166   class Cell {
167    public:
Cell(const A1 & a1,C * const container,Cell * const next)168     Cell(const A1& a1, C* const container, Cell* const next)
169         : a1_(a1), container_(container), next_(next) {}
170 
ReturnsIfEqual(const A1 & a1) const171     C* ReturnsIfEqual(const A1& a1) const {
172       if (IsEqual(a1_, a1)) {
173         return container_;
174       }
175       return nullptr;
176     }
177 
Hash() const178     uint64_t Hash() const { return Hash1(a1_); }
179 
set_next(Cell * const next)180     void set_next(Cell* const next) { next_ = next; }
181 
next() const182     Cell* next() const { return next_; }
183 
184    private:
185     const A1 a1_;
186     C* const container_;
187     Cell* next_;
188   };
189 
190   Cell** array_;
191   int size_;
192   int num_items_;
193 };
194 
195 // ----- Cache objects built with 2 objects -----
196 
197 template <class C, class A1, class A2>
198 class Cache2 {
199  public:
Cache2()200   Cache2()
201       : array_(new Cell*[absl::GetFlag(FLAGS_cache_initial_size)]),
202         size_(absl::GetFlag(FLAGS_cache_initial_size)),
203         num_items_(0) {
204     memset(array_, 0, sizeof(*array_) * size_);
205   }
206 
~Cache2()207   ~Cache2() {
208     for (int i = 0; i < size_; ++i) {
209       Cell* tmp = array_[i];
210       while (tmp != nullptr) {
211         Cell* const to_delete = tmp;
212         tmp = tmp->next();
213         delete to_delete;
214       }
215     }
216     delete[] array_;
217   }
218 
Clear()219   void Clear() {
220     for (int i = 0; i < size_; ++i) {
221       Cell* tmp = array_[i];
222       while (tmp != nullptr) {
223         Cell* const to_delete = tmp;
224         tmp = tmp->next();
225         delete to_delete;
226       }
227       array_[i] = nullptr;
228     }
229   }
230 
Find(const A1 & a1,const A2 & a2) const231   C* Find(const A1& a1, const A2& a2) const {
232     uint64_t code = Hash2(a1, a2) % size_;
233     Cell* tmp = array_[code];
234     while (tmp) {
235       C* const result = tmp->ReturnsIfEqual(a1, a2);
236       if (result != nullptr) {
237         return result;
238       }
239       tmp = tmp->next();
240     }
241     return nullptr;
242   }
243 
UnsafeInsert(const A1 & a1,const A2 & a2,C * const c)244   void UnsafeInsert(const A1& a1, const A2& a2, C* const c) {
245     const int position = Hash2(a1, a2) % size_;
246     Cell* const cell = new Cell(a1, a2, c, array_[position]);
247     array_[position] = cell;
248     if (++num_items_ > 2 * size_) {
249       Double(&array_, &size_);
250     }
251   }
252 
253  private:
254   class Cell {
255    public:
Cell(const A1 & a1,const A2 & a2,C * const container,Cell * const next)256     Cell(const A1& a1, const A2& a2, C* const container, Cell* const next)
257         : a1_(a1), a2_(a2), container_(container), next_(next) {}
258 
ReturnsIfEqual(const A1 & a1,const A2 & a2) const259     C* ReturnsIfEqual(const A1& a1, const A2& a2) const {
260       if (IsEqual(a1_, a1) && IsEqual(a2_, a2)) {
261         return container_;
262       }
263       return nullptr;
264     }
265 
Hash() const266     uint64_t Hash() const { return Hash2(a1_, a2_); }
267 
set_next(Cell * const next)268     void set_next(Cell* const next) { next_ = next; }
269 
next() const270     Cell* next() const { return next_; }
271 
272    private:
273     const A1 a1_;
274     const A2 a2_;
275     C* const container_;
276     Cell* next_;
277   };
278 
279   Cell** array_;
280   int size_;
281   int num_items_;
282 };
283 
284 // ----- Cache objects built with 2 objects -----
285 
286 template <class C, class A1, class A2, class A3>
287 class Cache3 {
288  public:
Cache3()289   Cache3()
290       : array_(new Cell*[absl::GetFlag(FLAGS_cache_initial_size)]),
291         size_(absl::GetFlag(FLAGS_cache_initial_size)),
292         num_items_(0) {
293     memset(array_, 0, sizeof(*array_) * size_);
294   }
295 
~Cache3()296   ~Cache3() {
297     for (int i = 0; i < size_; ++i) {
298       Cell* tmp = array_[i];
299       while (tmp != nullptr) {
300         Cell* const to_delete = tmp;
301         tmp = tmp->next();
302         delete to_delete;
303       }
304     }
305     delete[] array_;
306   }
307 
Clear()308   void Clear() {
309     for (int i = 0; i < size_; ++i) {
310       Cell* tmp = array_[i];
311       while (tmp != nullptr) {
312         Cell* const to_delete = tmp;
313         tmp = tmp->next();
314         delete to_delete;
315       }
316       array_[i] = nullptr;
317     }
318   }
319 
Find(const A1 & a1,const A2 & a2,const A3 & a3) const320   C* Find(const A1& a1, const A2& a2, const A3& a3) const {
321     uint64_t code = Hash3(a1, a2, a3) % size_;
322     Cell* tmp = array_[code];
323     while (tmp) {
324       C* const result = tmp->ReturnsIfEqual(a1, a2, a3);
325       if (result != nullptr) {
326         return result;
327       }
328       tmp = tmp->next();
329     }
330     return nullptr;
331   }
332 
UnsafeInsert(const A1 & a1,const A2 & a2,const A3 & a3,C * const c)333   void UnsafeInsert(const A1& a1, const A2& a2, const A3& a3, C* const c) {
334     const int position = Hash3(a1, a2, a3) % size_;
335     Cell* const cell = new Cell(a1, a2, a3, c, array_[position]);
336     array_[position] = cell;
337     if (++num_items_ > 2 * size_) {
338       Double(&array_, &size_);
339     }
340   }
341 
342  private:
343   class Cell {
344    public:
Cell(const A1 & a1,const A2 & a2,const A3 & a3,C * const container,Cell * const next)345     Cell(const A1& a1, const A2& a2, const A3& a3, C* const container,
346          Cell* const next)
347         : a1_(a1), a2_(a2), a3_(a3), container_(container), next_(next) {}
348 
ReturnsIfEqual(const A1 & a1,const A2 & a2,const A3 & a3) const349     C* ReturnsIfEqual(const A1& a1, const A2& a2, const A3& a3) const {
350       if (IsEqual(a1_, a1) && IsEqual(a2_, a2) && IsEqual(a3_, a3)) {
351         return container_;
352       }
353       return nullptr;
354     }
355 
Hash() const356     uint64_t Hash() const { return Hash3(a1_, a2_, a3_); }
357 
set_next(Cell * const next)358     void set_next(Cell* const next) { next_ = next; }
359 
next() const360     Cell* next() const { return next_; }
361 
362    private:
363     const A1 a1_;
364     const A2 a2_;
365     const A3 a3_;
366     C* const container_;
367     Cell* next_;
368   };
369 
370   Cell** array_;
371   int size_;
372   int num_items_;
373 };
374 
375 // ----- Model Cache -----
376 
377 class NonReversibleCache : public ModelCache {
378  public:
379   typedef Cache1<IntExpr, IntExpr*> ExprIntExprCache;
380   typedef Cache1<IntExpr, std::vector<IntVar*> > VarArrayIntExprCache;
381 
382   typedef Cache2<Constraint, IntVar*, int64_t> VarConstantConstraintCache;
383   typedef Cache2<Constraint, IntExpr*, IntExpr*> ExprExprConstraintCache;
384   typedef Cache2<IntExpr, IntVar*, int64_t> VarConstantIntExprCache;
385   typedef Cache2<IntExpr, IntExpr*, int64_t> ExprConstantIntExprCache;
386   typedef Cache2<IntExpr, IntExpr*, IntExpr*> ExprExprIntExprCache;
387   typedef Cache2<IntExpr, IntVar*, const std::vector<int64_t>&>
388       VarConstantArrayIntExprCache;
389   typedef Cache2<IntExpr, std::vector<IntVar*>, const std::vector<int64_t>&>
390       VarArrayConstantArrayIntExprCache;
391   typedef Cache2<IntExpr, std::vector<IntVar*>, int64_t>
392       VarArrayConstantIntExprCache;
393 
394   typedef Cache3<IntExpr, IntVar*, int64_t, int64_t>
395       VarConstantConstantIntExprCache;
396   typedef Cache3<Constraint, IntVar*, int64_t, int64_t>
397       VarConstantConstantConstraintCache;
398   typedef Cache3<IntExpr, IntExpr*, IntExpr*, int64_t>
399       ExprExprConstantIntExprCache;
400 
NonReversibleCache(Solver * const solver)401   explicit NonReversibleCache(Solver* const solver)
402       : ModelCache(solver), void_constraints_(VOID_CONSTRAINT_MAX, nullptr) {
403     for (int i = 0; i < VAR_CONSTANT_CONSTRAINT_MAX; ++i) {
404       var_constant_constraints_.push_back(new VarConstantConstraintCache);
405     }
406     for (int i = 0; i < EXPR_EXPR_CONSTRAINT_MAX; ++i) {
407       expr_expr_constraints_.push_back(new ExprExprConstraintCache);
408     }
409     for (int i = 0; i < VAR_CONSTANT_CONSTANT_CONSTRAINT_MAX; ++i) {
410       var_constant_constant_constraints_.push_back(
411           new VarConstantConstantConstraintCache);
412     }
413     for (int i = 0; i < EXPR_EXPRESSION_MAX; ++i) {
414       expr_expressions_.push_back(new ExprIntExprCache);
415     }
416     for (int i = 0; i < EXPR_CONSTANT_EXPRESSION_MAX; ++i) {
417       expr_constant_expressions_.push_back(new ExprConstantIntExprCache);
418     }
419     for (int i = 0; i < EXPR_EXPR_EXPRESSION_MAX; ++i) {
420       expr_expr_expressions_.push_back(new ExprExprIntExprCache);
421     }
422     for (int i = 0; i < VAR_CONSTANT_CONSTANT_EXPRESSION_MAX; ++i) {
423       var_constant_constant_expressions_.push_back(
424           new VarConstantConstantIntExprCache);
425     }
426     for (int i = 0; i < VAR_CONSTANT_ARRAY_EXPRESSION_MAX; ++i) {
427       var_constant_array_expressions_.push_back(
428           new VarConstantArrayIntExprCache);
429     }
430     for (int i = 0; i < VAR_ARRAY_EXPRESSION_MAX; ++i) {
431       var_array_expressions_.push_back(new VarArrayIntExprCache);
432     }
433     for (int i = 0; i < VAR_ARRAY_CONSTANT_ARRAY_EXPRESSION_MAX; ++i) {
434       var_array_constant_array_expressions_.push_back(
435           new VarArrayConstantArrayIntExprCache);
436     }
437     for (int i = 0; i < VAR_ARRAY_CONSTANT_EXPRESSION_MAX; ++i) {
438       var_array_constant_expressions_.push_back(
439           new VarArrayConstantIntExprCache);
440     }
441     for (int i = 0; i < EXPR_EXPR_CONSTANT_EXPRESSION_MAX; ++i) {
442       expr_expr_constant_expressions_.push_back(
443           new ExprExprConstantIntExprCache);
444     }
445   }
446 
~NonReversibleCache()447   ~NonReversibleCache() override {
448     gtl::STLDeleteElements(&var_constant_constraints_);
449     gtl::STLDeleteElements(&expr_expr_constraints_);
450     gtl::STLDeleteElements(&var_constant_constant_constraints_);
451     gtl::STLDeleteElements(&expr_expressions_);
452     gtl::STLDeleteElements(&expr_constant_expressions_);
453     gtl::STLDeleteElements(&expr_expr_expressions_);
454     gtl::STLDeleteElements(&var_constant_constant_expressions_);
455     gtl::STLDeleteElements(&var_constant_array_expressions_);
456     gtl::STLDeleteElements(&var_array_expressions_);
457     gtl::STLDeleteElements(&var_array_constant_array_expressions_);
458     gtl::STLDeleteElements(&var_array_constant_expressions_);
459     gtl::STLDeleteElements(&expr_expr_constant_expressions_);
460   }
461 
Clear()462   void Clear() override {
463     for (int i = 0; i < VAR_CONSTANT_CONSTRAINT_MAX; ++i) {
464       var_constant_constraints_[i]->Clear();
465     }
466     for (int i = 0; i < EXPR_EXPR_CONSTRAINT_MAX; ++i) {
467       expr_expr_constraints_[i]->Clear();
468     }
469     for (int i = 0; i < VAR_CONSTANT_CONSTANT_CONSTRAINT_MAX; ++i) {
470       var_constant_constant_constraints_[i]->Clear();
471     }
472     for (int i = 0; i < EXPR_EXPRESSION_MAX; ++i) {
473       expr_expressions_[i]->Clear();
474     }
475     for (int i = 0; i < EXPR_CONSTANT_EXPRESSION_MAX; ++i) {
476       expr_constant_expressions_[i]->Clear();
477     }
478     for (int i = 0; i < EXPR_EXPR_EXPRESSION_MAX; ++i) {
479       expr_expr_expressions_[i]->Clear();
480     }
481     for (int i = 0; i < VAR_CONSTANT_CONSTANT_EXPRESSION_MAX; ++i) {
482       var_constant_constant_expressions_[i]->Clear();
483     }
484     for (int i = 0; i < VAR_CONSTANT_ARRAY_EXPRESSION_MAX; ++i) {
485       var_constant_array_expressions_[i]->Clear();
486     }
487     for (int i = 0; i < VAR_ARRAY_EXPRESSION_MAX; ++i) {
488       var_array_expressions_[i]->Clear();
489     }
490     for (int i = 0; i < VAR_ARRAY_CONSTANT_ARRAY_EXPRESSION_MAX; ++i) {
491       var_array_constant_array_expressions_[i]->Clear();
492     }
493     for (int i = 0; i < VAR_ARRAY_CONSTANT_EXPRESSION_MAX; ++i) {
494       var_array_constant_expressions_[i]->Clear();
495     }
496     for (int i = 0; i < EXPR_EXPR_CONSTANT_EXPRESSION_MAX; ++i) {
497       expr_expr_constant_expressions_[i]->Clear();
498     }
499   }
500 
501   // Void Constraint.-
502 
FindVoidConstraint(VoidConstraintType type) const503   Constraint* FindVoidConstraint(VoidConstraintType type) const override {
504     DCHECK_GE(type, 0);
505     DCHECK_LT(type, VOID_CONSTRAINT_MAX);
506     return void_constraints_[type];
507   }
508 
InsertVoidConstraint(Constraint * const ct,VoidConstraintType type)509   void InsertVoidConstraint(Constraint* const ct,
510                             VoidConstraintType type) override {
511     DCHECK_GE(type, 0);
512     DCHECK_LT(type, VOID_CONSTRAINT_MAX);
513     DCHECK(ct != nullptr);
514     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
515         !absl::GetFlag(FLAGS_cp_disable_cache)) {
516       void_constraints_[type] = ct;
517     }
518   }
519 
520   // VarConstantConstraint.
521 
FindVarConstantConstraint(IntVar * const var,int64_t value,VarConstantConstraintType type) const522   Constraint* FindVarConstantConstraint(
523       IntVar* const var, int64_t value,
524       VarConstantConstraintType type) const override {
525     DCHECK(var != nullptr);
526     DCHECK_GE(type, 0);
527     DCHECK_LT(type, VAR_CONSTANT_CONSTRAINT_MAX);
528     return var_constant_constraints_[type]->Find(var, value);
529   }
530 
InsertVarConstantConstraint(Constraint * const ct,IntVar * const var,int64_t value,VarConstantConstraintType type)531   void InsertVarConstantConstraint(Constraint* const ct, IntVar* const var,
532                                    int64_t value,
533                                    VarConstantConstraintType type) override {
534     DCHECK(ct != nullptr);
535     DCHECK(var != nullptr);
536     DCHECK_GE(type, 0);
537     DCHECK_LT(type, VAR_CONSTANT_CONSTRAINT_MAX);
538     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
539         !absl::GetFlag(FLAGS_cp_disable_cache) &&
540         var_constant_constraints_[type]->Find(var, value) == nullptr) {
541       var_constant_constraints_[type]->UnsafeInsert(var, value, ct);
542     }
543   }
544 
545   // Var Constant Constant Constraint.
546 
FindVarConstantConstantConstraint(IntVar * const var,int64_t value1,int64_t value2,VarConstantConstantConstraintType type) const547   Constraint* FindVarConstantConstantConstraint(
548       IntVar* const var, int64_t value1, int64_t value2,
549       VarConstantConstantConstraintType type) const override {
550     DCHECK(var != nullptr);
551     DCHECK_GE(type, 0);
552     DCHECK_LT(type, VAR_CONSTANT_CONSTANT_CONSTRAINT_MAX);
553     return var_constant_constant_constraints_[type]->Find(var, value1, value2);
554   }
555 
InsertVarConstantConstantConstraint(Constraint * const ct,IntVar * const var,int64_t value1,int64_t value2,VarConstantConstantConstraintType type)556   void InsertVarConstantConstantConstraint(
557       Constraint* const ct, IntVar* const var, int64_t value1, int64_t value2,
558       VarConstantConstantConstraintType type) override {
559     DCHECK(ct != nullptr);
560     DCHECK(var != nullptr);
561     DCHECK_GE(type, 0);
562     DCHECK_LT(type, VAR_CONSTANT_CONSTANT_CONSTRAINT_MAX);
563     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
564         !absl::GetFlag(FLAGS_cp_disable_cache) &&
565         var_constant_constant_constraints_[type]->Find(var, value1, value2) ==
566             nullptr) {
567       var_constant_constant_constraints_[type]->UnsafeInsert(var, value1,
568                                                              value2, ct);
569     }
570   }
571 
572   // Var Var Constraint.
573 
FindExprExprConstraint(IntExpr * const var1,IntExpr * const var2,ExprExprConstraintType type) const574   Constraint* FindExprExprConstraint(
575       IntExpr* const var1, IntExpr* const var2,
576       ExprExprConstraintType type) const override {
577     DCHECK(var1 != nullptr);
578     DCHECK(var2 != nullptr);
579     DCHECK_GE(type, 0);
580     DCHECK_LT(type, EXPR_EXPR_CONSTRAINT_MAX);
581     return expr_expr_constraints_[type]->Find(var1, var2);
582   }
583 
InsertExprExprConstraint(Constraint * const ct,IntExpr * const var1,IntExpr * const var2,ExprExprConstraintType type)584   void InsertExprExprConstraint(Constraint* const ct, IntExpr* const var1,
585                                 IntExpr* const var2,
586                                 ExprExprConstraintType type) override {
587     DCHECK(ct != nullptr);
588     DCHECK(var1 != nullptr);
589     DCHECK(var2 != nullptr);
590     DCHECK_GE(type, 0);
591     DCHECK_LT(type, EXPR_EXPR_CONSTRAINT_MAX);
592     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
593         !absl::GetFlag(FLAGS_cp_disable_cache) &&
594         expr_expr_constraints_[type]->Find(var1, var2) == nullptr) {
595       expr_expr_constraints_[type]->UnsafeInsert(var1, var2, ct);
596     }
597   }
598 
599   // Expr Expression.
600 
FindExprExpression(IntExpr * const expr,ExprExpressionType type) const601   IntExpr* FindExprExpression(IntExpr* const expr,
602                               ExprExpressionType type) const override {
603     DCHECK(expr != nullptr);
604     DCHECK_GE(type, 0);
605     DCHECK_LT(type, EXPR_EXPRESSION_MAX);
606     return expr_expressions_[type]->Find(expr);
607   }
608 
InsertExprExpression(IntExpr * const expression,IntExpr * const expr,ExprExpressionType type)609   void InsertExprExpression(IntExpr* const expression, IntExpr* const expr,
610                             ExprExpressionType type) override {
611     DCHECK(expression != nullptr);
612     DCHECK(expr != nullptr);
613     DCHECK_GE(type, 0);
614     DCHECK_LT(type, EXPR_EXPRESSION_MAX);
615     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
616         !absl::GetFlag(FLAGS_cp_disable_cache) &&
617         expr_expressions_[type]->Find(expr) == nullptr) {
618       expr_expressions_[type]->UnsafeInsert(expr, expression);
619     }
620   }
621 
622   // Expr Constant Expressions.
623 
FindExprConstantExpression(IntExpr * const expr,int64_t value,ExprConstantExpressionType type) const624   IntExpr* FindExprConstantExpression(
625       IntExpr* const expr, int64_t value,
626       ExprConstantExpressionType type) const override {
627     DCHECK(expr != nullptr);
628     DCHECK_GE(type, 0);
629     DCHECK_LT(type, EXPR_CONSTANT_EXPRESSION_MAX);
630     return expr_constant_expressions_[type]->Find(expr, value);
631   }
632 
InsertExprConstantExpression(IntExpr * const expression,IntExpr * const expr,int64_t value,ExprConstantExpressionType type)633   void InsertExprConstantExpression(IntExpr* const expression,
634                                     IntExpr* const expr, int64_t value,
635                                     ExprConstantExpressionType type) override {
636     DCHECK(expression != nullptr);
637     DCHECK(expr != nullptr);
638     DCHECK_GE(type, 0);
639     DCHECK_LT(type, EXPR_CONSTANT_EXPRESSION_MAX);
640     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
641         !absl::GetFlag(FLAGS_cp_disable_cache) &&
642         expr_constant_expressions_[type]->Find(expr, value) == nullptr) {
643       expr_constant_expressions_[type]->UnsafeInsert(expr, value, expression);
644     }
645   }
646 
647   // Expr Expr Expression.
648 
FindExprExprExpression(IntExpr * const var1,IntExpr * const var2,ExprExprExpressionType type) const649   IntExpr* FindExprExprExpression(IntExpr* const var1, IntExpr* const var2,
650                                   ExprExprExpressionType type) const override {
651     DCHECK(var1 != nullptr);
652     DCHECK(var2 != nullptr);
653     DCHECK_GE(type, 0);
654     DCHECK_LT(type, EXPR_EXPR_EXPRESSION_MAX);
655     return expr_expr_expressions_[type]->Find(var1, var2);
656   }
657 
InsertExprExprExpression(IntExpr * const expression,IntExpr * const var1,IntExpr * const var2,ExprExprExpressionType type)658   void InsertExprExprExpression(IntExpr* const expression, IntExpr* const var1,
659                                 IntExpr* const var2,
660                                 ExprExprExpressionType type) override {
661     DCHECK(expression != nullptr);
662     DCHECK(var1 != nullptr);
663     DCHECK(var2 != nullptr);
664     DCHECK_GE(type, 0);
665     DCHECK_LT(type, EXPR_EXPR_EXPRESSION_MAX);
666     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
667         !absl::GetFlag(FLAGS_cp_disable_cache) &&
668         expr_expr_expressions_[type]->Find(var1, var2) == nullptr) {
669       expr_expr_expressions_[type]->UnsafeInsert(var1, var2, expression);
670     }
671   }
672 
673   // Expr Expr Constant Expression.
674 
FindExprExprConstantExpression(IntExpr * const var1,IntExpr * const var2,int64_t constant,ExprExprConstantExpressionType type) const675   IntExpr* FindExprExprConstantExpression(
676       IntExpr* const var1, IntExpr* const var2, int64_t constant,
677       ExprExprConstantExpressionType type) const override {
678     DCHECK(var1 != nullptr);
679     DCHECK(var2 != nullptr);
680     DCHECK_GE(type, 0);
681     DCHECK_LT(type, EXPR_EXPR_CONSTANT_EXPRESSION_MAX);
682     return expr_expr_constant_expressions_[type]->Find(var1, var2, constant);
683   }
684 
InsertExprExprConstantExpression(IntExpr * const expression,IntExpr * const var1,IntExpr * const var2,int64_t constant,ExprExprConstantExpressionType type)685   void InsertExprExprConstantExpression(
686       IntExpr* const expression, IntExpr* const var1, IntExpr* const var2,
687       int64_t constant, ExprExprConstantExpressionType type) override {
688     DCHECK(expression != nullptr);
689     DCHECK(var1 != nullptr);
690     DCHECK(var2 != nullptr);
691     DCHECK_GE(type, 0);
692     DCHECK_LT(type, EXPR_EXPR_CONSTANT_EXPRESSION_MAX);
693     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
694         !absl::GetFlag(FLAGS_cp_disable_cache) &&
695         expr_expr_constant_expressions_[type]->Find(var1, var2, constant) ==
696             nullptr) {
697       expr_expr_constant_expressions_[type]->UnsafeInsert(var1, var2, constant,
698                                                           expression);
699     }
700   }
701 
702   // Var Constant Constant Expression.
703 
FindVarConstantConstantExpression(IntVar * const var,int64_t value1,int64_t value2,VarConstantConstantExpressionType type) const704   IntExpr* FindVarConstantConstantExpression(
705       IntVar* const var, int64_t value1, int64_t value2,
706       VarConstantConstantExpressionType type) const override {
707     DCHECK(var != nullptr);
708     DCHECK_GE(type, 0);
709     DCHECK_LT(type, VAR_CONSTANT_CONSTANT_EXPRESSION_MAX);
710     return var_constant_constant_expressions_[type]->Find(var, value1, value2);
711   }
712 
InsertVarConstantConstantExpression(IntExpr * const expression,IntVar * const var,int64_t value1,int64_t value2,VarConstantConstantExpressionType type)713   void InsertVarConstantConstantExpression(
714       IntExpr* const expression, IntVar* const var, int64_t value1,
715       int64_t value2, VarConstantConstantExpressionType type) override {
716     DCHECK(expression != nullptr);
717     DCHECK(var != nullptr);
718     DCHECK_GE(type, 0);
719     DCHECK_LT(type, VAR_CONSTANT_CONSTANT_EXPRESSION_MAX);
720     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
721         !absl::GetFlag(FLAGS_cp_disable_cache) &&
722         var_constant_constant_expressions_[type]->Find(var, value1, value2) ==
723             nullptr) {
724       var_constant_constant_expressions_[type]->UnsafeInsert(
725           var, value1, value2, expression);
726     }
727   }
728 
729   // Var Constant Array Expression.
730 
FindVarConstantArrayExpression(IntVar * const var,const std::vector<int64_t> & values,VarConstantArrayExpressionType type) const731   IntExpr* FindVarConstantArrayExpression(
732       IntVar* const var, const std::vector<int64_t>& values,
733       VarConstantArrayExpressionType type) const override {
734     DCHECK(var != nullptr);
735     DCHECK_GE(type, 0);
736     DCHECK_LT(type, VAR_CONSTANT_ARRAY_EXPRESSION_MAX);
737     return var_constant_array_expressions_[type]->Find(var, values);
738   }
739 
InsertVarConstantArrayExpression(IntExpr * const expression,IntVar * const var,const std::vector<int64_t> & values,VarConstantArrayExpressionType type)740   void InsertVarConstantArrayExpression(
741       IntExpr* const expression, IntVar* const var,
742       const std::vector<int64_t>& values,
743       VarConstantArrayExpressionType type) override {
744     DCHECK(expression != nullptr);
745     DCHECK(var != nullptr);
746     DCHECK_GE(type, 0);
747     DCHECK_LT(type, VAR_CONSTANT_ARRAY_EXPRESSION_MAX);
748     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
749         !absl::GetFlag(FLAGS_cp_disable_cache) &&
750         var_constant_array_expressions_[type]->Find(var, values) == nullptr) {
751       var_constant_array_expressions_[type]->UnsafeInsert(var, values,
752                                                           expression);
753     }
754   }
755 
756   // Var Array Expression.
757 
FindVarArrayExpression(const std::vector<IntVar * > & vars,VarArrayExpressionType type) const758   IntExpr* FindVarArrayExpression(const std::vector<IntVar*>& vars,
759                                   VarArrayExpressionType type) const override {
760     DCHECK_GE(type, 0);
761     DCHECK_LT(type, VAR_ARRAY_EXPRESSION_MAX);
762     return var_array_expressions_[type]->Find(vars);
763   }
764 
InsertVarArrayExpression(IntExpr * const expression,const std::vector<IntVar * > & vars,VarArrayExpressionType type)765   void InsertVarArrayExpression(IntExpr* const expression,
766                                 const std::vector<IntVar*>& vars,
767                                 VarArrayExpressionType type) override {
768     DCHECK(expression != nullptr);
769     DCHECK_GE(type, 0);
770     DCHECK_LT(type, VAR_ARRAY_EXPRESSION_MAX);
771     if (solver()->state() == Solver::OUTSIDE_SEARCH &&
772         !absl::GetFlag(FLAGS_cp_disable_cache) &&
773         var_array_expressions_[type]->Find(vars) == nullptr) {
774       var_array_expressions_[type]->UnsafeInsert(vars, expression);
775     }
776   }
777 
778   // Var Array Constant Array Expressions.
779 
FindVarArrayConstantArrayExpression(const std::vector<IntVar * > & vars,const std::vector<int64_t> & values,VarArrayConstantArrayExpressionType type) const780   IntExpr* FindVarArrayConstantArrayExpression(
781       const std::vector<IntVar*>& vars, const std::vector<int64_t>& values,
782       VarArrayConstantArrayExpressionType type) const override {
783     DCHECK_GE(type, 0);
784     DCHECK_LT(type, VAR_ARRAY_CONSTANT_ARRAY_EXPRESSION_MAX);
785     return var_array_constant_array_expressions_[type]->Find(vars, values);
786   }
787 
InsertVarArrayConstantArrayExpression(IntExpr * const expression,const std::vector<IntVar * > & vars,const std::vector<int64_t> & values,VarArrayConstantArrayExpressionType type)788   void InsertVarArrayConstantArrayExpression(
789       IntExpr* const expression, const std::vector<IntVar*>& vars,
790       const std::vector<int64_t>& values,
791       VarArrayConstantArrayExpressionType type) override {
792     DCHECK(expression != nullptr);
793     DCHECK_GE(type, 0);
794     DCHECK_LT(type, VAR_ARRAY_CONSTANT_ARRAY_EXPRESSION_MAX);
795     if (solver()->state() != Solver::IN_SEARCH &&
796         var_array_constant_array_expressions_[type]->Find(vars, values) ==
797             nullptr) {
798       var_array_constant_array_expressions_[type]->UnsafeInsert(vars, values,
799                                                                 expression);
800     }
801   }
802 
803   // Var Array Constant Expressions.
804 
FindVarArrayConstantExpression(const std::vector<IntVar * > & vars,int64_t value,VarArrayConstantExpressionType type) const805   IntExpr* FindVarArrayConstantExpression(
806       const std::vector<IntVar*>& vars, int64_t value,
807       VarArrayConstantExpressionType type) const override {
808     DCHECK_GE(type, 0);
809     DCHECK_LT(type, VAR_ARRAY_CONSTANT_EXPRESSION_MAX);
810     return var_array_constant_expressions_[type]->Find(vars, value);
811   }
812 
InsertVarArrayConstantExpression(IntExpr * const expression,const std::vector<IntVar * > & vars,int64_t value,VarArrayConstantExpressionType type)813   void InsertVarArrayConstantExpression(
814       IntExpr* const expression, const std::vector<IntVar*>& vars,
815       int64_t value, VarArrayConstantExpressionType type) override {
816     DCHECK(expression != nullptr);
817     DCHECK_GE(type, 0);
818     DCHECK_LT(type, VAR_ARRAY_CONSTANT_EXPRESSION_MAX);
819     if (solver()->state() != Solver::IN_SEARCH &&
820         var_array_constant_expressions_[type]->Find(vars, value) == nullptr) {
821       var_array_constant_expressions_[type]->UnsafeInsert(vars, value,
822                                                           expression);
823     }
824   }
825 
826  private:
827   std::vector<Constraint*> void_constraints_;
828   std::vector<VarConstantConstraintCache*> var_constant_constraints_;
829   std::vector<ExprExprConstraintCache*> expr_expr_constraints_;
830   std::vector<VarConstantConstantConstraintCache*>
831       var_constant_constant_constraints_;
832   std::vector<ExprIntExprCache*> expr_expressions_;
833   std::vector<ExprConstantIntExprCache*> expr_constant_expressions_;
834   std::vector<ExprExprIntExprCache*> expr_expr_expressions_;
835   std::vector<VarConstantConstantIntExprCache*>
836       var_constant_constant_expressions_;
837   std::vector<VarConstantArrayIntExprCache*> var_constant_array_expressions_;
838   std::vector<VarArrayIntExprCache*> var_array_expressions_;
839   std::vector<VarArrayConstantArrayIntExprCache*>
840       var_array_constant_array_expressions_;
841   std::vector<VarArrayConstantIntExprCache*> var_array_constant_expressions_;
842   std::vector<ExprExprConstantIntExprCache*> expr_expr_constant_expressions_;
843 };
844 }  // namespace
845 
BuildModelCache(Solver * const solver)846 ModelCache* BuildModelCache(Solver* const solver) {
847   return new NonReversibleCache(solver);
848 }
849 
Cache() const850 ModelCache* Solver::Cache() const { return model_cache_.get(); }
851 }  // namespace operations_research
852