1 /********************* */
2 /*! \file shared_terms_database.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Dejan Jovanovic, Morgan Deters, Tim King
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** [[ Add lengthier description here ]]
13 ** \todo document this file
14 **/
15
16 #include "theory/shared_terms_database.h"
17
18 #include "smt/smt_statistics_registry.h"
19 #include "theory/theory_engine.h"
20
21 using namespace std;
22 using namespace CVC4::theory;
23
24 namespace CVC4 {
25
SharedTermsDatabase(TheoryEngine * theoryEngine,context::Context * context)26 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
27 context::Context* context)
28 : ContextNotifyObj(context),
29 d_statSharedTerms("theory::shared_terms", 0),
30 d_addedSharedTermsSize(context, 0),
31 d_termsToTheories(context),
32 d_alreadyNotifiedMap(context),
33 d_registeredEqualities(context),
34 d_EENotify(*this),
35 d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true),
36 d_theoryEngine(theoryEngine),
37 d_inConflict(context, false),
38 d_conflictPolarity() {
39 smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
40 }
41
~SharedTermsDatabase()42 SharedTermsDatabase::~SharedTermsDatabase()
43 {
44 smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
45 }
46
addEqualityToPropagate(TNode equality)47 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
48 d_registeredEqualities.insert(equality);
49 d_equalityEngine.addTriggerEquality(equality);
50 checkForConflict();
51 }
52
53
addSharedTerm(TNode atom,TNode term,Theory::Set theories)54 void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
55 Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
56
57 std::pair<TNode, TNode> search_pair(atom, term);
58 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
59 if (find == d_termsToTheories.end()) {
60 // First time for this term and this atom
61 d_atomsToTerms[atom].push_back(term);
62 d_addedSharedTerms.push_back(atom);
63 d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
64 d_termsToTheories[search_pair] = theories;
65 } else {
66 Assert(theories != (*find).second);
67 d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
68 }
69 }
70
begin(TNode atom) const71 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
72 Assert(hasSharedTerms(atom));
73 return d_atomsToTerms.find(atom)->second.begin();
74 }
75
end(TNode atom) const76 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
77 Assert(hasSharedTerms(atom));
78 return d_atomsToTerms.find(atom)->second.end();
79 }
80
hasSharedTerms(TNode atom) const81 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
82 return d_atomsToTerms.find(atom) != d_atomsToTerms.end();
83 }
84
backtrack()85 void SharedTermsDatabase::backtrack() {
86 for (int i = d_addedSharedTerms.size() - 1, i_end = (int)d_addedSharedTermsSize; i >= i_end; -- i) {
87 TNode atom = d_addedSharedTerms[i];
88 shared_terms_list& list = d_atomsToTerms[atom];
89 list.pop_back();
90 if (list.empty()) {
91 d_atomsToTerms.erase(atom);
92 }
93 }
94 d_addedSharedTerms.resize(d_addedSharedTermsSize);
95 }
96
getTheoriesToNotify(TNode atom,TNode term) const97 Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
98 // Get the theories that share this term from this atom
99 std::pair<TNode, TNode> search_pair(atom, term);
100 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
101 Assert(find != d_termsToTheories.end());
102
103 // Get the theories that were already notified
104 Theory::Set alreadyNotified = 0;
105 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
106 if (theoriesFind != d_alreadyNotifiedMap.end()) {
107 alreadyNotified = (*theoriesFind).second;
108 }
109
110 // Return the ones that haven't been notified yet
111 return Theory::setDifference((*find).second, alreadyNotified);
112 }
113
114
getNotifiedTheories(TNode term) const115 Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
116 // Get the theories that were already notified
117 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
118 if (theoriesFind != d_alreadyNotifiedMap.end()) {
119 return (*theoriesFind).second;
120 } else {
121 return 0;
122 }
123 }
124
propagateSharedEquality(TheoryId theory,TNode a,TNode b,bool value)125 bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
126 {
127 Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << "," << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
128
129 if (d_inConflict) {
130 return false;
131 }
132
133 // Propagate away
134 Node equality = a.eqNode(b);
135 if (value) {
136 d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
137 } else {
138 d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
139 }
140
141 // As you were
142 return true;
143 }
144
markNotified(TNode term,Theory::Set theories)145 void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
146
147 // Find out if there are any new theories that were notified about this term
148 Theory::Set alreadyNotified = 0;
149 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
150 if (theoriesFind != d_alreadyNotifiedMap.end()) {
151 alreadyNotified = (*theoriesFind).second;
152 }
153 Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
154
155 // If no new theories were notified, we are done
156 if (newlyNotified == 0) {
157 return;
158 }
159
160 Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
161
162 // First update the set of notified theories for this term
163 d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
164
165 // Mark the shared terms in the equality engine
166 theory::TheoryId currentTheory;
167 while ((currentTheory = Theory::setPop(newlyNotified)) != THEORY_LAST) {
168 d_equalityEngine.addTriggerTerm(term, currentTheory);
169 }
170
171 // Check for any conflits
172 checkForConflict();
173 }
174
areEqual(TNode a,TNode b) const175 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
176 if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
177 return d_equalityEngine.areEqual(a,b);
178 } else {
179 Assert(d_equalityEngine.hasTerm(a) || a.isConst());
180 Assert(d_equalityEngine.hasTerm(b) || b.isConst());
181 // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
182 return false;
183 }
184 }
185
areDisequal(TNode a,TNode b) const186 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
187 if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
188 return d_equalityEngine.areDisequal(a,b,false);
189 } else {
190 Assert(d_equalityEngine.hasTerm(a) || a.isConst());
191 Assert(d_equalityEngine.hasTerm(b) || b.isConst());
192 // one (or both) are in the equality engine
193 return false;
194 }
195 }
196
assertEquality(TNode equality,bool polarity,TNode reason)197 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
198 {
199 Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
200 // Add it to the equality engine
201 d_equalityEngine.assertEquality(equality, polarity, reason);
202 // Check for conflict
203 checkForConflict();
204 }
205
propagateEquality(TNode equality,bool polarity)206 bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
207 if (polarity) {
208 d_theoryEngine->propagate(equality, THEORY_BUILTIN);
209 } else {
210 d_theoryEngine->propagate(equality.notNode(), THEORY_BUILTIN);
211 }
212 return true;
213 }
214
mkAnd(const std::vector<TNode> & conjunctions)215 static Node mkAnd(const std::vector<TNode>& conjunctions) {
216 Assert(conjunctions.size() > 0);
217
218 std::set<TNode> all;
219 all.insert(conjunctions.begin(), conjunctions.end());
220
221 if (all.size() == 1) {
222 // All the same, or just one
223 return conjunctions[0];
224 }
225
226 NodeBuilder<> conjunction(kind::AND);
227 std::set<TNode>::const_iterator it = all.begin();
228 std::set<TNode>::const_iterator it_end = all.end();
229 while (it != it_end) {
230 conjunction << *it;
231 ++ it;
232 }
233
234 return conjunction;
235 }
236
checkForConflict()237 void SharedTermsDatabase::checkForConflict() {
238 if (d_inConflict) {
239 d_inConflict = false;
240 std::vector<TNode> assumptions;
241 d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
242 Node conflict = mkAnd(assumptions);
243 d_theoryEngine->conflict(conflict, THEORY_BUILTIN);
244 d_conflictLHS = d_conflictRHS = Node::null();
245 }
246 }
247
isKnown(TNode literal) const248 bool SharedTermsDatabase::isKnown(TNode literal) const {
249 bool polarity = literal.getKind() != kind::NOT;
250 TNode equality = polarity ? literal : literal[0];
251 if (polarity) {
252 return d_equalityEngine.areEqual(equality[0], equality[1]);
253 } else {
254 return d_equalityEngine.areDisequal(equality[0], equality[1], false);
255 }
256 }
257
explain(TNode literal) const258 Node SharedTermsDatabase::explain(TNode literal) const {
259 bool polarity = literal.getKind() != kind::NOT;
260 TNode atom = polarity ? literal : literal[0];
261 Assert(atom.getKind() == kind::EQUAL);
262 std::vector<TNode> assumptions;
263 d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
264 return mkAnd(assumptions);
265 }
266
267 } /* namespace CVC4 */
268