1 /*********************                                                        */
2 /*! \file theory_bv_rewrite_rules_operator_elimination.h
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Liana Hadarean, Aina Niemetz, Clark Barrett
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief [[ Add one-line brief description here ]]
13  **
14  ** [[ Add lengthier description here ]]
15  ** \todo document this file
16  **/
17 
18 #include "cvc4_private.h"
19 
20 #pragma once
21 
22 #include "theory/bv/theory_bv_rewrite_rules.h"
23 #include "theory/bv/theory_bv_utils.h"
24 
25 namespace CVC4 {
26 namespace theory {
27 namespace bv {
28 
29 template<>
applies(TNode node)30 bool RewriteRule<UgtEliminate>::applies(TNode node) {
31   return (node.getKind() == kind::BITVECTOR_UGT);
32 }
33 
34 template <>
apply(TNode node)35 Node RewriteRule<UgtEliminate>::apply(TNode node)
36 {
37   Debug("bv-rewrite") << "RewriteRule<UgtEliminate>(" << node << ")"
38                       << std::endl;
39   TNode a = node[0];
40   TNode b = node[1];
41   Node result = NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, b, a);
42   return result;
43 }
44 
45 template<>
applies(TNode node)46 bool RewriteRule<UgeEliminate>::applies(TNode node) {
47   return (node.getKind() == kind::BITVECTOR_UGE);
48 }
49 
50 template <>
apply(TNode node)51 Node RewriteRule<UgeEliminate>::apply(TNode node)
52 {
53   Debug("bv-rewrite") << "RewriteRule<UgeEliminate>(" << node << ")"
54                       << std::endl;
55   TNode a = node[0];
56   TNode b = node[1];
57   Node result = NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULE, b, a);
58   return result;
59 }
60 
61 template<>
applies(TNode node)62 bool RewriteRule<SgtEliminate>::applies(TNode node) {
63   return (node.getKind() == kind::BITVECTOR_SGT);
64 }
65 
66 template <>
apply(TNode node)67 Node RewriteRule<SgtEliminate>::apply(TNode node)
68 {
69   Debug("bv-rewrite") << "RewriteRule<SgtEliminate>(" << node << ")"
70                       << std::endl;
71   TNode a = node[0];
72   TNode b = node[1];
73   Node result = NodeManager::currentNM()->mkNode(kind::BITVECTOR_SLT, b, a);
74   return result;
75 }
76 
77 template<>
applies(TNode node)78 bool RewriteRule<SgeEliminate>::applies(TNode node) {
79   return (node.getKind() == kind::BITVECTOR_SGE);
80 }
81 
82 template <>
apply(TNode node)83 Node RewriteRule<SgeEliminate>::apply(TNode node)
84 {
85   Debug("bv-rewrite") << "RewriteRule<SgeEliminate>(" << node << ")"
86                       << std::endl;
87   TNode a = node[0];
88   TNode b = node[1];
89   Node result = NodeManager::currentNM()->mkNode(kind::BITVECTOR_SLE, b, a);
90   return result;
91 }
92 
93 template <>
applies(TNode node)94 bool RewriteRule<SltEliminate>::applies(TNode node) {
95   return (node.getKind() == kind::BITVECTOR_SLT);
96 }
97 
98 template <>
apply(TNode node)99 Node RewriteRule<SltEliminate>::apply(TNode node)
100 {
101   Debug("bv-rewrite") << "RewriteRule<SltEliminate>(" << node << ")"
102                       << std::endl;
103   NodeManager *nm = NodeManager::currentNM();
104   unsigned size = utils::getSize(node[0]);
105   Integer val = Integer(1).multiplyByPow2(size - 1);
106   Node pow_two = utils::mkConst(size, val);
107   Node a = nm->mkNode(kind::BITVECTOR_PLUS, node[0], pow_two);
108   Node b = nm->mkNode(kind::BITVECTOR_PLUS, node[1], pow_two);
109 
110   return nm->mkNode(kind::BITVECTOR_ULT, a, b);
111 }
112 
113 template <>
applies(TNode node)114 bool RewriteRule<SleEliminate>::applies(TNode node) {
115   return (node.getKind() == kind::BITVECTOR_SLE);
116 }
117 
118 template <>
apply(TNode node)119 Node RewriteRule<SleEliminate>::apply(TNode node)
120 {
121   Debug("bv-rewrite") << "RewriteRule<SleEliminate>(" << node << ")"
122                       << std::endl;
123   NodeManager *nm = NodeManager::currentNM();
124   TNode a = node[0];
125   TNode b = node[1];
126   Node b_slt_a = nm->mkNode(kind::BITVECTOR_SLT, b, a);
127   return nm->mkNode(kind::NOT, b_slt_a);
128 }
129 
130 template <>
applies(TNode node)131 bool RewriteRule<UleEliminate>::applies(TNode node) {
132   return (node.getKind() == kind::BITVECTOR_ULE);
133 }
134 
135 template <>
apply(TNode node)136 Node RewriteRule<UleEliminate>::apply(TNode node)
137 {
138   Debug("bv-rewrite") << "RewriteRule<UleEliminate>(" << node << ")"
139                       << std::endl;
140   NodeManager *nm = NodeManager::currentNM();
141   TNode a = node[0];
142   TNode b = node[1];
143   Node b_ult_a = nm->mkNode(kind::BITVECTOR_ULT, b, a);
144   return nm->mkNode(kind::NOT, b_ult_a);
145 }
146 
147 template <>
applies(TNode node)148 bool RewriteRule<CompEliminate>::applies(TNode node) {
149   return (node.getKind() == kind::BITVECTOR_COMP);
150 }
151 
152 template <>
apply(TNode node)153 Node RewriteRule<CompEliminate>::apply(TNode node)
154 {
155   Debug("bv-rewrite") << "RewriteRule<CompEliminate>(" << node << ")"
156                       << std::endl;
157   NodeManager *nm = NodeManager::currentNM();
158   Node comp = nm->mkNode(kind::EQUAL, node[0], node[1]);
159   Node one = utils::mkConst(1, 1);
160   Node zero = utils::mkConst(1, 0);
161 
162   return nm->mkNode(kind::ITE, comp, one, zero);
163 }
164 
165 template <>
applies(TNode node)166 bool RewriteRule<SubEliminate>::applies(TNode node) {
167   return (node.getKind() == kind::BITVECTOR_SUB);
168 }
169 
170 template <>
apply(TNode node)171 Node RewriteRule<SubEliminate>::apply(TNode node)
172 {
173   Debug("bv-rewrite") << "RewriteRule<SubEliminate>(" << node << ")"
174                       << std::endl;
175   NodeManager *nm = NodeManager::currentNM();
176   Node negb = nm->mkNode(kind::BITVECTOR_NEG, node[1]);
177   Node a = node[0];
178 
179   return nm->mkNode(kind::BITVECTOR_PLUS, a, negb);
180 }
181 
182 template<>
applies(TNode node)183 bool RewriteRule<RepeatEliminate>::applies(TNode node) {
184   return (node.getKind() == kind::BITVECTOR_REPEAT);
185 }
186 
187 template<>
apply(TNode node)188 Node RewriteRule<RepeatEliminate>::apply(TNode node) {
189   Debug("bv-rewrite") << "RewriteRule<RepeatEliminate>(" << node << ")" << std::endl;
190   TNode a = node[0];
191   unsigned amount = node.getOperator().getConst<BitVectorRepeat>().repeatAmount;
192   Assert(amount >= 1);
193   if(amount == 1) {
194     return a;
195   }
196   NodeBuilder<> result(kind::BITVECTOR_CONCAT);
197   for(unsigned i = 0; i < amount; ++i) {
198     result << node[0];
199   }
200   Node resultNode = result;
201   return resultNode;
202 }
203 
204 template<>
applies(TNode node)205 bool RewriteRule<RotateLeftEliminate>::applies(TNode node) {
206   return (node.getKind() == kind::BITVECTOR_ROTATE_LEFT);
207 }
208 
209 template<>
apply(TNode node)210 Node RewriteRule<RotateLeftEliminate>::apply(TNode node) {
211   Debug("bv-rewrite") << "RewriteRule<RotateLeftEliminate>(" << node << ")" << std::endl;
212   TNode a = node[0];
213   unsigned amount = node.getOperator().getConst<BitVectorRotateLeft>().rotateLeftAmount;
214   amount = amount % utils::getSize(a);
215   if (amount == 0) {
216     return a;
217   }
218 
219   Node left   = utils::mkExtract(a, utils::getSize(a)-1 - amount, 0);
220   Node right  = utils::mkExtract(a, utils::getSize(a) -1, utils::getSize(a) - amount);
221   Node result = utils::mkConcat(left, right);
222 
223   return result;
224 }
225 
226 template<>
applies(TNode node)227 bool RewriteRule<RotateRightEliminate>::applies(TNode node) {
228   return (node.getKind() == kind::BITVECTOR_ROTATE_RIGHT);
229 }
230 
231 template<>
apply(TNode node)232 Node RewriteRule<RotateRightEliminate>::apply(TNode node) {
233   Debug("bv-rewrite") << "RewriteRule<RotateRightEliminate>(" << node << ")" << std::endl;
234   TNode a = node[0];
235   unsigned amount = node.getOperator().getConst<BitVectorRotateRight>().rotateRightAmount;
236   amount = amount % utils::getSize(a);
237   if (amount == 0) {
238     return a;
239   }
240 
241   Node left  = utils::mkExtract(a, amount - 1, 0);
242   Node right = utils::mkExtract(a, utils::getSize(a)-1, amount);
243   Node result = utils::mkConcat(left, right);
244 
245   return result;
246 }
247 
248 template<>
applies(TNode node)249 bool RewriteRule<BVToNatEliminate>::applies(TNode node) {
250   return (node.getKind() == kind::BITVECTOR_TO_NAT);
251 }
252 
253 template<>
apply(TNode node)254 Node RewriteRule<BVToNatEliminate>::apply(TNode node) {
255   Debug("bv-rewrite") << "RewriteRule<BVToNatEliminate>(" << node << ")" << std::endl;
256 
257   //if( node[0].isConst() ){
258     //TODO? direct computation instead of term construction+rewriting
259   //}
260 
261   const unsigned size = utils::getSize(node[0]);
262   NodeManager* const nm = NodeManager::currentNM();
263   const Node z = nm->mkConst(Rational(0));
264   const Node bvone = utils::mkOne(1);
265 
266   NodeBuilder<> result(kind::PLUS);
267   Integer i = 1;
268   for(unsigned bit = 0; bit < size; ++bit, i *= 2) {
269     Node cond = nm->mkNode(kind::EQUAL, nm->mkNode(nm->mkConst(BitVectorExtract(bit, bit)), node[0]), bvone);
270     result << nm->mkNode(kind::ITE, cond, nm->mkConst(Rational(i)), z);
271   }
272 
273   return Node(result);
274 }
275 
276 template<>
applies(TNode node)277 bool RewriteRule<IntToBVEliminate>::applies(TNode node) {
278   return (node.getKind() == kind::INT_TO_BITVECTOR);
279 }
280 
281 template<>
apply(TNode node)282 Node RewriteRule<IntToBVEliminate>::apply(TNode node) {
283   Debug("bv-rewrite") << "RewriteRule<IntToBVEliminate>(" << node << ")" << std::endl;
284 
285   //if( node[0].isConst() ){
286     //TODO? direct computation instead of term construction+rewriting
287   //}
288 
289   const unsigned size = node.getOperator().getConst<IntToBitVector>().size;
290   NodeManager* const nm = NodeManager::currentNM();
291   const Node bvzero = utils::mkZero(1);
292   const Node bvone = utils::mkOne(1);
293 
294   std::vector<Node> v;
295   Integer i = 2;
296   while(v.size() < size) {
297     Node cond = nm->mkNode(kind::GEQ, nm->mkNode(kind::INTS_MODULUS_TOTAL, node[0], nm->mkConst(Rational(i))), nm->mkConst(Rational(i, 2)));
298     v.push_back(nm->mkNode(kind::ITE, cond, bvone, bvzero));
299     i *= 2;
300   }
301 
302   NodeBuilder<> result(kind::BITVECTOR_CONCAT);
303   result.append(v.rbegin(), v.rend());
304   return Node(result);
305 }
306 
307 template<>
applies(TNode node)308 bool RewriteRule<NandEliminate>::applies(TNode node) {
309   return (node.getKind() == kind::BITVECTOR_NAND &&
310           node.getNumChildren() == 2);
311 }
312 
313 template <>
apply(TNode node)314 Node RewriteRule<NandEliminate>::apply(TNode node)
315 {
316   Debug("bv-rewrite") << "RewriteRule<NandEliminate>(" << node << ")"
317                       << std::endl;
318   NodeManager *nm = NodeManager::currentNM();
319   TNode a = node[0];
320   TNode b = node[1];
321   Node andNode = nm->mkNode(kind::BITVECTOR_AND, a, b);
322   Node result = nm->mkNode(kind::BITVECTOR_NOT, andNode);
323   return result;
324 }
325 
326 template <>
applies(TNode node)327 bool RewriteRule<NorEliminate>::applies(TNode node)
328 {
329   return (node.getKind() == kind::BITVECTOR_NOR && node.getNumChildren() == 2);
330 }
331 
332 template <>
apply(TNode node)333 Node RewriteRule<NorEliminate>::apply(TNode node)
334 {
335   Debug("bv-rewrite") << "RewriteRule<NorEliminate>(" << node << ")"
336                       << std::endl;
337   NodeManager *nm = NodeManager::currentNM();
338   TNode a = node[0];
339   TNode b = node[1];
340   Node orNode = nm->mkNode(kind::BITVECTOR_OR, a, b);
341   Node result = nm->mkNode(kind::BITVECTOR_NOT, orNode);
342   return result;
343 }
344 
345 template<>
applies(TNode node)346 bool RewriteRule<XnorEliminate>::applies(TNode node) {
347   return (node.getKind() == kind::BITVECTOR_XNOR &&
348           node.getNumChildren() == 2);
349 }
350 
351 template <>
apply(TNode node)352 Node RewriteRule<XnorEliminate>::apply(TNode node)
353 {
354   Debug("bv-rewrite") << "RewriteRule<XnorEliminate>(" << node << ")"
355                       << std::endl;
356   NodeManager *nm = NodeManager::currentNM();
357   TNode a = node[0];
358   TNode b = node[1];
359   Node xorNode = nm->mkNode(kind::BITVECTOR_XOR, a, b);
360   Node result = nm->mkNode(kind::BITVECTOR_NOT, xorNode);
361   return result;
362 }
363 
364 template<>
applies(TNode node)365 bool RewriteRule<SdivEliminate>::applies(TNode node) {
366   return (node.getKind() == kind::BITVECTOR_SDIV);
367 }
368 
369 template <>
apply(TNode node)370 Node RewriteRule<SdivEliminate>::apply(TNode node)
371 {
372   Debug("bv-rewrite") << "RewriteRule<SdivEliminate>(" << node << ")"
373                       << std::endl;
374 
375   NodeManager *nm = NodeManager::currentNM();
376   TNode a = node[0];
377   TNode b = node[1];
378   unsigned size = utils::getSize(a);
379 
380   Node one = utils::mkConst(1, 1);
381   Node a_lt_0 =
382       nm->mkNode(kind::EQUAL, utils::mkExtract(a, size - 1, size - 1), one);
383   Node b_lt_0 =
384       nm->mkNode(kind::EQUAL, utils::mkExtract(b, size - 1, size - 1), one);
385   Node abs_a =
386       nm->mkNode(kind::ITE, a_lt_0, nm->mkNode(kind::BITVECTOR_NEG, a), a);
387   Node abs_b =
388       nm->mkNode(kind::ITE, b_lt_0, nm->mkNode(kind::BITVECTOR_NEG, b), b);
389 
390   Node a_udiv_b =
391       nm->mkNode(options::bitvectorDivByZeroConst() ? kind::BITVECTOR_UDIV_TOTAL
392                                                     : kind::BITVECTOR_UDIV,
393                  abs_a,
394                  abs_b);
395   Node neg_result = nm->mkNode(kind::BITVECTOR_NEG, a_udiv_b);
396 
397   Node condition = nm->mkNode(kind::XOR, a_lt_0, b_lt_0);
398   Node result = nm->mkNode(kind::ITE, condition, neg_result, a_udiv_b);
399 
400   return result;
401 }
402 
403 template<>
applies(TNode node)404 bool RewriteRule<SremEliminate>::applies(TNode node) {
405   return (node.getKind() == kind::BITVECTOR_SREM);
406 }
407 
408 template <>
apply(TNode node)409 Node RewriteRule<SremEliminate>::apply(TNode node)
410 {
411   Debug("bv-rewrite") << "RewriteRule<SremEliminate>(" << node << ")"
412                       << std::endl;
413   NodeManager *nm = NodeManager::currentNM();
414   TNode a = node[0];
415   TNode b = node[1];
416   unsigned size = utils::getSize(a);
417 
418   Node one = utils::mkConst(1, 1);
419   Node a_lt_0 =
420       nm->mkNode(kind::EQUAL, utils::mkExtract(a, size - 1, size - 1), one);
421   Node b_lt_0 =
422       nm->mkNode(kind::EQUAL, utils::mkExtract(b, size - 1, size - 1), one);
423   Node abs_a =
424       nm->mkNode(kind::ITE, a_lt_0, nm->mkNode(kind::BITVECTOR_NEG, a), a);
425   Node abs_b =
426       nm->mkNode(kind::ITE, b_lt_0, nm->mkNode(kind::BITVECTOR_NEG, b), b);
427 
428   Node a_urem_b =
429       nm->mkNode(options::bitvectorDivByZeroConst() ? kind::BITVECTOR_UREM_TOTAL
430                                                     : kind::BITVECTOR_UREM,
431                  abs_a,
432                  abs_b);
433   Node neg_result = nm->mkNode(kind::BITVECTOR_NEG, a_urem_b);
434 
435   Node result = nm->mkNode(kind::ITE, a_lt_0, neg_result, a_urem_b);
436 
437   return result;
438 }
439 
440 template<>
applies(TNode node)441 bool RewriteRule<SmodEliminate>::applies(TNode node) {
442   return (node.getKind() == kind::BITVECTOR_SMOD);
443 }
444 
445 template <>
apply(TNode node)446 Node RewriteRule<SmodEliminate>::apply(TNode node)
447 {
448   Debug("bv-rewrite") << "RewriteRule<SmodEliminate>(" << node << ")"
449                       << std::endl;
450   NodeManager *nm = NodeManager::currentNM();
451   TNode s = node[0];
452   TNode t = node[1];
453   unsigned size = utils::getSize(s);
454 
455   // (bvsmod s t) abbreviates
456   //     (let ((?msb_s ((_ extract |m-1| |m-1|) s))
457   //           (?msb_t ((_ extract |m-1| |m-1|) t)))
458   //       (let ((abs_s (ite (= ?msb_s #b0) s (bvneg s)))
459   //             (abs_t (ite (= ?msb_t #b0) t (bvneg t))))
460   //         (let ((u (bvurem abs_s abs_t)))
461   //           (ite (= u (_ bv0 m))
462   //                u
463   //           (ite (and (= ?msb_s #b0) (= ?msb_t #b0))
464   //                u
465   //           (ite (and (= ?msb_s #b1) (= ?msb_t #b0))
466   //                (bvadd (bvneg u) t)
467   //           (ite (and (= ?msb_s #b0) (= ?msb_t #b1))
468   //                (bvadd u t)
469   //                (bvneg u))))))))
470 
471   Node msb_s = utils::mkExtract(s, size - 1, size - 1);
472   Node msb_t = utils::mkExtract(t, size - 1, size - 1);
473 
474   Node bit1 = utils::mkConst(1, 1);
475   Node bit0 = utils::mkConst(1, 0);
476 
477   Node abs_s =
478       msb_s.eqNode(bit0).iteNode(s, nm->mkNode(kind::BITVECTOR_NEG, s));
479   Node abs_t =
480       msb_t.eqNode(bit0).iteNode(t, nm->mkNode(kind::BITVECTOR_NEG, t));
481 
482   Node u = nm->mkNode(kind::BITVECTOR_UREM, abs_s, abs_t);
483   Node neg_u = nm->mkNode(kind::BITVECTOR_NEG, u);
484 
485   Node cond0 = u.eqNode(utils::mkConst(size, 0));
486   Node cond1 = msb_s.eqNode(bit0).andNode(msb_t.eqNode(bit0));
487   Node cond2 = msb_s.eqNode(bit1).andNode(msb_t.eqNode(bit0));
488   Node cond3 = msb_s.eqNode(bit0).andNode(msb_t.eqNode(bit1));
489 
490   Node result = cond0.iteNode(
491       u,
492       cond1.iteNode(
493           u,
494           cond2.iteNode(
495               nm->mkNode(kind::BITVECTOR_PLUS, neg_u, t),
496               cond3.iteNode(nm->mkNode(kind::BITVECTOR_PLUS, u, t), neg_u))));
497 
498   return result;
499 }
500 
501 template<>
applies(TNode node)502 bool RewriteRule<ZeroExtendEliminate>::applies(TNode node) {
503   return (node.getKind() == kind::BITVECTOR_ZERO_EXTEND);
504 }
505 
506 template<>
apply(TNode node)507 Node RewriteRule<ZeroExtendEliminate>::apply(TNode node) {
508   Debug("bv-rewrite") << "RewriteRule<ZeroExtendEliminate>(" << node << ")" << std::endl;
509 
510   TNode bv = node[0];
511   unsigned amount = node.getOperator().getConst<BitVectorZeroExtend>().zeroExtendAmount;
512   if (amount == 0) {
513     return node[0];
514   }
515   Node zero = utils::mkConst(amount, 0);
516   Node result = utils::mkConcat(zero, node[0]);
517 
518   return result;
519 }
520 
521 template<>
applies(TNode node)522 bool RewriteRule<SignExtendEliminate>::applies(TNode node) {
523   return (node.getKind() == kind::BITVECTOR_SIGN_EXTEND);
524 }
525 
526 template<>
apply(TNode node)527 Node RewriteRule<SignExtendEliminate>::apply(TNode node) {
528   Debug("bv-rewrite") << "RewriteRule<SignExtendEliminate>(" << node << ")" << std::endl;
529 
530   unsigned amount = node.getOperator().getConst<BitVectorSignExtend>().signExtendAmount;
531   if(amount == 0) {
532     return node[0];
533   }
534   unsigned size = utils::getSize(node[0]);
535   Node sign_bit = utils::mkExtract(node[0], size-1, size-1);
536   Node extension = utils::mkConcat(sign_bit, amount);
537 
538   return utils::mkConcat(extension, node[0]);
539 }
540 
541 template<>
applies(TNode node)542 bool RewriteRule<RedorEliminate>::applies(TNode node) {
543   return (node.getKind() == kind::BITVECTOR_REDOR);
544 }
545 
546 template<>
apply(TNode node)547 Node RewriteRule<RedorEliminate>::apply(TNode node) {
548   Debug("bv-rewrite") << "RewriteRule<RedorEliminate>(" << node << ")" << std::endl;
549   TNode a = node[0];
550   unsigned size = utils::getSize(node[0]);
551   Node result = NodeManager::currentNM()->mkNode(kind::EQUAL, a, utils::mkConst( size, 0 ) );
552   return result.negate();
553 }
554 
555 template<>
applies(TNode node)556 bool RewriteRule<RedandEliminate>::applies(TNode node) {
557   return (node.getKind() == kind::BITVECTOR_REDAND);
558 }
559 
560 template<>
apply(TNode node)561 Node RewriteRule<RedandEliminate>::apply(TNode node) {
562   Debug("bv-rewrite") << "RewriteRule<RedandEliminate>(" << node << ")" << std::endl;
563   TNode a = node[0];
564   unsigned size = utils::getSize(node[0]);
565   Node result = NodeManager::currentNM()->mkNode(kind::EQUAL, a, utils::mkOnes( size ) );
566   return result;
567 }
568 
569 }
570 }
571 }
572