1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #ifndef OR_TOOLS_SAT_INTEGER_EXPR_H_
15 #define OR_TOOLS_SAT_INTEGER_EXPR_H_
16 
17 #include <cstdint>
18 #include <functional>
19 #include <vector>
20 
21 #include "ortools/base/int_type.h"
22 #include "ortools/base/integral_types.h"
23 #include "ortools/base/logging.h"
24 #include "ortools/base/macros.h"
25 #include "ortools/base/mathutil.h"
26 #include "ortools/sat/integer.h"
27 #include "ortools/sat/linear_constraint.h"
28 #include "ortools/sat/model.h"
29 #include "ortools/sat/precedences.h"
30 #include "ortools/sat/sat_base.h"
31 #include "ortools/sat/sat_solver.h"
32 #include "ortools/util/rev.h"
33 
34 namespace operations_research {
35 namespace sat {
36 
37 // A really basic implementation of an upper-bounded sum of integer variables.
38 // The complexity is in O(num_variables) at each propagation.
39 //
40 // Note that we assume that there can be NO integer overflow. This must be
41 // checked at model validation time before this is even created.
42 //
43 // TODO(user): If one has many such constraint, it will be more efficient to
44 // propagate all of them at once rather than doing it one at the time.
45 //
46 // TODO(user): Explore tree structure to get a log(n) complexity.
47 //
48 // TODO(user): When the variables are Boolean, use directly the pseudo-Boolean
49 // constraint implementation. But we do need support for enforcement literals
50 // there.
51 class IntegerSumLE : public PropagatorInterface {
52  public:
53   // If refied_literal is kNoLiteralIndex then this is a normal constraint,
54   // otherwise we enforce the implication refied_literal => constraint is true.
55   // Note that we don't do the reverse implication here, it is usually done by
56   // another IntegerSumLE constraint on the negated variables.
57   IntegerSumLE(const std::vector<Literal>& enforcement_literals,
58                const std::vector<IntegerVariable>& vars,
59                const std::vector<IntegerValue>& coeffs,
60                IntegerValue upper_bound, Model* model);
61 
62   // We propagate:
63   // - If the sum of the individual lower-bound is > upper_bound, we fail.
64   // - For all i, upper-bound of i
65   //      <= upper_bound - Sum {individual lower-bound excluding i).
66   bool Propagate() final;
67   void RegisterWith(GenericLiteralWatcher* watcher);
68 
69   // Same as Propagate() but only consider current root level bounds. This is
70   // mainly useful for the LP propagator since it can find relevant optimal
71   // really late in the search tree.
72   bool PropagateAtLevelZero();
73 
74   // This is a pretty usage specific function. Returns the implied lower bound
75   // on target_var if the given integer literal is false (resp. true). If the
76   // variables do not appear both in the linear inequality, this returns two
77   // kMinIntegerValue.
78   std::pair<IntegerValue, IntegerValue> ConditionalLb(
79       IntegerLiteral integer_literal, IntegerVariable target_var) const;
80 
81  private:
82   // Fills integer_reason_ with all the current lower_bounds. The real
83   // explanation may require removing one of them, but as an optimization, we
84   // always keep all the IntegerLiteral in integer_reason_, and swap them as
85   // needed just before pushing something.
86   void FillIntegerReason();
87 
88   const std::vector<Literal> enforcement_literals_;
89   const IntegerValue upper_bound_;
90 
91   Trail* trail_;
92   IntegerTrail* integer_trail_;
93   TimeLimit* time_limit_;
94   RevIntegerValueRepository* rev_integer_value_repository_;
95 
96   // Reversible sum of the lower bound of the fixed variables.
97   bool is_registered_ = false;
98   IntegerValue rev_lb_fixed_vars_;
99 
100   // Reversible number of fixed variables.
101   int rev_num_fixed_vars_;
102 
103   // Those vectors are shuffled during search to ensure that the variables
104   // (resp. coefficients) contained in the range [0, rev_num_fixed_vars_) of
105   // vars_ (resp. coeffs_) are fixed (resp. belong to fixed variables).
106   std::vector<IntegerVariable> vars_;
107   std::vector<IntegerValue> coeffs_;
108   std::vector<IntegerValue> max_variations_;
109 
110   std::vector<Literal> literal_reason_;
111 
112   // Parallel vectors.
113   std::vector<IntegerLiteral> integer_reason_;
114   std::vector<IntegerValue> reason_coeffs_;
115 
116   DISALLOW_COPY_AND_ASSIGN(IntegerSumLE);
117 };
118 
119 // This assumes target = SUM_i coeffs[i] * vars[i], and detects that the target
120 // must be of the form (a*X + b).
121 //
122 // This propagator is quite specific and runs only at level zero. For now, this
123 // is mainly used for the objective variable. As we fix terms with high
124 // objective coefficient, it is possible the only terms left have a common
125 // divisor. This close app2-2.mps in less than a second instead of running
126 // forever to prove the optimal (in single thread).
127 class LevelZeroEquality : PropagatorInterface {
128  public:
129   LevelZeroEquality(IntegerVariable target,
130                     const std::vector<IntegerVariable>& vars,
131                     const std::vector<IntegerValue>& coeffs, Model* model);
132 
133   bool Propagate() final;
134 
135  private:
136   const IntegerVariable target_;
137   const std::vector<IntegerVariable> vars_;
138   const std::vector<IntegerValue> coeffs_;
139 
140   IntegerValue gcd_ = IntegerValue(1);
141 
142   Trail* trail_;
143   IntegerTrail* integer_trail_;
144 };
145 
146 // A min (resp max) constraint of the form min == MIN(vars) can be decomposed
147 // into two inequalities:
148 //   1/ min <= MIN(vars), which is the same as for all v in vars, "min <= v".
149 //      This can be taken care of by the LowerOrEqual(min, v) constraint.
150 //   2/ min >= MIN(vars).
151 //
152 // And in turn, 2/ can be decomposed in:
153 //   a) lb(min) >= lb(MIN(vars)) = MIN(lb(var));
154 //   b) ub(min) >= ub(MIN(vars)) and we can't propagate anything here unless
155 //      there is just one possible variable 'v' that can be the min:
156 //         for all u != v, lb(u) > ub(min);
157 //      In this case, ub(min) >= ub(v).
158 //
159 // This constraint take care of a) and b). That is:
160 // - If the min of the lower bound of the vars increase, then the lower bound of
161 //   the min_var will be >= to it.
162 // - If there is only one candidate for the min, then if the ub(min) decrease,
163 //   the ub of the only candidate will be <= to it.
164 //
165 // Complexity: This is a basic implementation in O(num_vars) on each call to
166 // Propagate(), which will happen each time one or more variables in vars_
167 // changed.
168 //
169 // TODO(user): Implement a more efficient algorithm when the need arise.
170 class MinPropagator : public PropagatorInterface {
171  public:
172   MinPropagator(const std::vector<IntegerVariable>& vars,
173                 IntegerVariable min_var, IntegerTrail* integer_trail);
174 
175   bool Propagate() final;
176   void RegisterWith(GenericLiteralWatcher* watcher);
177 
178  private:
179   const std::vector<IntegerVariable> vars_;
180   const IntegerVariable min_var_;
181   IntegerTrail* integer_trail_;
182 
183   std::vector<IntegerLiteral> integer_reason_;
184 
185   DISALLOW_COPY_AND_ASSIGN(MinPropagator);
186 };
187 
188 // Same as MinPropagator except this works on min = MIN(exprs) where exprs are
189 // linear expressions. It uses IntegerSumLE to propagate bounds on the exprs.
190 // Assumes Canonical expressions (all positive coefficients).
191 class LinMinPropagator : public PropagatorInterface {
192  public:
193   LinMinPropagator(const std::vector<LinearExpression>& exprs,
194                    IntegerVariable min_var, Model* model);
195   LinMinPropagator(const LinMinPropagator&) = delete;
196   LinMinPropagator& operator=(const LinMinPropagator&) = delete;
197 
198   bool Propagate() final;
199   void RegisterWith(GenericLiteralWatcher* watcher);
200 
201  private:
202   // Lighter version of IntegerSumLE. This uses the current value of
203   // integer_reason_ in addition to the reason for propagating the linear
204   // constraint. The coeffs are assumed to be positive here.
205   bool PropagateLinearUpperBound(const std::vector<IntegerVariable>& vars,
206                                  const std::vector<IntegerValue>& coeffs,
207                                  IntegerValue upper_bound);
208 
209   const std::vector<LinearExpression> exprs_;
210   const IntegerVariable min_var_;
211   std::vector<IntegerValue> expr_lbs_;
212   Model* model_;
213   IntegerTrail* integer_trail_;
214   std::vector<IntegerLiteral> integer_reason_for_unique_candidate_;
215   int rev_unique_candidate_ = 0;
216 };
217 
218 // Propagates a * b = p.
219 //
220 // The bounds [min, max] of a and b will be propagated perfectly, but not
221 // the bounds on p as this require more complex arithmetics.
222 class ProductPropagator : public PropagatorInterface {
223  public:
224   ProductPropagator(AffineExpression a, AffineExpression b, AffineExpression p,
225                     IntegerTrail* integer_trail);
226 
227   bool Propagate() final;
228   void RegisterWith(GenericLiteralWatcher* watcher);
229 
230  private:
231   // Maybe replace a_, b_ or c_ by their negation to simplify the cases.
232   bool CanonicalizeCases();
233 
234   // Special case when all are >= 0.
235   // We use faster code and better reasons than the generic code.
236   bool PropagateWhenAllNonNegative();
237 
238   // Internal helper, see code for more details.
239   bool PropagateMaxOnPositiveProduct(AffineExpression a, AffineExpression b,
240                                      IntegerValue min_p, IntegerValue max_p);
241 
242   // Note that we might negate any two terms in CanonicalizeCases() during
243   // each propagation. This is fine.
244   AffineExpression a_;
245   AffineExpression b_;
246   AffineExpression p_;
247 
248   IntegerTrail* integer_trail_;
249 
250   DISALLOW_COPY_AND_ASSIGN(ProductPropagator);
251 };
252 
253 // Propagates num / denom = div. Basic version, we don't extract any special
254 // cases, and we only propagates the bounds. It expects denom to be > 0.
255 //
256 // TODO(user): Deal with overflow.
257 class DivisionPropagator : public PropagatorInterface {
258  public:
259   DivisionPropagator(AffineExpression num, AffineExpression denom,
260                      AffineExpression div, IntegerTrail* integer_trail);
261 
262   bool Propagate() final;
263   void RegisterWith(GenericLiteralWatcher* watcher);
264 
265  private:
266   // Propagates the fact that the signs of each domain, if fixed, are
267   // compatible.
268   bool PropagateSigns();
269 
270   // If both num and div >= 0, we can propagate their upper bounds.
271   bool PropagateUpperBounds(AffineExpression num, AffineExpression denom,
272                             AffineExpression div);
273 
274   // When the sign of all 3 expressions are fixed, we can do morel propagation.
275   //
276   // By using negated expressions, we can make sure the domains of num, denom,
277   // and div are positive.
278   bool PropagatePositiveDomains(AffineExpression num, AffineExpression denom,
279                                 AffineExpression div);
280 
281   const AffineExpression num_;
282   const AffineExpression denom_;
283   const AffineExpression div_;
284   const AffineExpression negated_num_;
285   const AffineExpression negated_div_;
286   IntegerTrail* integer_trail_;
287 
288   DISALLOW_COPY_AND_ASSIGN(DivisionPropagator);
289 };
290 
291 // Propagates var_a / cst_b = var_c. Basic version, we don't extract any special
292 // cases, and we only propagates the bounds. cst_b must be > 0.
293 class FixedDivisionPropagator : public PropagatorInterface {
294  public:
295   FixedDivisionPropagator(AffineExpression a, IntegerValue b,
296                           AffineExpression c, IntegerTrail* integer_trail);
297 
298   bool Propagate() final;
299   void RegisterWith(GenericLiteralWatcher* watcher);
300 
301  private:
302   const AffineExpression a_;
303   const IntegerValue b_;
304   const AffineExpression c_;
305 
306   IntegerTrail* integer_trail_;
307 
308   DISALLOW_COPY_AND_ASSIGN(FixedDivisionPropagator);
309 };
310 
311 // Propagates target == expr % mod. Basic version, we don't extract any special
312 // cases, and we only propagates the bounds. mod must be > 0.
313 class FixedModuloPropagator : public PropagatorInterface {
314  public:
315   FixedModuloPropagator(AffineExpression expr, IntegerValue mod,
316                         AffineExpression target, IntegerTrail* integer_trail);
317 
318   bool Propagate() final;
319   void RegisterWith(GenericLiteralWatcher* watcher);
320 
321  private:
322   bool PropagateSignsAndTargetRange();
323   bool PropagateBoundsWhenExprIsPositive(AffineExpression expr,
324                                          AffineExpression target);
325   bool PropagateOuterBounds();
326 
327   const AffineExpression expr_;
328   const IntegerValue mod_;
329   const AffineExpression target_;
330   const AffineExpression negated_expr_;
331   const AffineExpression negated_target_;
332   IntegerTrail* integer_trail_;
333 
334   DISALLOW_COPY_AND_ASSIGN(FixedModuloPropagator);
335 };
336 
337 // Propagates x * x = s.
338 // TODO(user): Only works for x nonnegative.
339 class SquarePropagator : public PropagatorInterface {
340  public:
341   SquarePropagator(AffineExpression x, AffineExpression s,
342                    IntegerTrail* integer_trail);
343 
344   bool Propagate() final;
345   void RegisterWith(GenericLiteralWatcher* watcher);
346 
347  private:
348   const AffineExpression x_;
349   const AffineExpression s_;
350   IntegerTrail* integer_trail_;
351 
352   DISALLOW_COPY_AND_ASSIGN(SquarePropagator);
353 };
354 
355 // =============================================================================
356 // Model based functions.
357 // =============================================================================
358 
359 // Weighted sum <= constant.
360 template <typename VectorInt>
WeightedSumLowerOrEqual(const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t upper_bound)361 inline std::function<void(Model*)> WeightedSumLowerOrEqual(
362     const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
363     int64_t upper_bound) {
364   // Special cases.
365   CHECK_GE(vars.size(), 1);
366   if (vars.size() == 1) {
367     const int64_t c = coefficients[0];
368     CHECK_NE(c, 0);
369     if (c > 0) {
370       return LowerOrEqual(
371           vars[0],
372           FloorRatio(IntegerValue(upper_bound), IntegerValue(c)).value());
373     } else {
374       return GreaterOrEqual(
375           vars[0],
376           CeilRatio(IntegerValue(-upper_bound), IntegerValue(-c)).value());
377     }
378   }
379   if (vars.size() == 2 && (coefficients[0] == 1 || coefficients[0] == -1) &&
380       (coefficients[1] == 1 || coefficients[1] == -1)) {
381     return Sum2LowerOrEqual(
382         coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
383         coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]), upper_bound);
384   }
385   if (vars.size() == 3 && (coefficients[0] == 1 || coefficients[0] == -1) &&
386       (coefficients[1] == 1 || coefficients[1] == -1) &&
387       (coefficients[2] == 1 || coefficients[2] == -1)) {
388     return Sum3LowerOrEqual(
389         coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
390         coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]),
391         coefficients[2] == 1 ? vars[2] : NegationOf(vars[2]), upper_bound);
392   }
393 
394   return [=](Model* model) {
395     // We split large constraints into a square root number of parts.
396     // This is to avoid a bad complexity while propagating them since our
397     // algorithm is not in O(num_changes).
398     //
399     // TODO(user): Alternatively, we could use a O(num_changes) propagation (a
400     // bit tricky to implement), or a decomposition into a tree with more than
401     // one level. Both requires experimentations.
402     //
403     // TODO(user): If the initial constraint was an equalilty we will create
404     // the "intermediate" variable twice where we could have use the same for
405     // both direction. Improve?
406     const int num_vars = vars.size();
407     if (num_vars > 100) {
408       std::vector<IntegerVariable> bucket_sum_vars;
409 
410       std::vector<IntegerVariable> local_vars;
411       std::vector<IntegerValue> local_coeffs;
412 
413       int i = 0;
414       const int num_buckets = static_cast<int>(std::round(std::sqrt(num_vars)));
415       for (int b = 0; b < num_buckets; ++b) {
416         local_vars.clear();
417         local_coeffs.clear();
418         int64_t bucket_lb = 0;
419         int64_t bucket_ub = 0;
420         const int limit = num_vars * (b + 1);
421         for (; i * num_buckets < limit; ++i) {
422           local_vars.push_back(vars[i]);
423           local_coeffs.push_back(IntegerValue(coefficients[i]));
424           const int64_t term1 =
425               model->Get(LowerBound(vars[i])) * coefficients[i];
426           const int64_t term2 =
427               model->Get(UpperBound(vars[i])) * coefficients[i];
428           bucket_lb += std::min(term1, term2);
429           bucket_ub += std::max(term1, term2);
430         }
431 
432         const IntegerVariable bucket_sum =
433             model->Add(NewIntegerVariable(bucket_lb, bucket_ub));
434         bucket_sum_vars.push_back(bucket_sum);
435         local_vars.push_back(bucket_sum);
436         local_coeffs.push_back(IntegerValue(-1));
437         IntegerSumLE* constraint = new IntegerSumLE(
438             {}, local_vars, local_coeffs, IntegerValue(0), model);
439         constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
440         model->TakeOwnership(constraint);
441       }
442 
443       // Create the root-level sum.
444       local_vars.clear();
445       local_coeffs.clear();
446       for (const IntegerVariable var : bucket_sum_vars) {
447         local_vars.push_back(var);
448         local_coeffs.push_back(IntegerValue(1));
449       }
450       IntegerSumLE* constraint = new IntegerSumLE(
451           {}, local_vars, local_coeffs, IntegerValue(upper_bound), model);
452       constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
453       model->TakeOwnership(constraint);
454       return;
455     }
456 
457     IntegerSumLE* constraint = new IntegerSumLE(
458         {}, vars,
459         std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
460         IntegerValue(upper_bound), model);
461     constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
462     model->TakeOwnership(constraint);
463   };
464 }
465 
466 // Weighted sum >= constant.
467 template <typename VectorInt>
WeightedSumGreaterOrEqual(const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t lower_bound)468 inline std::function<void(Model*)> WeightedSumGreaterOrEqual(
469     const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
470     int64_t lower_bound) {
471   // We just negate everything and use an <= constraints.
472   std::vector<int64_t> negated_coeffs(coefficients.begin(), coefficients.end());
473   for (int64_t& ref : negated_coeffs) ref = -ref;
474   return WeightedSumLowerOrEqual(vars, negated_coeffs, -lower_bound);
475 }
476 
477 // Weighted sum == constant.
478 template <typename VectorInt>
FixedWeightedSum(const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t value)479 inline std::function<void(Model*)> FixedWeightedSum(
480     const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
481     int64_t value) {
482   return [=](Model* model) {
483     model->Add(WeightedSumGreaterOrEqual(vars, coefficients, value));
484     model->Add(WeightedSumLowerOrEqual(vars, coefficients, value));
485   };
486 }
487 
488 // enforcement_literals => sum <= upper_bound
489 template <typename VectorInt>
ConditionalWeightedSumLowerOrEqual(const std::vector<Literal> & enforcement_literals,const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t upper_bound)490 inline std::function<void(Model*)> ConditionalWeightedSumLowerOrEqual(
491     const std::vector<Literal>& enforcement_literals,
492     const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
493     int64_t upper_bound) {
494   // Special cases.
495   CHECK_GE(vars.size(), 1);
496   if (vars.size() == 1) {
497     CHECK_NE(coefficients[0], 0);
498     if (coefficients[0] > 0) {
499       return Implication(
500           enforcement_literals,
501           IntegerLiteral::LowerOrEqual(
502               vars[0], FloorRatio(IntegerValue(upper_bound),
503                                   IntegerValue(coefficients[0]))));
504     } else {
505       return Implication(
506           enforcement_literals,
507           IntegerLiteral::GreaterOrEqual(
508               vars[0], CeilRatio(IntegerValue(-upper_bound),
509                                  IntegerValue(-coefficients[0]))));
510     }
511   }
512 
513   if (vars.size() == 2 && (coefficients[0] == 1 || coefficients[0] == -1) &&
514       (coefficients[1] == 1 || coefficients[1] == -1)) {
515     return ConditionalSum2LowerOrEqual(
516         coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
517         coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]), upper_bound,
518         enforcement_literals);
519   }
520   if (vars.size() == 3 && (coefficients[0] == 1 || coefficients[0] == -1) &&
521       (coefficients[1] == 1 || coefficients[1] == -1) &&
522       (coefficients[2] == 1 || coefficients[2] == -1)) {
523     return ConditionalSum3LowerOrEqual(
524         coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
525         coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]),
526         coefficients[2] == 1 ? vars[2] : NegationOf(vars[2]), upper_bound,
527         enforcement_literals);
528   }
529 
530   return [=](Model* model) {
531     // If value == min(expression), then we can avoid creating the sum.
532     IntegerValue expression_min(0);
533     auto* integer_trail = model->GetOrCreate<IntegerTrail>();
534     for (int i = 0; i < vars.size(); ++i) {
535       expression_min +=
536           coefficients[i] * (coefficients[i] >= 0
537                                  ? integer_trail->LowerBound(vars[i])
538                                  : integer_trail->UpperBound(vars[i]));
539     }
540     if (expression_min == upper_bound) {
541       // Tricky: as we create integer literal, we might propagate stuff and
542       // the bounds might change, so if the expression_min increase with the
543       // bound we use, then the literal must be false.
544       IntegerValue non_cached_min;
545       for (int i = 0; i < vars.size(); ++i) {
546         if (coefficients[i] > 0) {
547           const IntegerValue lb = integer_trail->LowerBound(vars[i]);
548           non_cached_min += coefficients[i] * lb;
549           model->Add(Implication(enforcement_literals,
550                                  IntegerLiteral::LowerOrEqual(vars[i], lb)));
551         } else if (coefficients[i] < 0) {
552           const IntegerValue ub = integer_trail->UpperBound(vars[i]);
553           non_cached_min += coefficients[i] * ub;
554           model->Add(Implication(enforcement_literals,
555                                  IntegerLiteral::GreaterOrEqual(vars[i], ub)));
556         }
557       }
558       if (non_cached_min > expression_min) {
559         std::vector<Literal> clause;
560         for (const Literal l : enforcement_literals) {
561           clause.push_back(l.Negated());
562         }
563         model->Add(ClauseConstraint(clause));
564       }
565     } else {
566       IntegerSumLE* constraint = new IntegerSumLE(
567           enforcement_literals, vars,
568           std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
569           IntegerValue(upper_bound), model);
570       constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
571       model->TakeOwnership(constraint);
572     }
573   };
574 }
575 
576 // enforcement_literals => sum >= lower_bound
577 template <typename VectorInt>
ConditionalWeightedSumGreaterOrEqual(const std::vector<Literal> & enforcement_literals,const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t lower_bound)578 inline std::function<void(Model*)> ConditionalWeightedSumGreaterOrEqual(
579     const std::vector<Literal>& enforcement_literals,
580     const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
581     int64_t lower_bound) {
582   // We just negate everything and use an <= constraint.
583   std::vector<int64_t> negated_coeffs(coefficients.begin(), coefficients.end());
584   for (int64_t& ref : negated_coeffs) ref = -ref;
585   return ConditionalWeightedSumLowerOrEqual(enforcement_literals, vars,
586                                             negated_coeffs, -lower_bound);
587 }
588 
589 // Weighted sum <= constant reified.
590 template <typename VectorInt>
WeightedSumLowerOrEqualReif(Literal is_le,const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t upper_bound)591 inline std::function<void(Model*)> WeightedSumLowerOrEqualReif(
592     Literal is_le, const std::vector<IntegerVariable>& vars,
593     const VectorInt& coefficients, int64_t upper_bound) {
594   return [=](Model* model) {
595     model->Add(ConditionalWeightedSumLowerOrEqual({is_le}, vars, coefficients,
596                                                   upper_bound));
597     model->Add(ConditionalWeightedSumGreaterOrEqual(
598         {is_le.Negated()}, vars, coefficients, upper_bound + 1));
599   };
600 }
601 
602 // Weighted sum >= constant reified.
603 template <typename VectorInt>
WeightedSumGreaterOrEqualReif(Literal is_ge,const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t lower_bound)604 inline std::function<void(Model*)> WeightedSumGreaterOrEqualReif(
605     Literal is_ge, const std::vector<IntegerVariable>& vars,
606     const VectorInt& coefficients, int64_t lower_bound) {
607   return [=](Model* model) {
608     model->Add(ConditionalWeightedSumGreaterOrEqual({is_ge}, vars, coefficients,
609                                                     lower_bound));
610     model->Add(ConditionalWeightedSumLowerOrEqual(
611         {is_ge.Negated()}, vars, coefficients, lower_bound - 1));
612   };
613 }
614 
615 // LinearConstraint version.
LoadLinearConstraint(const LinearConstraint & cst,Model * model)616 inline void LoadLinearConstraint(const LinearConstraint& cst, Model* model) {
617   if (cst.vars.empty()) {
618     if (cst.lb <= 0 && cst.ub >= 0) return;
619     model->GetOrCreate<SatSolver>()->NotifyThatModelIsUnsat();
620     return;
621   }
622 
623   // TODO(user): Remove the conversion!
624   std::vector<int64_t> converted_coeffs;
625 
626   for (const IntegerValue v : cst.coeffs) converted_coeffs.push_back(v.value());
627   if (cst.ub < kMaxIntegerValue) {
628     model->Add(
629         WeightedSumLowerOrEqual(cst.vars, converted_coeffs, cst.ub.value()));
630   }
631   if (cst.lb > kMinIntegerValue) {
632     model->Add(
633         WeightedSumGreaterOrEqual(cst.vars, converted_coeffs, cst.lb.value()));
634   }
635 }
636 
LoadConditionalLinearConstraint(const absl::Span<const Literal> enforcement_literals,const LinearConstraint & cst,Model * model)637 inline void LoadConditionalLinearConstraint(
638     const absl::Span<const Literal> enforcement_literals,
639     const LinearConstraint& cst, Model* model) {
640   if (enforcement_literals.empty()) {
641     return LoadLinearConstraint(cst, model);
642   }
643   if (cst.vars.empty()) {
644     if (cst.lb <= 0 && cst.ub >= 0) return;
645     return model->Add(ClauseConstraint(enforcement_literals));
646   }
647 
648   // TODO(user): Remove the conversion!
649   std::vector<Literal> converted_literals(enforcement_literals.begin(),
650                                           enforcement_literals.end());
651   std::vector<int64_t> converted_coeffs;
652   for (const IntegerValue v : cst.coeffs) converted_coeffs.push_back(v.value());
653 
654   if (cst.ub < kMaxIntegerValue) {
655     model->Add(ConditionalWeightedSumLowerOrEqual(
656         converted_literals, cst.vars, converted_coeffs, cst.ub.value()));
657   }
658   if (cst.lb > kMinIntegerValue) {
659     model->Add(ConditionalWeightedSumGreaterOrEqual(
660         converted_literals, cst.vars, converted_coeffs, cst.lb.value()));
661   }
662 }
663 
664 // Weighted sum == constant reified.
665 // TODO(user): Simplify if the constant is at the edge of the possible values.
666 template <typename VectorInt>
FixedWeightedSumReif(Literal is_eq,const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t value)667 inline std::function<void(Model*)> FixedWeightedSumReif(
668     Literal is_eq, const std::vector<IntegerVariable>& vars,
669     const VectorInt& coefficients, int64_t value) {
670   return [=](Model* model) {
671     // We creates two extra Boolean variables in this case. The alternative is
672     // to code a custom propagator for the direction equality => reified.
673     const Literal is_le = Literal(model->Add(NewBooleanVariable()), true);
674     const Literal is_ge = Literal(model->Add(NewBooleanVariable()), true);
675     model->Add(ReifiedBoolAnd({is_le, is_ge}, is_eq));
676     model->Add(WeightedSumLowerOrEqualReif(is_le, vars, coefficients, value));
677     model->Add(WeightedSumGreaterOrEqualReif(is_ge, vars, coefficients, value));
678   };
679 }
680 
681 // Weighted sum != constant.
682 // TODO(user): Simplify if the constant is at the edge of the possible values.
683 template <typename VectorInt>
WeightedSumNotEqual(const std::vector<IntegerVariable> & vars,const VectorInt & coefficients,int64_t value)684 inline std::function<void(Model*)> WeightedSumNotEqual(
685     const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
686     int64_t value) {
687   return [=](Model* model) {
688     // Exactly one of these alternative must be true.
689     const Literal is_lt = Literal(model->Add(NewBooleanVariable()), true);
690     const Literal is_gt = is_lt.Negated();
691     model->Add(ConditionalWeightedSumLowerOrEqual(is_lt, vars, coefficients,
692                                                   value - 1));
693     model->Add(ConditionalWeightedSumGreaterOrEqual(is_gt, vars, coefficients,
694                                                     value + 1));
695   };
696 }
697 
698 // Model-based function to create an IntegerVariable that corresponds to the
699 // given weighted sum of other IntegerVariables.
700 //
701 // Note that this is templated so that it can seamlessly accept vector<int> or
702 // vector<int64_t>.
703 //
704 // TODO(user): invert the coefficients/vars arguments.
705 template <typename VectorInt>
NewWeightedSum(const VectorInt & coefficients,const std::vector<IntegerVariable> & vars)706 inline std::function<IntegerVariable(Model*)> NewWeightedSum(
707     const VectorInt& coefficients, const std::vector<IntegerVariable>& vars) {
708   return [=](Model* model) {
709     std::vector<IntegerVariable> new_vars = vars;
710     // To avoid overflow in the FixedWeightedSum() constraint, we need to
711     // compute the basic bounds on the sum.
712     //
713     // TODO(user): deal with overflow here too!
714     int64_t sum_lb(0);
715     int64_t sum_ub(0);
716     for (int i = 0; i < new_vars.size(); ++i) {
717       if (coefficients[i] > 0) {
718         sum_lb += coefficients[i] * model->Get(LowerBound(new_vars[i]));
719         sum_ub += coefficients[i] * model->Get(UpperBound(new_vars[i]));
720       } else {
721         sum_lb += coefficients[i] * model->Get(UpperBound(new_vars[i]));
722         sum_ub += coefficients[i] * model->Get(LowerBound(new_vars[i]));
723       }
724     }
725 
726     const IntegerVariable sum = model->Add(NewIntegerVariable(sum_lb, sum_ub));
727     new_vars.push_back(sum);
728     std::vector<int64_t> new_coeffs(coefficients.begin(), coefficients.end());
729     new_coeffs.push_back(-1);
730     model->Add(FixedWeightedSum(new_vars, new_coeffs, 0));
731     return sum;
732   };
733 }
734 
735 // Expresses the fact that an existing integer variable is equal to the minimum
736 // of other integer variables.
IsEqualToMinOf(IntegerVariable min_var,const std::vector<IntegerVariable> & vars)737 inline std::function<void(Model*)> IsEqualToMinOf(
738     IntegerVariable min_var, const std::vector<IntegerVariable>& vars) {
739   return [=](Model* model) {
740     for (const IntegerVariable& var : vars) {
741       model->Add(LowerOrEqual(min_var, var));
742     }
743 
744     MinPropagator* constraint =
745         new MinPropagator(vars, min_var, model->GetOrCreate<IntegerTrail>());
746     constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
747     model->TakeOwnership(constraint);
748   };
749 }
750 
751 // Expresses the fact that an existing integer variable is equal to the minimum
752 // of linear expressions. Assumes Canonical expressions (all positive
753 // coefficients).
IsEqualToMinOf(const LinearExpression & min_expr,const std::vector<LinearExpression> & exprs)754 inline std::function<void(Model*)> IsEqualToMinOf(
755     const LinearExpression& min_expr,
756     const std::vector<LinearExpression>& exprs) {
757   return [=](Model* model) {
758     IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
759 
760     IntegerVariable min_var;
761     if (min_expr.vars.size() == 1 &&
762         std::abs(min_expr.coeffs[0].value()) == 1 && min_expr.offset == 0) {
763       if (min_expr.coeffs[0].value() == 1) {
764         min_var = min_expr.vars[0];
765       } else {
766         min_var = NegationOf(min_expr.vars[0]);
767       }
768     } else {
769       // Create a new variable if the expression is not just a single variable.
770       IntegerValue min_lb = LinExprLowerBound(min_expr, *integer_trail);
771       IntegerValue min_ub = LinExprUpperBound(min_expr, *integer_trail);
772       min_var = integer_trail->AddIntegerVariable(min_lb, min_ub);
773 
774       // min_var = min_expr
775       std::vector<IntegerVariable> min_sum_vars = min_expr.vars;
776       std::vector<int64_t> min_sum_coeffs;
777       for (IntegerValue coeff : min_expr.coeffs) {
778         min_sum_coeffs.push_back(coeff.value());
779       }
780       min_sum_vars.push_back(min_var);
781       min_sum_coeffs.push_back(-1);
782 
783       model->Add(FixedWeightedSum(min_sum_vars, min_sum_coeffs,
784                                   -min_expr.offset.value()));
785     }
786     for (const LinearExpression& expr : exprs) {
787       // min_var <= expr
788       std::vector<IntegerVariable> vars = expr.vars;
789       std::vector<int64_t> coeffs;
790       for (IntegerValue coeff : expr.coeffs) {
791         coeffs.push_back(coeff.value());
792       }
793       vars.push_back(min_var);
794       coeffs.push_back(-1);
795       model->Add(WeightedSumGreaterOrEqual(vars, coeffs, -expr.offset.value()));
796     }
797 
798     LinMinPropagator* constraint = new LinMinPropagator(exprs, min_var, model);
799     constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
800     model->TakeOwnership(constraint);
801   };
802 }
803 
804 // Expresses the fact that an existing integer variable is equal to the maximum
805 // of other integer variables.
IsEqualToMaxOf(IntegerVariable max_var,const std::vector<IntegerVariable> & vars)806 inline std::function<void(Model*)> IsEqualToMaxOf(
807     IntegerVariable max_var, const std::vector<IntegerVariable>& vars) {
808   return [=](Model* model) {
809     std::vector<IntegerVariable> negated_vars;
810     for (const IntegerVariable& var : vars) {
811       negated_vars.push_back(NegationOf(var));
812       model->Add(GreaterOrEqual(max_var, var));
813     }
814 
815     MinPropagator* constraint = new MinPropagator(
816         negated_vars, NegationOf(max_var), model->GetOrCreate<IntegerTrail>());
817     constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
818     model->TakeOwnership(constraint);
819   };
820 }
821 
822 // Expresses the fact that an existing integer variable is equal to one of
823 // the given values, each selected by a given literal.
824 std::function<void(Model*)> IsOneOf(IntegerVariable var,
825                                     const std::vector<Literal>& selectors,
826                                     const std::vector<IntegerValue>& values);
827 
828 template <class T>
RegisterAndTransferOwnership(Model * model,T * ct)829 void RegisterAndTransferOwnership(Model* model, T* ct) {
830   ct->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
831   model->TakeOwnership(ct);
832 }
833 // Adds the constraint: a * b = p.
ProductConstraint(AffineExpression a,AffineExpression b,AffineExpression p)834 inline std::function<void(Model*)> ProductConstraint(AffineExpression a,
835                                                      AffineExpression b,
836                                                      AffineExpression p) {
837   return [=](Model* model) {
838     IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
839     if (a == b) {
840       if (integer_trail->LowerBound(a) >= 0) {
841         RegisterAndTransferOwnership(model,
842                                      new SquarePropagator(a, p, integer_trail));
843         return;
844       }
845       if (integer_trail->UpperBound(a) <= 0) {
846         RegisterAndTransferOwnership(
847             model, new SquarePropagator(a.Negated(), p, integer_trail));
848         return;
849       }
850     }
851     RegisterAndTransferOwnership(model,
852                                  new ProductPropagator(a, b, p, integer_trail));
853   };
854 }
855 
856 // Adds the constraint: num / denom = div. (denom > 0).
DivisionConstraint(AffineExpression num,AffineExpression denom,AffineExpression div)857 inline std::function<void(Model*)> DivisionConstraint(AffineExpression num,
858                                                       AffineExpression denom,
859                                                       AffineExpression div) {
860   return [=](Model* model) {
861     IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
862     DivisionPropagator* constraint;
863     if (integer_trail->UpperBound(denom) < 0) {
864       constraint = new DivisionPropagator(num.Negated(), denom.Negated(), div,
865                                           integer_trail);
866 
867     } else {
868       constraint = new DivisionPropagator(num, denom, div, integer_trail);
869     }
870     constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
871     model->TakeOwnership(constraint);
872   };
873 }
874 
875 // Adds the constraint: a / b = c where b is a constant.
FixedDivisionConstraint(AffineExpression a,IntegerValue b,AffineExpression c)876 inline std::function<void(Model*)> FixedDivisionConstraint(AffineExpression a,
877                                                            IntegerValue b,
878                                                            AffineExpression c) {
879   return [=](Model* model) {
880     IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
881     FixedDivisionPropagator* constraint =
882         b > 0 ? new FixedDivisionPropagator(a, b, c, integer_trail)
883               : new FixedDivisionPropagator(a.Negated(), -b, c, integer_trail);
884     constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
885     model->TakeOwnership(constraint);
886   };
887 }
888 
889 // Adds the constraint: a % b = c where b is a constant.
FixedModuloConstraint(AffineExpression a,IntegerValue b,AffineExpression c)890 inline std::function<void(Model*)> FixedModuloConstraint(AffineExpression a,
891                                                          IntegerValue b,
892                                                          AffineExpression c) {
893   return [=](Model* model) {
894     IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
895     FixedModuloPropagator* constraint =
896         new FixedModuloPropagator(a, b, c, integer_trail);
897     constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
898     model->TakeOwnership(constraint);
899   };
900 }
901 
902 }  // namespace sat
903 }  // namespace operations_research
904 
905 #endif  // OR_TOOLS_SAT_INTEGER_EXPR_H_
906