1 /********************* */
2 /*! \file arith_static_learner.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Tim King, Dejan Jovanovic, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include <vector>
19
20 #include "base/output.h"
21 #include "expr/expr.h"
22 #include "expr/node_algorithm.h"
23 #include "options/arith_options.h"
24 #include "smt/smt_statistics_registry.h"
25 #include "theory/arith/arith_static_learner.h"
26 #include "theory/arith/arith_utilities.h"
27 #include "theory/arith/normal_form.h"
28 #include "theory/rewriter.h"
29
30 using namespace std;
31 using namespace CVC4::kind;
32
33 namespace CVC4 {
34 namespace theory {
35 namespace arith {
36
37
ArithStaticLearner(context::Context * userContext)38 ArithStaticLearner::ArithStaticLearner(context::Context* userContext) :
39 d_minMap(userContext),
40 d_maxMap(userContext),
41 d_statistics()
42 {
43 }
44
~ArithStaticLearner()45 ArithStaticLearner::~ArithStaticLearner(){
46 }
47
Statistics()48 ArithStaticLearner::Statistics::Statistics():
49 d_iteMinMaxApplications("theory::arith::iteMinMaxApplications", 0),
50 d_iteConstantApplications("theory::arith::iteConstantApplications", 0)
51 {
52 smtStatisticsRegistry()->registerStat(&d_iteMinMaxApplications);
53 smtStatisticsRegistry()->registerStat(&d_iteConstantApplications);
54 }
55
~Statistics()56 ArithStaticLearner::Statistics::~Statistics(){
57 smtStatisticsRegistry()->unregisterStat(&d_iteMinMaxApplications);
58 smtStatisticsRegistry()->unregisterStat(&d_iteConstantApplications);
59 }
60
staticLearning(TNode n,NodeBuilder<> & learned)61 void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){
62
63 vector<TNode> workList;
64 workList.push_back(n);
65 TNodeSet processed;
66
67 //Contains an underapproximation of nodes that must hold.
68 TNodeSet defTrue;
69
70 defTrue.insert(n);
71
72 while(!workList.empty()) {
73 n = workList.back();
74
75 bool unprocessedChildren = false;
76 for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) {
77 if(processed.find(*i) == processed.end()) {
78 // unprocessed child
79 workList.push_back(*i);
80 unprocessedChildren = true;
81 }
82 }
83 if(n.getKind() == AND && defTrue.find(n) != defTrue.end() ){
84 for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) {
85 defTrue.insert(*i);
86 }
87 }
88
89 if(unprocessedChildren) {
90 continue;
91 }
92
93 workList.pop_back();
94 // has node n been processed in the meantime ?
95 if(processed.find(n) != processed.end()) {
96 continue;
97 }
98 processed.insert(n);
99
100 process(n,learned, defTrue);
101
102 }
103 }
104
105
process(TNode n,NodeBuilder<> & learned,const TNodeSet & defTrue)106 void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){
107 Debug("arith::static") << "===================== looking at " << n << endl;
108
109 switch(n.getKind()){
110 case ITE:
111 if (expr::hasBoundVar(n))
112 {
113 // Unsafe with non-ground ITEs; do nothing
114 Debug("arith::static")
115 << "(potentially) non-ground ITE, ignoring..." << endl;
116 break;
117 }
118
119 if(n[0].getKind() != EQUAL &&
120 isRelationOperator(n[0].getKind()) ){
121 iteMinMax(n, learned);
122 }
123
124 if((d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) ||
125 (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end())) {
126 iteConstant(n, learned);
127 }
128 break;
129
130 case CONST_RATIONAL:
131 // Mark constants as minmax
132 d_minMap.insert(n, n.getConst<Rational>());
133 d_maxMap.insert(n, n.getConst<Rational>());
134 break;
135 default: // Do nothing
136 break;
137 }
138 }
139
iteMinMax(TNode n,NodeBuilder<> & learned)140 void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){
141 Assert(n.getKind() == kind::ITE);
142 Assert(n[0].getKind() != EQUAL);
143 Assert(isRelationOperator(n[0].getKind()));
144
145 TNode c = n[0];
146 Kind k = oldSimplifiedKind(c);
147 TNode t = n[1];
148 TNode e = n[2];
149 TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0];
150 TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1];
151
152 if((t == cright) && (e == cleft)){
153 TNode tmp = t;
154 t = e;
155 e = tmp;
156 k = reverseRelationKind(k);
157 }
158 //(ite (< x y) x y)
159 //(ite (x < y) x y)
160 //(ite (x - y < 0) x y)
161 // ----------------
162 // (ite (x - y < -c) )
163
164 if(t == cleft && e == cright){
165 // t == cleft && e == cright
166 Assert( t == cleft );
167 Assert( e == cright );
168 switch(k){
169 case LT: // (ite (< x y) x y)
170 case LEQ: { // (ite (<= x y) x y)
171 Node nLeqX = NodeBuilder<2>(LEQ) << n << t;
172 Node nLeqY = NodeBuilder<2>(LEQ) << n << e;
173 Debug("arith::static") << n << "is a min =>" << nLeqX << nLeqY << endl;
174 learned << nLeqX << nLeqY;
175 ++(d_statistics.d_iteMinMaxApplications);
176 break;
177 }
178 case GT: // (ite (> x y) x y)
179 case GEQ: { // (ite (>= x y) x y)
180 Node nGeqX = NodeBuilder<2>(GEQ) << n << t;
181 Node nGeqY = NodeBuilder<2>(GEQ) << n << e;
182 Debug("arith::static") << n << "is a max =>" << nGeqX << nGeqY << endl;
183 learned << nGeqX << nGeqY;
184 ++(d_statistics.d_iteMinMaxApplications);
185 break;
186 }
187 default: Unreachable();
188 }
189 }
190 }
191
iteConstant(TNode n,NodeBuilder<> & learned)192 void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
193 Assert(n.getKind() == ITE);
194
195 Debug("arith::static") << "iteConstant(" << n << ")" << endl;
196
197 if (d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) {
198 const DeltaRational& first = d_minMap[n[1]];
199 const DeltaRational& second = d_minMap[n[2]];
200 DeltaRational min = std::min(first, second);
201 CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n);
202 if (minFind == d_minMap.end() || (*minFind).second < min) {
203 d_minMap.insert(n, min);
204 Node nGeqMin;
205 if (min.getInfinitesimalPart() == 0) {
206 nGeqMin = NodeBuilder<2>(kind::GEQ) << n << mkRationalNode(min.getNoninfinitesimalPart());
207 } else {
208 nGeqMin = NodeBuilder<2>(kind::GT) << n << mkRationalNode(min.getNoninfinitesimalPart());
209 }
210 learned << nGeqMin;
211 Debug("arith::static") << n << " iteConstant" << nGeqMin << endl;
212 ++(d_statistics.d_iteConstantApplications);
213 }
214 }
215
216 if (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end()) {
217 const DeltaRational& first = d_maxMap[n[1]];
218 const DeltaRational& second = d_maxMap[n[2]];
219 DeltaRational max = std::max(first, second);
220 CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n);
221 if (maxFind == d_maxMap.end() || (*maxFind).second > max) {
222 d_maxMap.insert(n, max);
223 Node nLeqMax;
224 if (max.getInfinitesimalPart() == 0) {
225 nLeqMax = NodeBuilder<2>(kind::LEQ) << n << mkRationalNode(max.getNoninfinitesimalPart());
226 } else {
227 nLeqMax = NodeBuilder<2>(kind::LT) << n << mkRationalNode(max.getNoninfinitesimalPart());
228 }
229 learned << nLeqMax;
230 Debug("arith::static") << n << " iteConstant" << nLeqMax << endl;
231 ++(d_statistics.d_iteConstantApplications);
232 }
233 }
234 }
235
listToSet(TNode l)236 std::set<Node> listToSet(TNode l){
237 std::set<Node> ret;
238 while(l.getKind() == OR){
239 Assert(l.getNumChildren() == 2);
240 ret.insert(l[0]);
241 l = l[1];
242 }
243 return ret;
244 }
245
addBound(TNode n)246 void ArithStaticLearner::addBound(TNode n) {
247
248 CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n[0]);
249 CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n[0]);
250
251 Rational constant = n[1].getConst<Rational>();
252 DeltaRational bound = constant;
253
254 switch(Kind k = n.getKind()) {
255 case kind::LT:
256 bound = DeltaRational(constant, -1);
257 /* fall through */
258 case kind::LEQ:
259 if (maxFind == d_maxMap.end() || (*maxFind).second > bound) {
260 d_maxMap.insert(n[0], bound);
261 Debug("arith::static") << "adding bound " << n << endl;
262 }
263 break;
264 case kind::GT:
265 bound = DeltaRational(constant, 1);
266 /* fall through */
267 case kind::GEQ:
268 if (minFind == d_minMap.end() || (*minFind).second < bound) {
269 d_minMap.insert(n[0], bound);
270 Debug("arith::static") << "adding bound " << n << endl;
271 }
272 break;
273 default:
274 Unhandled(k);
275 break;
276 }
277 }
278
279 }/* CVC4::theory::arith namespace */
280 }/* CVC4::theory namespace */
281 }/* CVC4 namespace */
282