1 /*********************                                                        */
2 /*! \file pseudo_boolean_processor.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Tim King, Andres Noetzli
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief [[ Add one-line brief description here ]]
13  **
14  ** [[ Add lengthier description here ]]
15  ** \todo document this file
16  **/
17 
18 #include "preprocessing/passes/pseudo_boolean_processor.h"
19 
20 #include "base/output.h"
21 #include "theory/arith/arith_utilities.h"
22 #include "theory/arith/normal_form.h"
23 #include "theory/rewriter.h"
24 
25 namespace CVC4 {
26 namespace preprocessing {
27 namespace passes {
28 
29 using namespace CVC4::theory;
30 using namespace CVC4::theory::arith;
31 
PseudoBooleanProcessor(PreprocessingPassContext * preprocContext)32 PseudoBooleanProcessor::PseudoBooleanProcessor(
33     PreprocessingPassContext* preprocContext)
34     : PreprocessingPass(preprocContext, "pseudo-boolean-processor"),
35       d_pbBounds(preprocContext->getUserContext()),
36       d_subCache(preprocContext->getUserContext()),
37       d_pbs(preprocContext->getUserContext(), 0)
38 {
39 }
40 
applyInternal(AssertionPipeline * assertionsToPreprocess)41 PreprocessingPassResult PseudoBooleanProcessor::applyInternal(
42     AssertionPipeline* assertionsToPreprocess)
43 {
44   learn(assertionsToPreprocess->ref());
45   if (likelyToHelp())
46   {
47     applyReplacements(assertionsToPreprocess);
48   }
49 
50   return PreprocessingPassResult::NO_CONFLICT;
51 }
52 
decomposeAssertion(Node assertion,bool negated)53 bool PseudoBooleanProcessor::decomposeAssertion(Node assertion, bool negated)
54 {
55   if (assertion.getKind() != kind::GEQ)
56   {
57     return false;
58   }
59   Assert(assertion.getKind() == kind::GEQ);
60 
61   Debug("pbs::rewrites") << "decomposeAssertion" << assertion << std::endl;
62 
63   Node l = assertion[0];
64   Node r = assertion[1];
65 
66   if (r.getKind() != kind::CONST_RATIONAL)
67   {
68     Debug("pbs::rewrites") << "not rhs constant" << assertion << std::endl;
69     return false;
70   }
71   // don't bother matching on anything other than + on the left hand side
72   if (l.getKind() != kind::PLUS)
73   {
74     Debug("pbs::rewrites") << "not plus" << assertion << std::endl;
75     return false;
76   }
77 
78   if (!Polynomial::isMember(l))
79   {
80     Debug("pbs::rewrites") << "not polynomial" << assertion << std::endl;
81     return false;
82   }
83 
84   Polynomial p = Polynomial::parsePolynomial(l);
85   clear();
86   if (negated)
87   {
88     // (not (>= p r))
89     // (< p r)
90     // (> (-p) (-r))
91     // (>= (-p) (-r +1))
92     d_off = (-r.getConst<Rational>());
93 
94     if (d_off.value().isIntegral())
95     {
96       d_off = d_off.value() + Rational(1);
97     }
98     else
99     {
100       d_off = Rational(d_off.value().ceiling());
101     }
102   }
103   else
104   {
105     // (>= p r)
106     d_off = r.getConst<Rational>();
107     d_off = Rational(d_off.value().ceiling());
108   }
109   Assert(d_off.value().isIntegral());
110 
111   int adj = negated ? -1 : 1;
112   for (Polynomial::iterator i = p.begin(), end = p.end(); i != end; ++i)
113   {
114     Monomial m = *i;
115     const Rational& coeff = m.getConstant().getValue();
116     if (!(coeff.isOne() || coeff.isNegativeOne()))
117     {
118       return false;
119     }
120     Assert(coeff.sgn() != 0);
121 
122     const VarList& vl = m.getVarList();
123     Node v = vl.getNode();
124 
125     if (!isPseudoBoolean(v))
126     {
127       return false;
128     }
129     int sgn = adj * coeff.sgn();
130     if (sgn > 0)
131     {
132       d_pos.push_back(v);
133     }
134     else
135     {
136       d_neg.push_back(v);
137     }
138   }
139   // all of the variables are pseudoboolean
140   // with coefficients +/- and the offsetoff
141   return true;
142 }
143 
isPseudoBoolean(Node v) const144 bool PseudoBooleanProcessor::isPseudoBoolean(Node v) const
145 {
146   CDNode2PairMap::const_iterator ci = d_pbBounds.find(v);
147   if (ci != d_pbBounds.end())
148   {
149     const std::pair<Node, Node>& p = (*ci).second;
150     return !(p.first).isNull() && !(p.second).isNull();
151   }
152   return false;
153 }
154 
addGeqZero(Node v,Node exp)155 void PseudoBooleanProcessor::addGeqZero(Node v, Node exp)
156 {
157   Assert(isIntVar(v));
158   Assert(!exp.isNull());
159   CDNode2PairMap::const_iterator ci = d_pbBounds.find(v);
160 
161   Debug("pbs::rewrites") << "addGeqZero " << v << std::endl;
162 
163   if (ci == d_pbBounds.end())
164   {
165     d_pbBounds.insert(v, std::make_pair(exp, Node::null()));
166   }
167   else
168   {
169     const std::pair<Node, Node>& p = (*ci).second;
170     if (p.first.isNull())
171     {
172       Assert(!p.second.isNull());
173       d_pbBounds.insert(v, std::make_pair(exp, p.second));
174       Debug("pbs::rewrites") << "add pbs " << v << std::endl;
175       Assert(isPseudoBoolean(v));
176       d_pbs = d_pbs + 1;
177     }
178   }
179 }
180 
addLeqOne(Node v,Node exp)181 void PseudoBooleanProcessor::addLeqOne(Node v, Node exp)
182 {
183   Assert(isIntVar(v));
184   Assert(!exp.isNull());
185   Debug("pbs::rewrites") << "addLeqOne " << v << std::endl;
186   CDNode2PairMap::const_iterator ci = d_pbBounds.find(v);
187   if (ci == d_pbBounds.end())
188   {
189     d_pbBounds.insert(v, std::make_pair(Node::null(), exp));
190   }
191   else
192   {
193     const std::pair<Node, Node>& p = (*ci).second;
194     if (p.second.isNull())
195     {
196       Assert(!p.first.isNull());
197       d_pbBounds.insert(v, std::make_pair(p.first, exp));
198       Debug("pbs::rewrites") << "add pbs " << v << std::endl;
199       Assert(isPseudoBoolean(v));
200       d_pbs = d_pbs + 1;
201     }
202   }
203 }
204 
learnRewrittenGeq(Node assertion,bool negated,Node orig)205 void PseudoBooleanProcessor::learnRewrittenGeq(Node assertion,
206                                                bool negated,
207                                                Node orig)
208 {
209   Assert(assertion.getKind() == kind::GEQ);
210   Assert(assertion == Rewriter::rewrite(assertion));
211 
212   // assume assertion is rewritten
213   Node l = assertion[0];
214   Node r = assertion[1];
215 
216   if (r.getKind() == kind::CONST_RATIONAL)
217   {
218     const Rational& rc = r.getConst<Rational>();
219     if (isIntVar(l))
220     {
221       if (!negated && rc.isZero())
222       {  // (>= x 0)
223         addGeqZero(l, orig);
224       }
225       else if (negated && rc == Rational(2))
226       {
227         addLeqOne(l, orig);
228       }
229     }
230     else if (l.getKind() == kind::MULT && l.getNumChildren() == 2)
231     {
232       Node c = l[0], v = l[1];
233       if (c.getKind() == kind::CONST_RATIONAL
234           && c.getConst<Rational>().isNegativeOne())
235       {
236         if (isIntVar(v))
237         {
238           if (!negated && rc.isNegativeOne())
239           {  // (>= (* -1 x) -1)
240             addLeqOne(v, orig);
241           }
242         }
243       }
244     }
245   }
246 
247   if (!negated)
248   {
249     learnGeqSub(assertion);
250   }
251 }
252 
learnInternal(Node assertion,bool negated,Node orig)253 void PseudoBooleanProcessor::learnInternal(Node assertion,
254                                            bool negated,
255                                            Node orig)
256 {
257   switch (assertion.getKind())
258   {
259     case kind::GEQ:
260     case kind::GT:
261     case kind::LEQ:
262     case kind::LT:
263     {
264       Node rw = Rewriter::rewrite(assertion);
265       if (assertion == rw)
266       {
267         if (assertion.getKind() == kind::GEQ)
268         {
269           learnRewrittenGeq(assertion, negated, orig);
270         }
271       }
272       else
273       {
274         learnInternal(rw, negated, orig);
275       }
276     }
277     break;
278     case kind::NOT: learnInternal(assertion[0], !negated, orig); break;
279     default: break;  // do nothing
280   }
281 }
282 
learn(Node assertion)283 void PseudoBooleanProcessor::learn(Node assertion)
284 {
285   if (assertion.getKind() == kind::AND)
286   {
287     Node::iterator ci = assertion.begin(), cend = assertion.end();
288     for (; ci != cend; ++ci)
289     {
290       learn(*ci);
291     }
292   }
293   else
294   {
295     learnInternal(assertion, false, assertion);
296   }
297 }
298 
mkGeqOne(Node v)299 Node PseudoBooleanProcessor::mkGeqOne(Node v)
300 {
301   NodeManager* nm = NodeManager::currentNM();
302   return nm->mkNode(kind::GEQ, v, mkRationalNode(Rational(1)));
303 }
304 
learn(const std::vector<Node> & assertions)305 void PseudoBooleanProcessor::learn(const std::vector<Node>& assertions)
306 {
307   std::vector<Node>::const_iterator ci, cend;
308   ci = assertions.begin();
309   cend = assertions.end();
310   for (; ci != cend; ++ci)
311   {
312     learn(*ci);
313   }
314 }
315 
addSub(Node from,Node to)316 void PseudoBooleanProcessor::addSub(Node from, Node to)
317 {
318   if (!d_subCache.hasSubstitution(from))
319   {
320     Node rw_to = Rewriter::rewrite(to);
321     d_subCache.addSubstitution(from, rw_to);
322   }
323 }
324 
learnGeqSub(Node geq)325 void PseudoBooleanProcessor::learnGeqSub(Node geq)
326 {
327   Assert(geq.getKind() == kind::GEQ);
328   const bool negated = false;
329   bool success = decomposeAssertion(geq, negated);
330   if (!success)
331   {
332     Debug("pbs::rewrites") << "failed " << std::endl;
333     return;
334   }
335   Assert(d_off.value().isIntegral());
336   Integer off = d_off.value().ceiling();
337 
338   // \sum pos >= \sum neg + off
339 
340   // for now special case everything we want
341   // target easy clauses
342   if (d_pos.size() == 1 && d_neg.size() == 1 && off.isZero())
343   {
344     // x >= y
345     // |- (y >= 1) => (x >= 1)
346     Node x = d_pos.front();
347     Node y = d_neg.front();
348 
349     Node xGeq1 = mkGeqOne(x);
350     Node yGeq1 = mkGeqOne(y);
351     Node imp = yGeq1.impNode(xGeq1);
352     addSub(geq, imp);
353   }
354   else if (d_pos.size() == 0 && d_neg.size() == 2 && off.isNegativeOne())
355   {
356     // 0 >= (x + y -1)
357     // |- 1 >= x + y
358     // |- (or (not (x >= 1)) (not (y >= 1)))
359     Node x = d_neg[0];
360     Node y = d_neg[1];
361 
362     Node xGeq1 = mkGeqOne(x);
363     Node yGeq1 = mkGeqOne(y);
364     Node cases = (xGeq1.notNode()).orNode(yGeq1.notNode());
365     addSub(geq, cases);
366   }
367   else if (d_pos.size() == 2 && d_neg.size() == 1 && off.isZero())
368   {
369     // (x + y) >= z
370     // |- (z >= 1) => (or (x >= 1) (y >=1 ))
371     Node x = d_pos[0];
372     Node y = d_pos[1];
373     Node z = d_neg[0];
374 
375     Node xGeq1 = mkGeqOne(x);
376     Node yGeq1 = mkGeqOne(y);
377     Node zGeq1 = mkGeqOne(z);
378     NodeManager* nm = NodeManager::currentNM();
379     Node dis = nm->mkNode(kind::OR, zGeq1.notNode(), xGeq1, yGeq1);
380     addSub(geq, dis);
381   }
382 }
383 
applyReplacements(Node pre)384 Node PseudoBooleanProcessor::applyReplacements(Node pre)
385 {
386   Node assertion = Rewriter::rewrite(pre);
387 
388   Node result = d_subCache.apply(assertion);
389   if (Debug.isOn("pbs::rewrites") && result != assertion)
390   {
391     Debug("pbs::rewrites") << "applyReplacements" << assertion << "-> "
392                            << result << std::endl;
393   }
394   return result;
395 }
396 
likelyToHelp() const397 bool PseudoBooleanProcessor::likelyToHelp() const { return d_pbs >= 100; }
398 
applyReplacements(AssertionPipeline * assertionsToPreprocess)399 void PseudoBooleanProcessor::applyReplacements(
400     AssertionPipeline* assertionsToPreprocess)
401 {
402   for (size_t i = 0, N = assertionsToPreprocess->size(); i < N; ++i)
403   {
404     assertionsToPreprocess->replace(
405         i, applyReplacements((*assertionsToPreprocess)[i]));
406   }
407 }
408 
clear()409 void PseudoBooleanProcessor::clear()
410 {
411   d_off.clear();
412   d_pos.clear();
413   d_neg.clear();
414 }
415 
416 
417 }  // namespace passes
418 }  // namespace preprocessing
419 }  // namespace CVC4
420