1 /*********************                                                        */
2 /*! \file shared_terms_database.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Dejan Jovanovic, Morgan Deters, Tim King
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** [[ Add lengthier description here ]]
13  ** \todo document this file
14  **/
15 
16 #include "theory/shared_terms_database.h"
17 
18 #include "smt/smt_statistics_registry.h"
19 #include "theory/theory_engine.h"
20 
21 using namespace std;
22 using namespace CVC4::theory;
23 
24 namespace CVC4 {
25 
SharedTermsDatabase(TheoryEngine * theoryEngine,context::Context * context)26 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
27                                          context::Context* context)
28     : ContextNotifyObj(context),
29       d_statSharedTerms("theory::shared_terms", 0),
30       d_addedSharedTermsSize(context, 0),
31       d_termsToTheories(context),
32       d_alreadyNotifiedMap(context),
33       d_registeredEqualities(context),
34       d_EENotify(*this),
35       d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true),
36       d_theoryEngine(theoryEngine),
37       d_inConflict(context, false),
38       d_conflictPolarity() {
39   smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
40 }
41 
~SharedTermsDatabase()42 SharedTermsDatabase::~SharedTermsDatabase()
43 {
44   smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
45 }
46 
addEqualityToPropagate(TNode equality)47 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
48   d_registeredEqualities.insert(equality);
49   d_equalityEngine.addTriggerEquality(equality);
50   checkForConflict();
51 }
52 
53 
addSharedTerm(TNode atom,TNode term,Theory::Set theories)54 void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
55   Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
56 
57   std::pair<TNode, TNode> search_pair(atom, term);
58   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
59   if (find == d_termsToTheories.end()) {
60     // First time for this term and this atom
61     d_atomsToTerms[atom].push_back(term);
62     d_addedSharedTerms.push_back(atom);
63     d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
64     d_termsToTheories[search_pair] = theories;
65   } else {
66     Assert(theories != (*find).second);
67     d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
68   }
69 }
70 
begin(TNode atom) const71 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
72   Assert(hasSharedTerms(atom));
73   return d_atomsToTerms.find(atom)->second.begin();
74 }
75 
end(TNode atom) const76 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
77   Assert(hasSharedTerms(atom));
78   return d_atomsToTerms.find(atom)->second.end();
79 }
80 
hasSharedTerms(TNode atom) const81 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
82   return d_atomsToTerms.find(atom) != d_atomsToTerms.end();
83 }
84 
backtrack()85 void SharedTermsDatabase::backtrack() {
86   for (int i = d_addedSharedTerms.size() - 1, i_end = (int)d_addedSharedTermsSize; i >= i_end; -- i) {
87     TNode atom = d_addedSharedTerms[i];
88     shared_terms_list& list = d_atomsToTerms[atom];
89     list.pop_back();
90     if (list.empty()) {
91       d_atomsToTerms.erase(atom);
92     }
93   }
94   d_addedSharedTerms.resize(d_addedSharedTermsSize);
95 }
96 
getTheoriesToNotify(TNode atom,TNode term) const97 Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
98   // Get the theories that share this term from this atom
99   std::pair<TNode, TNode> search_pair(atom, term);
100   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
101   Assert(find != d_termsToTheories.end());
102 
103   // Get the theories that were already notified
104   Theory::Set alreadyNotified = 0;
105   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
106   if (theoriesFind != d_alreadyNotifiedMap.end()) {
107     alreadyNotified = (*theoriesFind).second;
108   }
109 
110   // Return the ones that haven't been notified yet
111   return Theory::setDifference((*find).second, alreadyNotified);
112 }
113 
114 
getNotifiedTheories(TNode term) const115 Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
116   // Get the theories that were already notified
117   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
118   if (theoriesFind != d_alreadyNotifiedMap.end()) {
119     return (*theoriesFind).second;
120   } else {
121     return 0;
122   }
123 }
124 
propagateSharedEquality(TheoryId theory,TNode a,TNode b,bool value)125 bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
126 {
127   Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << "," << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
128 
129   if (d_inConflict) {
130     return false;
131   }
132 
133   // Propagate away
134   Node equality = a.eqNode(b);
135   if (value) {
136     d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
137   } else {
138     d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
139   }
140 
141   // As you were
142   return true;
143 }
144 
markNotified(TNode term,Theory::Set theories)145 void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
146 
147   // Find out if there are any new theories that were notified about this term
148   Theory::Set alreadyNotified = 0;
149   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
150   if (theoriesFind != d_alreadyNotifiedMap.end()) {
151     alreadyNotified = (*theoriesFind).second;
152   }
153   Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
154 
155   // If no new theories were notified, we are done
156   if (newlyNotified == 0) {
157     return;
158   }
159 
160   Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
161 
162   // First update the set of notified theories for this term
163   d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
164 
165   // Mark the shared terms in the equality engine
166   theory::TheoryId currentTheory;
167   while ((currentTheory = Theory::setPop(newlyNotified)) != THEORY_LAST) {
168     d_equalityEngine.addTriggerTerm(term, currentTheory);
169   }
170 
171   // Check for any conflits
172   checkForConflict();
173 }
174 
areEqual(TNode a,TNode b) const175 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
176   if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
177     return d_equalityEngine.areEqual(a,b);
178   } else {
179     Assert(d_equalityEngine.hasTerm(a) || a.isConst());
180     Assert(d_equalityEngine.hasTerm(b) || b.isConst());
181     // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
182     return false;
183   }
184 }
185 
areDisequal(TNode a,TNode b) const186 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
187   if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
188     return d_equalityEngine.areDisequal(a,b,false);
189   } else {
190     Assert(d_equalityEngine.hasTerm(a) || a.isConst());
191     Assert(d_equalityEngine.hasTerm(b) || b.isConst());
192     // one (or both) are in the equality engine
193     return false;
194   }
195 }
196 
assertEquality(TNode equality,bool polarity,TNode reason)197 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
198 {
199   Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
200   // Add it to the equality engine
201   d_equalityEngine.assertEquality(equality, polarity, reason);
202   // Check for conflict
203   checkForConflict();
204 }
205 
propagateEquality(TNode equality,bool polarity)206 bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
207   if (polarity) {
208     d_theoryEngine->propagate(equality, THEORY_BUILTIN);
209   } else {
210     d_theoryEngine->propagate(equality.notNode(), THEORY_BUILTIN);
211   }
212   return true;
213 }
214 
mkAnd(const std::vector<TNode> & conjunctions)215 static Node mkAnd(const std::vector<TNode>& conjunctions) {
216   Assert(conjunctions.size() > 0);
217 
218   std::set<TNode> all;
219   all.insert(conjunctions.begin(), conjunctions.end());
220 
221   if (all.size() == 1) {
222     // All the same, or just one
223     return conjunctions[0];
224   }
225 
226   NodeBuilder<> conjunction(kind::AND);
227   std::set<TNode>::const_iterator it = all.begin();
228   std::set<TNode>::const_iterator it_end = all.end();
229   while (it != it_end) {
230     conjunction << *it;
231     ++ it;
232   }
233 
234   return conjunction;
235 }
236 
checkForConflict()237 void SharedTermsDatabase::checkForConflict() {
238   if (d_inConflict) {
239     d_inConflict = false;
240     std::vector<TNode> assumptions;
241     d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
242     Node conflict = mkAnd(assumptions);
243     d_theoryEngine->conflict(conflict, THEORY_BUILTIN);
244     d_conflictLHS = d_conflictRHS = Node::null();
245   }
246 }
247 
isKnown(TNode literal) const248 bool SharedTermsDatabase::isKnown(TNode literal) const {
249   bool polarity = literal.getKind() != kind::NOT;
250   TNode equality = polarity ? literal : literal[0];
251   if (polarity) {
252     return d_equalityEngine.areEqual(equality[0], equality[1]);
253   } else {
254     return d_equalityEngine.areDisequal(equality[0], equality[1], false);
255   }
256 }
257 
explain(TNode literal) const258 Node SharedTermsDatabase::explain(TNode literal) const {
259   bool polarity = literal.getKind() != kind::NOT;
260   TNode atom = polarity ? literal : literal[0];
261   Assert(atom.getKind() == kind::EQUAL);
262   std::vector<TNode> assumptions;
263   d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
264   return mkAnd(assumptions);
265 }
266 
267 } /* namespace CVC4 */
268