1 /*********************                                                        */
2 /*! \file theory_sets_rels.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Paul Meng, Andrew Reynolds, Tim King
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Sets theory implementation.
13  **
14  ** Extension to Sets theory.
15  **/
16 
17 #include "theory/sets/theory_sets_rels.h"
18 #include "expr/datatype.h"
19 #include "theory/sets/theory_sets_private.h"
20 #include "theory/sets/theory_sets.h"
21 
22 using namespace std;
23 
24 namespace CVC4 {
25 namespace theory {
26 namespace sets {
27 
28 typedef std::map< Node, std::vector< Node > >::iterator                                         MEM_IT;
29 typedef std::map< kind::Kind_t, std::vector< Node > >::iterator                                 KIND_TERM_IT;
30 typedef std::map< Node, std::unordered_set< Node, NodeHashFunction > >::iterator                     TC_GRAPH_IT;
31 typedef std::map< Node, std::map< kind::Kind_t, std::vector< Node > > >::iterator               TERM_IT;
32 typedef std::map< Node, std::map< Node, std::unordered_set< Node, NodeHashFunction > > >::iterator   TC_IT;
33 
check(Theory::Effort level)34   void TheorySetsRels::check(Theory::Effort level) {
35     Trace("rels") << "\n[sets-rels] ******************************* Start the relational solver, effort = " << level << " *******************************\n" << std::endl;
36     if(Theory::fullEffort(level)) {
37       collectRelsInfo();
38       check();
39       doPendingLemmas();
40       Assert( d_lemmas_out.empty() );
41       Assert( d_pending_facts.empty() );
42     } else {
43       doPendingMerge();
44     }
45     Trace("rels") << "\n[sets-rels] ******************************* Done with the relational solver *******************************\n" << std::endl;
46   }
47 
check()48   void TheorySetsRels::check() {
49     MEM_IT m_it = d_rReps_memberReps_cache.begin();
50 
51     while(m_it != d_rReps_memberReps_cache.end()) {
52       Node rel_rep = m_it->first;
53 
54       for(unsigned int i = 0; i < m_it->second.size(); i++) {
55         Node    mem     = d_rReps_memberReps_cache[rel_rep][i];
56         Node    exp     = d_rReps_memberReps_exp_cache[rel_rep][i];
57         std::map<kind::Kind_t, std::vector<Node> >    kind_terms      = d_terms_cache[rel_rep];
58 
59         if( kind_terms.find(kind::TRANSPOSE) != kind_terms.end() ) {
60           std::vector<Node> tp_terms = kind_terms[kind::TRANSPOSE];
61           if( tp_terms.size() > 0 ) {
62             applyTransposeRule( tp_terms );
63             applyTransposeRule( tp_terms[0], rel_rep, exp );
64           }
65         }
66         if( kind_terms.find(kind::JOIN) != kind_terms.end() ) {
67           std::vector<Node> join_terms = kind_terms[kind::JOIN];
68           for( unsigned int j = 0; j < join_terms.size(); j++ ) {
69             applyJoinRule( join_terms[j], rel_rep, exp );
70           }
71         }
72         if( kind_terms.find(kind::PRODUCT) != kind_terms.end() ) {
73           std::vector<Node> product_terms = kind_terms[kind::PRODUCT];
74           for( unsigned int j = 0; j < product_terms.size(); j++ ) {
75             applyProductRule( product_terms[j], rel_rep, exp );
76           }
77         }
78         if( kind_terms.find(kind::TCLOSURE) != kind_terms.end() ) {
79           std::vector<Node> tc_terms = kind_terms[kind::TCLOSURE];
80           for( unsigned int j = 0; j < tc_terms.size(); j++ ) {
81             applyTCRule( mem, tc_terms[j], rel_rep, exp );
82           }
83         }
84         if( kind_terms.find(kind::JOIN_IMAGE) != kind_terms.end() ) {
85           std::vector<Node> join_image_terms = kind_terms[kind::JOIN_IMAGE];
86           for( unsigned int j = 0; j < join_image_terms.size(); j++ ) {
87             applyJoinImageRule( mem, join_image_terms[j], exp );
88           }
89         }
90         if( kind_terms.find(kind::IDEN) != kind_terms.end() ) {
91           std::vector<Node> iden_terms = kind_terms[kind::IDEN];
92           for( unsigned int j = 0; j < iden_terms.size(); j++ ) {
93             applyIdenRule( mem, iden_terms[j], exp );
94           }
95         }
96       }
97       m_it++;
98     }
99 
100     TERM_IT t_it = d_terms_cache.begin();
101     while( t_it != d_terms_cache.end() ) {
102       if( d_rReps_memberReps_cache.find(t_it->first) == d_rReps_memberReps_cache.end() ) {
103         Trace("rels-debug") << "[sets-rels] A term does not have membership constraints: " << t_it->first << std::endl;
104         KIND_TERM_IT k_t_it = t_it->second.begin();
105 
106         while( k_t_it != t_it->second.end() ) {
107           if( k_t_it->first == kind::JOIN || k_t_it->first == kind::PRODUCT ) {
108             std::vector<Node>::iterator term_it = k_t_it->second.begin();
109             while(term_it != k_t_it->second.end()) {
110               computeMembersForBinOpRel( *term_it );
111               term_it++;
112             }
113           } else if( k_t_it->first == kind::TRANSPOSE ) {
114             std::vector<Node>::iterator term_it = k_t_it->second.begin();
115             while( term_it != k_t_it->second.end() ) {
116               computeMembersForUnaryOpRel( *term_it );
117               term_it++;
118             }
119           } else if ( k_t_it->first == kind::TCLOSURE ) {
120             std::vector<Node>::iterator term_it = k_t_it->second.begin();
121             while( term_it != k_t_it->second.end() ) {
122               buildTCGraphForRel( *term_it );
123               term_it++;
124             }
125           } else if( k_t_it->first == kind::JOIN_IMAGE ) {
126             std::vector<Node>::iterator term_it = k_t_it->second.begin();
127             while( term_it != k_t_it->second.end() ) {
128               computeMembersForJoinImageTerm( *term_it );
129               term_it++;
130             }
131           } else if( k_t_it->first == kind::IDEN ) {
132             std::vector<Node>::iterator term_it = k_t_it->second.begin();
133             while( term_it != k_t_it->second.end() ) {
134               computeMembersForIdenTerm( *term_it );
135               term_it++;
136             }
137           }
138           k_t_it++;
139         }
140       }
141       t_it++;
142     }
143     doTCInference();
144   }
145 
146   /*
147    * Populate relational terms data structure
148    */
149 
collectRelsInfo()150   void TheorySetsRels::collectRelsInfo() {
151     Trace("rels") << "[sets-rels] Start collecting relational terms..." << std::endl;
152     eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( d_eqEngine );
153     while( !eqcs_i.isFinished() ){
154       Node                      eqc_rep  = (*eqcs_i);
155       eq::EqClassIterator       eqc_i   = eq::EqClassIterator( eqc_rep, d_eqEngine );
156 
157       Trace("rels-ee") << "[sets-rels-ee] Eqc term representative: " << eqc_rep << " with type " << eqc_rep.getType() << std::endl;
158 
159       while( !eqc_i.isFinished() ){
160         Node eqc_node = (*eqc_i);
161 
162         Trace("rels-ee") << "  term : " << eqc_node << std::endl;
163 
164         if( getRepresentative(eqc_rep) == getRepresentative(d_trueNode) ||
165             getRepresentative(eqc_rep) == getRepresentative(d_falseNode) ) {
166 
167           // collect membership info
168           if( eqc_node.getKind() == kind::MEMBER && eqc_node[1].getType().getSetElementType().isTuple()) {
169             Node tup_rep = getRepresentative( eqc_node[0] );
170             Node rel_rep = getRepresentative( eqc_node[1] );
171 
172             if( eqc_node[0].isVar() ){
173               reduceTupleVar( eqc_node );
174             }
175 
176             bool is_true_eq    = areEqual( eqc_rep, d_trueNode );
177             Node reason        = is_true_eq ? eqc_node : eqc_node.negate();
178 
179             if( is_true_eq ) {
180               if( safelyAddToMap(d_rReps_memberReps_cache, rel_rep, tup_rep) ) {
181                 addToMap(d_rReps_memberReps_exp_cache, rel_rep, reason);
182                 computeTupleReps(tup_rep);
183                 d_membership_trie[rel_rep].addTerm(tup_rep, d_tuple_reps[tup_rep]);
184               }
185             }
186           }
187         // collect relational terms info
188         } else if( eqc_rep.getType().isSet() && eqc_rep.getType().getSetElementType().isTuple() ) {
189           if( eqc_node.getKind() == kind::TRANSPOSE || eqc_node.getKind() == kind::JOIN ||
190               eqc_node.getKind() == kind::PRODUCT || eqc_node.getKind() == kind::TCLOSURE ||
191               eqc_node.getKind() == kind::JOIN_IMAGE || eqc_node.getKind() == kind::IDEN ) {
192             std::vector<Node> terms;
193             std::map< kind::Kind_t, std::vector<Node> >  rel_terms;
194             TERM_IT terms_it = d_terms_cache.find(eqc_rep);
195 
196             if( terms_it == d_terms_cache.end() ) {
197               terms.push_back(eqc_node);
198               rel_terms[eqc_node.getKind()]      = terms;
199               d_terms_cache[eqc_rep]             = rel_terms;
200             } else {
201               KIND_TERM_IT kind_term_it = terms_it->second.find(eqc_node.getKind());
202 
203               if( kind_term_it == terms_it->second.end() ) {
204                 terms.push_back(eqc_node);
205                 d_terms_cache[eqc_rep][eqc_node.getKind()] = terms;
206               } else {
207                 kind_term_it->second.push_back(eqc_node);
208               }
209             }
210           }
211         // need to add all tuple elements as shared terms
212         } else if( eqc_node.getType().isTuple() && !eqc_node.isConst() && !eqc_node.isVar() ) {
213           for( unsigned int i = 0; i < eqc_node.getType().getTupleLength(); i++ ) {
214             Node element = RelsUtils::nthElementOfTuple( eqc_node, i );
215 
216             if( !element.isConst() ) {
217               makeSharedTerm( element );
218             }
219           }
220         }
221         ++eqc_i;
222       }
223       ++eqcs_i;
224     }
225     Trace("rels-debug") << "[Theory::Rels] Done with collecting relational terms!" << std::endl;
226   }
227 
228   /* JOIN-IMAGE UP  :   (x, x1) IS_IN R, ..., (x, xn) IS_IN R  (R JOIN_IMAGE n)
229   *                     -------------------------------------------------------
230   *                     x IS_IN (R JOIN_IMAGE n) || NOT DISTINCT(x1, ... , xn)
231   *
232   */
233 
computeMembersForJoinImageTerm(Node join_image_term)234   void TheorySetsRels::computeMembersForJoinImageTerm( Node join_image_term ) {
235     Trace("rels-debug") << "\n[Theory::Rels] *********** Compute members for JoinImage Term = " << join_image_term << std::endl;
236     MEM_IT rel_mem_it = d_rReps_memberReps_cache.find( getRepresentative( join_image_term[0] ) );
237 
238     if( rel_mem_it == d_rReps_memberReps_cache.end() ) {
239       return;
240     }
241 
242     Node join_image_rel = join_image_term[0];
243     std::unordered_set< Node, NodeHashFunction > hasChecked;
244     Node join_image_rel_rep = getRepresentative( join_image_rel );
245     std::vector< Node >::iterator mem_rep_it = (*rel_mem_it).second.begin();
246     MEM_IT rel_mem_exp_it = d_rReps_memberReps_exp_cache.find( join_image_rel_rep );
247     std::vector< Node >::iterator mem_rep_exp_it = (*rel_mem_exp_it).second.begin();
248     unsigned int min_card = join_image_term[1].getConst<Rational>().getNumerator().getUnsignedInt();
249 
250     while( mem_rep_it != (*rel_mem_it).second.end() ) {
251       Node fst_mem_rep = RelsUtils::nthElementOfTuple( *mem_rep_it, 0 );
252 
253       if( hasChecked.find( fst_mem_rep ) != hasChecked.end() ) {
254         ++mem_rep_it;
255         ++mem_rep_exp_it;
256         continue;
257       }
258       hasChecked.insert( fst_mem_rep );
259 
260       Datatype dt = join_image_term.getType().getSetElementType().getDatatype();
261       Node new_membership = NodeManager::currentNM()->mkNode(kind::MEMBER,
262                                                              NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR,
263                                                                                                Node::fromExpr(dt[0].getConstructor()), fst_mem_rep ),
264                                                              join_image_term);
265       if( holds( new_membership ) ) {
266         ++mem_rep_it;
267         ++mem_rep_exp_it;
268         continue;
269       }
270 
271       std::vector< Node > reasons;
272       std::vector< Node > existing_members;
273       std::vector< Node >::iterator mem_rep_exp_it_snd = (*rel_mem_exp_it).second.begin();
274 
275       while( mem_rep_exp_it_snd != (*rel_mem_exp_it).second.end() ) {
276         Node fst_element_snd_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 0 );
277 
278         if( areEqual( fst_mem_rep,  fst_element_snd_mem ) ) {
279           bool notExist = true;
280           std::vector< Node >::iterator existing_mem_it = existing_members.begin();
281           Node snd_element_snd_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 1 );
282 
283           while( existing_mem_it != existing_members.end() ) {
284             if( areEqual( (*existing_mem_it), snd_element_snd_mem ) ) {
285               notExist = false;
286               break;
287             }
288             ++existing_mem_it;
289           }
290 
291           if( notExist ) {
292             existing_members.push_back( snd_element_snd_mem );
293             reasons.push_back( *mem_rep_exp_it_snd );
294             if( fst_mem_rep != fst_element_snd_mem ) {
295               reasons.push_back( NodeManager::currentNM()->mkNode( kind::EQUAL, fst_mem_rep, fst_element_snd_mem ) );
296             }
297             if( join_image_rel != (*mem_rep_exp_it_snd)[1] ) {
298               reasons.push_back( NodeManager::currentNM()->mkNode( kind::EQUAL, (*mem_rep_exp_it_snd)[1], join_image_rel ));
299             }
300             if( existing_members.size() == min_card ) {
301               if( min_card >= 2) {
302                 new_membership = NodeManager::currentNM()->mkNode( kind::OR, new_membership, NodeManager::currentNM()->mkNode( kind::NOT, NodeManager::currentNM()->mkNode( kind::DISTINCT, existing_members ) ));
303               }
304               Assert(reasons.size() >= 1);
305               sendInfer( new_membership, reasons.size() > 1 ? NodeManager::currentNM()->mkNode( kind::AND, reasons) : reasons[0], "JOIN-IMAGE UP" );
306               break;
307             }
308           }
309         }
310         ++mem_rep_exp_it_snd;
311       }
312       ++mem_rep_it;
313       ++mem_rep_exp_it;
314     }
315     Trace("rels-debug") << "\n[Theory::Rels] *********** Done with computing members for JoinImage Term" << join_image_term << "*********** " << std::endl;
316   }
317 
318   /* JOIN-IMAGE DOWN  : (x) IS_IN (R JOIN_IMAGE n)
319   *                     -------------------------------------------------------
320   *                     (x, x1) IS_IN R .... (x, xn) IS_IN R  DISTINCT(x1, ... , xn)
321   *
322   */
323 
applyJoinImageRule(Node mem_rep,Node join_image_term,Node exp)324   void TheorySetsRels::applyJoinImageRule( Node mem_rep, Node join_image_term, Node exp ) {
325     Trace("rels-debug") << "\n[Theory::Rels] *********** applyJoinImageRule on " << join_image_term
326                         << " with mem_rep = " << mem_rep  << " and exp = " << exp << std::endl;
327     if( d_rel_nodes.find( join_image_term ) == d_rel_nodes.end() ) {
328       computeMembersForJoinImageTerm( join_image_term );
329       d_rel_nodes.insert( join_image_term );
330     }
331 
332     Node join_image_rel = join_image_term[0];
333     Node join_image_rel_rep = getRepresentative( join_image_rel );
334     MEM_IT rel_mem_it = d_rReps_memberReps_cache.find( join_image_rel_rep );
335     unsigned int min_card = join_image_term[1].getConst<Rational>().getNumerator().getUnsignedInt();
336 
337     if( rel_mem_it != d_rReps_memberReps_cache.end() ) {
338       if( d_membership_trie.find( join_image_rel_rep ) != d_membership_trie.end() ) {
339         computeTupleReps( mem_rep );
340         if( d_membership_trie[join_image_rel_rep].findSuccessors(d_tuple_reps[mem_rep]).size() >= min_card ) {
341           return;
342         }
343       }
344     }
345 
346     Node reason = exp;
347     Node conclusion = d_trueNode;
348     std::vector< Node > distinct_skolems;
349     Node fst_mem_element = RelsUtils::nthElementOfTuple( exp[0], 0 );
350 
351     if( exp[1] != join_image_term ) {
352       reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode( kind::EQUAL, exp[1], join_image_term ) );
353     }
354     for( unsigned int i = 0; i < min_card; i++ ) {
355       Node skolem = NodeManager::currentNM()->mkSkolem( "jig", join_image_rel.getType()[0].getTupleTypes()[0] );
356       distinct_skolems.push_back( skolem );
357       conclusion = NodeManager::currentNM()->mkNode( kind::AND, conclusion, NodeManager::currentNM()->mkNode( kind::MEMBER, RelsUtils::constructPair( join_image_rel, fst_mem_element, skolem ), join_image_rel ) );
358     }
359     if( distinct_skolems.size() >= 2 ) {
360       conclusion =  NodeManager::currentNM()->mkNode( kind::AND, conclusion, NodeManager::currentNM()->mkNode( kind::DISTINCT, distinct_skolems ) );
361     }
362     sendInfer( conclusion, reason, "JOIN-IMAGE DOWN" );
363     Trace("rels-debug") << "\n[Theory::Rels] *********** Done with applyJoinImageRule ***********" << std::endl;
364 
365   }
366 
367 
368   /* IDENTITY-DOWN  : (x, y) IS_IN IDEN(R)
369   *               -------------------------------------------------------
370   *                   x = y,  (x IS_IN R)
371   *
372   */
373 
applyIdenRule(Node mem_rep,Node iden_term,Node exp)374   void TheorySetsRels::applyIdenRule( Node mem_rep, Node iden_term, Node exp) {
375     Trace("rels-debug") << "\n[Theory::Rels] *********** applyIdenRule on " << iden_term
376                         << " with mem_rep = " << mem_rep  << " and exp = " << exp << std::endl;
377     if( d_rel_nodes.find( iden_term ) == d_rel_nodes.end() ) {
378       computeMembersForIdenTerm( iden_term );
379       d_rel_nodes.insert( iden_term );
380     }
381     Node reason = exp;
382     Node fst_mem = RelsUtils::nthElementOfTuple( exp[0], 0 );
383     Node snd_mem = RelsUtils::nthElementOfTuple( exp[0], 1 );
384     Datatype dt = iden_term[0].getType().getSetElementType().getDatatype();
385     Node fact = NodeManager::currentNM()->mkNode( kind::MEMBER, NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, Node::fromExpr(dt[0].getConstructor()), fst_mem ), iden_term[0] );
386 
387     if( exp[1] != iden_term ) {
388       reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode( kind::EQUAL, exp[1], iden_term ) );
389     }
390     sendInfer( NodeManager::currentNM()->mkNode( kind::AND, fact, NodeManager::currentNM()->mkNode( kind::EQUAL, fst_mem, snd_mem ) ), reason, "IDENTITY-DOWN" );
391     Trace("rels-debug") << "\n[Theory::Rels] *********** Done with applyIdenRule on " << iden_term << std::endl;
392   }
393 
394   /* IDEN UP  : (x) IS_IN R        IDEN(R) IN T
395   *             --------------------------------
396   *                   (x, x) IS_IN IDEN(R)
397   *
398   */
399 
computeMembersForIdenTerm(Node iden_term)400   void TheorySetsRels::computeMembersForIdenTerm( Node iden_term ) {
401     Trace("rels-debug") << "\n[Theory::Rels] *********** Compute members for Iden Term = " << iden_term << std::endl;
402     Node iden_term_rel = iden_term[0];
403     Node iden_term_rel_rep = getRepresentative( iden_term_rel );
404 
405     if( d_rReps_memberReps_cache.find( iden_term_rel_rep ) == d_rReps_memberReps_cache.end() ) {
406       return;
407     }
408 
409     MEM_IT rel_mem_exp_it = d_rReps_memberReps_exp_cache.find( iden_term_rel_rep );
410     std::vector< Node >::iterator mem_rep_exp_it = (*rel_mem_exp_it).second.begin();
411 
412     while( mem_rep_exp_it != (*rel_mem_exp_it).second.end() ) {
413       Node reason = *mem_rep_exp_it;
414       Node fst_exp_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it)[0], 0 );
415       Node new_mem = RelsUtils::constructPair( iden_term, fst_exp_mem, fst_exp_mem );
416 
417       if( (*mem_rep_exp_it)[1] != iden_term_rel ) {
418         reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode( kind::EQUAL, (*mem_rep_exp_it)[1], iden_term_rel ) );
419       }
420       sendInfer( NodeManager::currentNM()->mkNode( kind::MEMBER, new_mem, iden_term ), reason, "IDENTITY-UP" );
421       ++mem_rep_exp_it;
422     }
423     Trace("rels-debug") << "\n[Theory::Rels] *********** Done with computing members for Iden Term = " << iden_term << std::endl;
424   }
425 
426 
427   /*
428    * Construct transitive closure graph for tc_rep based on the members of tc_r_rep
429    */
430 
431   /*
432    * TCLOSURE TCLOSURE(x) = x | x.x | x.x.x | ... (| is union)
433    *
434    * TCLOSURE-UP I:   (a, b) IS_IN x            TCLOSURE(x) in T
435    *              ---------------------------------------------
436    *                              (a, b) IS_IN TCLOSURE(x)
437    *
438    *
439    *
440    * TCLOSURE-UP II : (a, b) IS_IN TCLOSURE(x)  (b, c) IS_IN TCLOSURE(x)
441    *              -----------------------------------------------------------
442    *                            (a, c) IS_IN TCLOSURE(x)
443    *
444    */
applyTCRule(Node mem_rep,Node tc_rel,Node tc_rel_rep,Node exp)445   void TheorySetsRels::applyTCRule( Node mem_rep, Node tc_rel, Node tc_rel_rep, Node exp ) {
446     Trace("rels-debug") << "[Theory::Rels] *********** Applying TCLOSURE rule on a tc term = " << tc_rel
447                             << ", its representative = " << tc_rel_rep
448                             << " with member rep = " << mem_rep << " and explanation = " << exp << std::endl;
449     MEM_IT mem_it = d_rReps_memberReps_cache.find( tc_rel[0] );
450 
451     if( mem_it != d_rReps_memberReps_cache.end() && d_rel_nodes.find( tc_rel ) == d_rel_nodes.end()
452         && d_rRep_tcGraph.find( getRepresentative( tc_rel[0] ) ) ==  d_rRep_tcGraph.end() ) {
453       buildTCGraphForRel( tc_rel );
454       d_rel_nodes.insert( tc_rel );
455     }
456 
457     // mem_rep is a member of tc_rel[0] or mem_rep can be infered by TC_Graph of tc_rel[0], thus skip
458     if( isTCReachable( mem_rep, tc_rel) ) {
459       Trace("rels-debug") << "[Theory::Rels] mem_rep is a member of tc_rel[0] = " << tc_rel[0]
460                               << " or can be infered by TC_Graph of tc_rel[0]! " << std::endl;
461       return;
462     }
463     // add mem_rep to d_tcrRep_tcGraph
464     TC_IT tc_it = d_tcr_tcGraph.find( tc_rel );
465     Node mem_rep_fst = getRepresentative( RelsUtils::nthElementOfTuple( mem_rep, 0 ) );
466     Node mem_rep_snd = getRepresentative( RelsUtils::nthElementOfTuple( mem_rep, 1 ) );
467     Node mem_rep_tup = RelsUtils::constructPair( tc_rel, mem_rep_fst, mem_rep_snd );
468 
469     if( tc_it != d_tcr_tcGraph.end() ) {
470       std::map< Node, std::map< Node, Node > >::iterator tc_exp_it = d_tcr_tcGraph_exps.find( tc_rel );
471 
472       TC_GRAPH_IT tc_graph_it = (tc_it->second).find( mem_rep_fst );
473       Assert( tc_exp_it != d_tcr_tcGraph_exps.end() );
474       std::map< Node, Node >::iterator exp_map_it = (tc_exp_it->second).find( mem_rep_tup );
475 
476       if( exp_map_it == (tc_exp_it->second).end() ) {
477         (tc_exp_it->second)[mem_rep_tup] = exp;
478       }
479 
480       if( tc_graph_it != (tc_it->second).end() ) {
481         (tc_graph_it->second).insert( mem_rep_snd );
482       } else {
483         std::unordered_set< Node, NodeHashFunction > sets;
484         sets.insert( mem_rep_snd );
485         (tc_it->second)[mem_rep_fst] = sets;
486       }
487     } else {
488       std::map< Node, Node > exp_map;
489       std::unordered_set< Node, NodeHashFunction > sets;
490       std::map< Node, std::unordered_set<Node, NodeHashFunction> > element_map;
491       sets.insert( mem_rep_snd );
492       element_map[mem_rep_fst] = sets;
493       d_tcr_tcGraph[tc_rel] = element_map;
494       exp_map[mem_rep_tup] = exp;
495       d_tcr_tcGraph_exps[tc_rel] = exp_map;
496     }
497 
498     Node fst_element = RelsUtils::nthElementOfTuple( exp[0], 0 );
499     Node snd_element = RelsUtils::nthElementOfTuple( exp[0], 1 );
500     Node sk_1     = NodeManager::currentNM()->mkSkolem("stc", fst_element.getType());
501     Node sk_2     = NodeManager::currentNM()->mkSkolem("stc", snd_element.getType());
502     Node mem_of_r = NodeManager::currentNM()->mkNode(kind::MEMBER, exp[0], tc_rel[0]);
503     Node sk_eq    = NodeManager::currentNM()->mkNode(kind::EQUAL, sk_1, sk_2);
504     Node reason   = exp;
505 
506     if( tc_rel != exp[1] ) {
507       reason = NodeManager::currentNM()->mkNode(kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, tc_rel, exp[1]));
508     }
509 
510     Node conclusion = NodeManager::currentNM()->mkNode(kind::OR, mem_of_r,
511                                                      (NodeManager::currentNM()->mkNode(kind::AND, NodeManager::currentNM()->mkNode(kind::MEMBER, RelsUtils::constructPair(tc_rel, fst_element, sk_1), tc_rel[0]),
512                                                      (NodeManager::currentNM()->mkNode(kind::AND, NodeManager::currentNM()->mkNode(kind::MEMBER, RelsUtils::constructPair(tc_rel, sk_2, snd_element), tc_rel[0]),
513                                                      (NodeManager::currentNM()->mkNode(kind::OR, sk_eq, NodeManager::currentNM()->mkNode(kind::MEMBER, RelsUtils::constructPair(tc_rel, sk_1, sk_2), tc_rel))))))));
514 
515     Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, conclusion );
516     std::vector< Node > require_phase;
517     require_phase.push_back(Rewriter::rewrite(mem_of_r));
518     require_phase.push_back(Rewriter::rewrite(sk_eq));
519     d_tc_lemmas_last[tc_lemma] = require_phase;
520   }
521 
isTCReachable(Node mem_rep,Node tc_rel)522   bool TheorySetsRels::isTCReachable( Node mem_rep, Node tc_rel ) {
523     MEM_IT mem_it = d_rReps_memberReps_cache.find( getRepresentative( tc_rel[0] ) );
524 
525     if( mem_it != d_rReps_memberReps_cache.end() && std::find( (mem_it->second).begin(), (mem_it->second).end(), mem_rep) != (mem_it->second).end() ) {
526       return true;
527     }
528 
529     TC_IT tc_it = d_rRep_tcGraph.find( getRepresentative(tc_rel[0]) );
530     if( tc_it != d_rRep_tcGraph.end() ) {
531       bool isReachable = false;
532       std::unordered_set<Node, NodeHashFunction> seen;
533       isTCReachable( getRepresentative( RelsUtils::nthElementOfTuple(mem_rep, 0) ),
534                      getRepresentative( RelsUtils::nthElementOfTuple(mem_rep, 1) ), seen, tc_it->second, isReachable );
535       return isReachable;
536     }
537     return false;
538   }
539 
isTCReachable(Node start,Node dest,std::unordered_set<Node,NodeHashFunction> & hasSeen,std::map<Node,std::unordered_set<Node,NodeHashFunction>> & tc_graph,bool & isReachable)540   void TheorySetsRels::isTCReachable( Node start, Node dest, std::unordered_set<Node, NodeHashFunction>& hasSeen,
541                                     std::map< Node, std::unordered_set< Node, NodeHashFunction > >& tc_graph, bool& isReachable ) {
542     if(hasSeen.find(start) == hasSeen.end()) {
543       hasSeen.insert(start);
544     }
545 
546     TC_GRAPH_IT pair_set_it = tc_graph.find(start);
547 
548     if(pair_set_it != tc_graph.end()) {
549       if(pair_set_it->second.find(dest) != pair_set_it->second.end()) {
550         isReachable = true;
551         return;
552       } else {
553         std::unordered_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
554 
555         while( set_it != pair_set_it->second.end() ) {
556           // need to check if *set_it has been looked already
557           if( hasSeen.find(*set_it) == hasSeen.end() ) {
558             isTCReachable( *set_it, dest, hasSeen, tc_graph, isReachable );
559           }
560           set_it++;
561         }
562       }
563     }
564   }
565 
buildTCGraphForRel(Node tc_rel)566   void TheorySetsRels::buildTCGraphForRel( Node tc_rel ) {
567     std::map< Node, Node > rel_tc_graph_exps;
568     std::map< Node, std::unordered_set<Node, NodeHashFunction> > rel_tc_graph;
569 
570     Node rel_rep = getRepresentative( tc_rel[0] );
571     Node tc_rel_rep = getRepresentative( tc_rel );
572     std::vector< Node > members = d_rReps_memberReps_cache[rel_rep];
573     std::vector< Node > exps = d_rReps_memberReps_exp_cache[rel_rep];
574 
575     for( unsigned int i = 0; i < members.size(); i++ ) {
576       Node fst_element_rep = getRepresentative( RelsUtils::nthElementOfTuple( members[i], 0 ));
577       Node snd_element_rep = getRepresentative( RelsUtils::nthElementOfTuple( members[i], 1 ));
578       Node tuple_rep = RelsUtils::constructPair( rel_rep, fst_element_rep, snd_element_rep );
579       std::map< Node, std::unordered_set<Node, NodeHashFunction> >::iterator rel_tc_graph_it = rel_tc_graph.find( fst_element_rep );
580 
581       if( rel_tc_graph_it == rel_tc_graph.end() ) {
582         std::unordered_set< Node, NodeHashFunction > snd_elements;
583         snd_elements.insert( snd_element_rep );
584         rel_tc_graph[fst_element_rep] = snd_elements;
585         rel_tc_graph_exps[tuple_rep] = exps[i];
586       } else if( (rel_tc_graph_it->second).find( snd_element_rep ) ==  (rel_tc_graph_it->second).end() ) {
587         (rel_tc_graph_it->second).insert( snd_element_rep );
588         rel_tc_graph_exps[tuple_rep] = exps[i];
589       }
590     }
591 
592     if( members.size() > 0 ) {
593       d_rRep_tcGraph[rel_rep] = rel_tc_graph;
594       d_tcr_tcGraph_exps[tc_rel] = rel_tc_graph_exps;
595       d_tcr_tcGraph[tc_rel] = rel_tc_graph;
596     }
597   }
598 
doTCInference(std::map<Node,std::unordered_set<Node,NodeHashFunction>> rel_tc_graph,std::map<Node,Node> rel_tc_graph_exps,Node tc_rel)599   void TheorySetsRels::doTCInference( std::map< Node, std::unordered_set<Node, NodeHashFunction> > rel_tc_graph, std::map< Node, Node > rel_tc_graph_exps, Node tc_rel ) {
600     Trace("rels-debug") << "[Theory::Rels] ****** doTCInference !" << std::endl;
601     for( TC_GRAPH_IT tc_graph_it = rel_tc_graph.begin(); tc_graph_it != rel_tc_graph.end(); tc_graph_it++ ) {
602       for( std::unordered_set< Node, NodeHashFunction >::iterator snd_elements_it = tc_graph_it->second.begin();
603            snd_elements_it != tc_graph_it->second.end(); snd_elements_it++ ) {
604         std::vector< Node > reasons;
605         std::unordered_set<Node, NodeHashFunction> seen;
606         Node tuple = RelsUtils::constructPair( tc_rel, getRepresentative( tc_graph_it->first ), getRepresentative( *snd_elements_it) );
607         Assert( rel_tc_graph_exps.find( tuple ) != rel_tc_graph_exps.end() );
608         Node exp   = rel_tc_graph_exps.find( tuple )->second;
609 
610         reasons.push_back( exp );
611         seen.insert( tc_graph_it->first );
612         doTCInference( tc_rel, reasons, rel_tc_graph, rel_tc_graph_exps, tc_graph_it->first, *snd_elements_it, seen);
613       }
614     }
615     Trace("rels-debug") << "[Theory::Rels] ****** Done with doTCInference !" << std::endl;
616   }
617 
doTCInference(Node tc_rel,std::vector<Node> reasons,std::map<Node,std::unordered_set<Node,NodeHashFunction>> & tc_graph,std::map<Node,Node> & rel_tc_graph_exps,Node start_node_rep,Node cur_node_rep,std::unordered_set<Node,NodeHashFunction> & seen)618   void TheorySetsRels::doTCInference(Node tc_rel, std::vector< Node > reasons, std::map< Node, std::unordered_set< Node, NodeHashFunction > >& tc_graph,
619                                        std::map< Node, Node >& rel_tc_graph_exps, Node start_node_rep, Node cur_node_rep, std::unordered_set< Node, NodeHashFunction >& seen ) {
620     Node tc_mem = RelsUtils::constructPair( tc_rel, RelsUtils::nthElementOfTuple((reasons.front())[0], 0), RelsUtils::nthElementOfTuple((reasons.back())[0], 1) );
621     std::vector< Node > all_reasons( reasons );
622 
623     for( unsigned int i = 0 ; i < reasons.size()-1; i++ ) {
624       Node fst_element_end = RelsUtils::nthElementOfTuple( reasons[i][0], 1 );
625       Node snd_element_begin = RelsUtils::nthElementOfTuple( reasons[i+1][0], 0 );
626       if( fst_element_end != snd_element_begin ) {
627         all_reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, fst_element_end, snd_element_begin) );
628       }
629       if( tc_rel != reasons[i][1] && tc_rel[0] != reasons[i][1] ) {
630         all_reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, tc_rel[0], reasons[i][1]) );
631       }
632     }
633     if( tc_rel != reasons.back()[1] && tc_rel[0] != reasons.back()[1] ) {
634       all_reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, tc_rel[0], reasons.back()[1]) );
635     }
636     if( all_reasons.size() > 1) {
637       sendInfer( NodeManager::currentNM()->mkNode(kind::MEMBER, tc_mem, tc_rel), NodeManager::currentNM()->mkNode(kind::AND, all_reasons), "TCLOSURE-Forward");
638     } else {
639       sendInfer( NodeManager::currentNM()->mkNode(kind::MEMBER, tc_mem, tc_rel), all_reasons.front(), "TCLOSURE-Forward");
640     }
641 
642     // check if cur_node has been traversed or not
643     if( seen.find( cur_node_rep ) != seen.end() ) {
644       return;
645     }
646     seen.insert( cur_node_rep );
647     TC_GRAPH_IT  cur_set = tc_graph.find( cur_node_rep );
648     if( cur_set != tc_graph.end() ) {
649       for( std::unordered_set< Node, NodeHashFunction >::iterator set_it = cur_set->second.begin();
650            set_it != cur_set->second.end(); set_it++ ) {
651         Node new_pair = RelsUtils::constructPair( tc_rel, cur_node_rep, *set_it );
652         std::vector< Node > new_reasons( reasons );
653         new_reasons.push_back( rel_tc_graph_exps.find( new_pair )->second );
654         doTCInference( tc_rel, new_reasons, tc_graph, rel_tc_graph_exps, start_node_rep, *set_it, seen );
655       }
656     }
657   }
658 
659  /*  product-split rule:  (a, b) IS_IN (X PRODUCT Y)
660   *                     ----------------------------------
661   *                       a IS_IN X  && b IS_IN Y
662   *
663   *  product-compose rule: (a, b) IS_IN X    (c, d) IS_IN Y
664   *                        ---------------------------------
665   *                        (a, b, c, d) IS_IN (X PRODUCT Y)
666   */
667 
668 
applyProductRule(Node pt_rel,Node pt_rel_rep,Node exp)669   void TheorySetsRels::applyProductRule( Node pt_rel, Node pt_rel_rep, Node exp ) {
670     Trace("rels-debug") << "\n[Theory::Rels] *********** Applying PRODUCT rule on producted term = " << pt_rel
671                             << ", its representative = " << pt_rel_rep
672                             << " with explanation = " << exp << std::endl;
673 
674     if(d_rel_nodes.find( pt_rel ) == d_rel_nodes.end()) {
675       Trace("rels-debug") <<  "\n[Theory::Rels] Apply PRODUCT-COMPOSE rule on term: " << pt_rel
676                           << " with explanation: " << exp << std::endl;
677 
678       computeMembersForBinOpRel( pt_rel );
679       d_rel_nodes.insert( pt_rel );
680     }
681 
682     Node mem = exp[0];
683     std::vector<Node>   r1_element;
684     std::vector<Node>   r2_element;
685     Datatype     dt      = pt_rel[0].getType().getSetElementType().getDatatype();
686     unsigned int s1_len  = pt_rel[0].getType().getSetElementType().getTupleLength();
687     unsigned int tup_len = pt_rel.getType().getSetElementType().getTupleLength();
688 
689     r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
690 
691     unsigned int i = 0;
692     for(; i < s1_len; ++i) {
693       r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
694     }
695     dt = pt_rel[1].getType().getSetElementType().getDatatype();
696     r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
697     for(; i < tup_len; ++i) {
698       r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
699     }
700     Node reason   = exp;
701     Node mem1     = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
702     Node mem2     = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_element);
703     Node fact_1   = NodeManager::currentNM()->mkNode(kind::MEMBER, mem1, pt_rel[0]);
704     Node fact_2   = NodeManager::currentNM()->mkNode(kind::MEMBER, mem2, pt_rel[1]);
705 
706     if( pt_rel != exp[1] ) {
707       reason = NodeManager::currentNM()->mkNode(kind::AND, exp, NodeManager::currentNM()->mkNode(kind::EQUAL, pt_rel, exp[1]));
708     }
709     sendInfer( fact_1, reason, "product-split" );
710     sendInfer( fact_2, reason, "product-split" );
711   }
712 
713   /* join-split rule:           (a, b) IS_IN (X JOIN Y)
714    *                  --------------------------------------------
715    *                  exists z | (a, z) IS_IN X  && (z, b) IS_IN Y
716    *
717    *
718    * join-compose rule: (a, b) IS_IN X    (b, c) IS_IN Y  NOT (t, u) IS_IN (X JOIN Y)
719    *                    -------------------------------------------------------------
720    *                                      (a, c) IS_IN (X JOIN Y)
721    */
722 
applyJoinRule(Node join_rel,Node join_rel_rep,Node exp)723   void TheorySetsRels::applyJoinRule( Node join_rel, Node join_rel_rep, Node exp ) {
724     Trace("rels-debug") << "\n[Theory::Rels] *********** Applying JOIN rule on joined term = " << join_rel
725                             << ", its representative = " << join_rel_rep
726                             << " with explanation = " << exp << std::endl;
727     if(d_rel_nodes.find( join_rel ) == d_rel_nodes.end()) {
728       Trace("rels-debug") <<  "\n[Theory::Rels] Apply JOIN-COMPOSE rule on term: " << join_rel
729                           << " with explanation: " << exp << std::endl;
730 
731       computeMembersForBinOpRel( join_rel );
732       d_rel_nodes.insert( join_rel );
733     }
734 
735     Node mem = exp[0];
736     std::vector<Node> r1_element;
737     std::vector<Node> r2_element;
738     Node r1_rep = getRepresentative(join_rel[0]);
739     Node r2_rep = getRepresentative(join_rel[1]);
740     TypeNode     shared_type    = r2_rep.getType().getSetElementType().getTupleTypes()[0];
741     Node         shared_x       = NodeManager::currentNM()->mkSkolem("srj_", shared_type);
742     Datatype     dt             = join_rel[0].getType().getSetElementType().getDatatype();
743     unsigned int s1_len         = join_rel[0].getType().getSetElementType().getTupleLength();
744     unsigned int tup_len        = join_rel.getType().getSetElementType().getTupleLength();
745 
746     unsigned int i = 0;
747     r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
748     for(; i < s1_len-1; ++i) {
749       r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
750     }
751     r1_element.push_back(shared_x);
752     dt = join_rel[1].getType().getSetElementType().getDatatype();
753     r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
754     r2_element.push_back(shared_x);
755     for(; i < tup_len; ++i) {
756       r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
757     }
758     Node mem1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
759     Node mem2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_element);
760 
761     computeTupleReps(mem1);
762     computeTupleReps(mem2);
763 
764     std::vector<Node> elements = d_membership_trie[r1_rep].findTerms(d_tuple_reps[mem1]);
765 
766     for(unsigned int j = 0; j < elements.size(); j++) {
767       std::vector<Node> new_tup;
768       new_tup.push_back(elements[j]);
769       new_tup.insert(new_tup.end(), d_tuple_reps[mem2].begin()+1, d_tuple_reps[mem2].end());
770       if(d_membership_trie[r2_rep].existsTerm(new_tup) != Node::null()) {
771         return;
772       }
773     }
774     Node reason = exp;
775     if( join_rel != exp[1] ) {
776       reason = NodeManager::currentNM()->mkNode(kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, join_rel, exp[1]));
777     }
778     Node fact = NodeManager::currentNM()->mkNode(kind::MEMBER, mem1, join_rel[0]);
779     sendInfer( fact, reason, "JOIN-Split" );
780     fact = NodeManager::currentNM()->mkNode(kind::MEMBER, mem2, join_rel[1]);
781     sendInfer( fact, reason, "JOIN-Split" );
782     makeSharedTerm( shared_x );
783   }
784 
785   /*
786    * transpose-occur rule:    (a, b) IS_IN X   (TRANSPOSE X) in T
787    *                         ---------------------------------------
788    *                            (b, a) IS_IN (TRANSPOSE X)
789    *
790    * transpose-reverse rule:    (a, b) IS_IN (TRANSPOSE X)
791    *                         ---------------------------------------
792    *                            (b, a) IS_IN X
793    *
794    */
applyTransposeRule(std::vector<Node> tp_terms)795   void TheorySetsRels::applyTransposeRule( std::vector<Node> tp_terms ) {
796     if( tp_terms.size() < 1) {
797       return;
798     }
799     for( unsigned int i = 1; i < tp_terms.size(); i++ ) {
800       Trace("rels-debug") << "\n[Theory::Rels] *********** Applying TRANSPOSE-Equal rule on transposed term = " << tp_terms[0] << " and " << tp_terms[i] << std::endl;
801       sendInfer(NodeManager::currentNM()->mkNode(kind::EQUAL, tp_terms[0][0], tp_terms[i][0]), NodeManager::currentNM()->mkNode(kind::EQUAL, tp_terms[0], tp_terms[i]), "TRANSPOSE-Equal");
802     }
803   }
804 
applyTransposeRule(Node tp_rel,Node tp_rel_rep,Node exp)805   void TheorySetsRels::applyTransposeRule( Node tp_rel, Node tp_rel_rep, Node exp ) {
806     Trace("rels-debug") << "\n[Theory::Rels] *********** Applying TRANSPOSE rule on transposed term = " << tp_rel
807                         << ", its representative = " << tp_rel_rep
808                         << " with explanation = " << exp << std::endl;
809 
810     if(d_rel_nodes.find( tp_rel ) == d_rel_nodes.end()) {
811       Trace("rels-debug") <<  "\n[Theory::Rels] Apply TRANSPOSE-Compose rule on term: " << tp_rel
812                           << " with explanation: " << exp << std::endl;
813 
814       computeMembersForUnaryOpRel( tp_rel );
815       d_rel_nodes.insert( tp_rel );
816     }
817 
818     Node reason = exp;
819     Node reversed_mem = RelsUtils::reverseTuple( exp[0] );
820 
821     if( tp_rel != exp[1] ) {
822       reason = NodeManager::currentNM()->mkNode(kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, tp_rel, exp[1]));
823     }
824     sendInfer( NodeManager::currentNM()->mkNode(kind::MEMBER, reversed_mem, tp_rel[0]), reason, "TRANSPOSE-Reverse" );
825   }
826 
doTCInference()827   void TheorySetsRels::doTCInference() {
828     Trace("rels-debug") << "[Theory::Rels] ****** Finalizing transitive closure inferences!" << std::endl;
829     TC_IT tc_graph_it = d_tcr_tcGraph.begin();
830     while( tc_graph_it != d_tcr_tcGraph.end() ) {
831       Assert ( d_tcr_tcGraph_exps.find(tc_graph_it->first) != d_tcr_tcGraph_exps.end() );
832       doTCInference( tc_graph_it->second, d_tcr_tcGraph_exps.find(tc_graph_it->first)->second, tc_graph_it->first );
833       tc_graph_it++;
834     }
835     Trace("rels-debug") << "[Theory::Rels] ****** Done with finalizing transitive closure inferences!" << std::endl;
836   }
837 
838 
839   // Bottom-up fashion to compute relations with more than 1 arity
computeMembersForBinOpRel(Node rel)840   void TheorySetsRels::computeMembersForBinOpRel(Node rel) {
841     Trace("rels-debug") << "\n[Theory::Rels] computeMembersForBinOpRel for relation  " << rel << std::endl;
842 
843     switch(rel[0].getKind()) {
844       case kind::TRANSPOSE:
845       case kind::TCLOSURE: {
846         computeMembersForUnaryOpRel(rel[0]);
847         break;
848       }
849       case kind::JOIN:
850       case kind::PRODUCT: {
851         computeMembersForBinOpRel(rel[0]);
852         break;
853       }
854       default:
855         break;
856     }
857     switch(rel[1].getKind()) {
858       case kind::TRANSPOSE: {
859         computeMembersForUnaryOpRel(rel[1]);
860         break;
861       }
862       case kind::JOIN:
863       case kind::PRODUCT: {
864         computeMembersForBinOpRel(rel[1]);
865         break;
866       }
867       default:
868         break;
869     }
870     if(d_rReps_memberReps_cache.find(getRepresentative(rel[0])) == d_rReps_memberReps_cache.end() ||
871        d_rReps_memberReps_cache.find(getRepresentative(rel[1])) == d_rReps_memberReps_cache.end()) {
872       return;
873     }
874     composeMembersForRels(rel);
875   }
876 
877   // Bottom-up fashion to compute unary relation
computeMembersForUnaryOpRel(Node rel)878   void TheorySetsRels::computeMembersForUnaryOpRel(Node rel) {
879     Trace("rels-debug") << "\n[Theory::Rels] computeMembersForUnaryOpRel for relation  " << rel << std::endl;
880 
881     switch(rel[0].getKind()) {
882       case kind::TRANSPOSE:
883       case kind::TCLOSURE:
884         computeMembersForUnaryOpRel(rel[0]);
885         break;
886       case kind::JOIN:
887       case kind::PRODUCT:
888         computeMembersForBinOpRel(rel[0]);
889         break;
890       default:
891         break;
892     }
893 
894     Node rel0_rep  = getRepresentative(rel[0]);
895     if(d_rReps_memberReps_cache.find( rel0_rep ) == d_rReps_memberReps_cache.end())
896       return;
897 
898     std::vector<Node>   members = d_rReps_memberReps_cache[rel0_rep];
899     std::vector<Node>   exps    = d_rReps_memberReps_exp_cache[rel0_rep];
900 
901     Assert( members.size() == exps.size() );
902 
903     for(unsigned int i = 0; i < members.size(); i++) {
904       Node reason = exps[i];
905       if( rel.getKind() == kind::TRANSPOSE) {
906         if( rel[0] != exps[i][1] ) {
907           reason = NodeManager::currentNM()->mkNode(kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, rel[0], exps[i][1]));
908         }
909         sendInfer(NodeManager::currentNM()->mkNode(kind::MEMBER, RelsUtils::reverseTuple(exps[i][0]), rel), reason, "TRANSPOSE-reverse");
910       }
911     }
912   }
913 
914   /*
915    * Explicitly compose the join or product relations of r1 and r2
916    * e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y)
917    *
918    */
composeMembersForRels(Node rel)919   void TheorySetsRels::composeMembersForRels( Node rel ) {
920     Trace("rels-debug") << "[Theory::Rels] Start composing members for relation = " << rel << std::endl;
921     Node r1 = rel[0];
922     Node r2 = rel[1];
923     Node r1_rep = getRepresentative( r1 );
924     Node r2_rep = getRepresentative( r2 );
925 
926     if(d_rReps_memberReps_cache.find( r1_rep ) == d_rReps_memberReps_cache.end() ||
927        d_rReps_memberReps_cache.find( r2_rep ) == d_rReps_memberReps_cache.end() ) {
928       return;
929     }
930 
931     std::vector<Node> r1_rep_exps = d_rReps_memberReps_exp_cache[r1_rep];
932     std::vector<Node> r2_rep_exps = d_rReps_memberReps_exp_cache[r2_rep];
933     unsigned int r1_tuple_len = r1.getType().getSetElementType().getTupleLength();
934     unsigned int r2_tuple_len = r2.getType().getSetElementType().getTupleLength();
935 
936     for( unsigned int i = 0; i < r1_rep_exps.size(); i++ ) {
937       for( unsigned int j = 0; j < r2_rep_exps.size(); j++ ) {
938         std::vector<Node> tuple_elements;
939         TypeNode tn = rel.getType().getSetElementType();
940         Node r1_rmost = RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], r1_tuple_len-1 );
941         Node r2_lmost = RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], 0 );
942         tuple_elements.push_back( Node::fromExpr(tn.getDatatype()[0].getConstructor()) );
943 
944         if( (areEqual(r1_rmost, r2_lmost) && rel.getKind() == kind::JOIN) ||
945             rel.getKind() == kind::PRODUCT ) {
946           bool isProduct = rel.getKind() == kind::PRODUCT;
947           unsigned int k = 0;
948           unsigned int l = 1;
949 
950           for( ; k < r1_tuple_len - 1; ++k ) {
951             tuple_elements.push_back( RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) );
952           }
953           if(isProduct) {
954             tuple_elements.push_back( RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) );
955             tuple_elements.push_back( RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], 0 ) );
956           }
957           for( ; l < r2_tuple_len; ++l ) {
958             tuple_elements.push_back( RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], l ) );
959           }
960 
961           Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, tuple_elements);
962           Node fact = NodeManager::currentNM()->mkNode(kind::MEMBER, composed_tuple, rel);
963           std::vector<Node> reasons;
964           reasons.push_back( r1_rep_exps[i] );
965           reasons.push_back( r2_rep_exps[j] );
966 
967           if( r1 != r1_rep_exps[i][1] ) {
968             reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, r1, r1_rep_exps[i][1]) );
969           }
970           if( r2 != r2_rep_exps[j][1] ) {
971             reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, r2, r2_rep_exps[j][1]) );
972           }
973           if( isProduct ) {
974             sendInfer( fact, NodeManager::currentNM()->mkNode(kind::AND, reasons), "PRODUCT-Compose" );
975           } else {
976             if( r1_rmost != r2_lmost ) {
977               reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, r1_rmost, r2_lmost) );
978             }
979             sendInfer( fact, NodeManager::currentNM()->mkNode(kind::AND, reasons), "JOIN-Compose" );
980           }
981         }
982       }
983     }
984 
985   }
986 
doPendingLemmas()987   void TheorySetsRels::doPendingLemmas() {
988     Trace("rels-debug") << "[Theory::Rels] **************** Start doPendingLemmas !" << std::endl;
989     if( !(*d_conflict) ){
990       if ( (!d_lemmas_out.empty() || !d_pending_facts.empty()) ) {
991         for( unsigned i=0; i < d_lemmas_out.size(); i++ ){
992           Assert(d_lemmas_out[i].getKind() == kind::IMPLIES);
993           if(holds( d_lemmas_out[i][1] )) {
994             Trace("rels-lemma-skip") << "[sets-rels-lemma-skip] Skip an already held lemma: "
995                                      << d_lemmas_out[i]<< std::endl;
996             continue;
997           }
998           d_sets_theory.d_out->lemma( d_lemmas_out[i] );
999           Trace("rels-lemma") << "[sets-rels-lemma] Send out a lemma : "
1000                               << d_lemmas_out[i] << std::endl;
1001         }
1002         for( std::map<Node, Node>::iterator pending_it = d_pending_facts.begin();
1003              pending_it != d_pending_facts.end(); pending_it++ ) {
1004           if( holds( pending_it->first ) ) {
1005             Trace("rels-lemma-skip") << "[sets-rels-fact-lemma-skip] Skip an already held fact,: "
1006                                      << pending_it->first << std::endl;
1007             continue;
1008           }
1009           Node lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, pending_it->second, pending_it->first);
1010           if( d_lemmas_produced.find( lemma ) == d_lemmas_produced.end() ) {
1011             d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, pending_it->second, pending_it->first));
1012             Trace("rels-lemma") << "[sets-rels-fact-lemma] Send out a fact as lemma : "
1013                                 << pending_it->first << " with reason " << pending_it->second << std::endl;
1014             d_lemmas_produced.insert( lemma );
1015           }
1016         }
1017       }
1018     }
1019     doTCLemmas();
1020     Trace("rels-debug") << "[Theory::Rels] **************** Done with doPendingLemmas !" << std::endl;
1021     d_tuple_reps.clear();
1022     d_rReps_memberReps_exp_cache.clear();
1023     d_terms_cache.clear();
1024     d_lemmas_out.clear();
1025     d_membership_trie.clear();
1026     d_rel_nodes.clear();
1027     d_pending_facts.clear();
1028     d_rReps_memberReps_cache.clear();
1029     d_rRep_tcGraph.clear();
1030     d_tcr_tcGraph_exps.clear();
1031     d_tcr_tcGraph.clear();
1032     d_tc_lemmas_last.clear();
1033   }
1034 
1035 
isRelationKind(Kind k)1036   bool TheorySetsRels::isRelationKind( Kind k ) {
1037     return k == kind::TRANSPOSE || k == kind::PRODUCT || k == kind::JOIN || k == kind::TCLOSURE;
1038   }
1039 
doTCLemmas()1040   void TheorySetsRels::doTCLemmas() {
1041     Trace("rels-debug") << "[Theory::Rels] **************** Start doTCLemmas !" << std::endl;
1042     std::map< Node, std::vector< Node > >::iterator tc_lemma_it = d_tc_lemmas_last.begin();
1043     while( tc_lemma_it != d_tc_lemmas_last.end() ) {
1044       if( !holds( tc_lemma_it->first[1] ) ) {
1045         if( d_lemmas_produced.find( tc_lemma_it->first ) == d_lemmas_produced.end() ) {
1046           d_sets_theory.d_out->lemma( tc_lemma_it->first );
1047           d_lemmas_produced.insert( tc_lemma_it->first );
1048 
1049           for( unsigned int i = 0; i < (tc_lemma_it->second).size(); i++ ) {
1050             if( (tc_lemma_it->second)[i] == d_falseNode ) {
1051               d_sets_theory.d_out->requirePhase((tc_lemma_it->second)[i], true);
1052             }
1053           }
1054           Trace("rels-lemma") << "[Theory::Rels] **** Send out a TC lemma = " << tc_lemma_it->first << " by " << "TCLOSURE-Forward"<< std::endl;
1055         }
1056       }
1057       ++tc_lemma_it;
1058     }
1059     Trace("rels-debug") << "[Theory::Rels] **************** Done with doTCLemmas !" << std::endl;
1060   }
1061 
sendLemma(Node conc,Node ant,const char * c)1062   void TheorySetsRels::sendLemma(Node conc, Node ant, const char * c) {
1063     if( !holds( conc ) ) {
1064       Node lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, ant, conc);
1065       if( d_lemmas_produced.find( lemma ) == d_lemmas_produced.end() ) {
1066         d_lemmas_out.push_back( lemma );
1067         d_lemmas_produced.insert( lemma );
1068         Trace("rels-send-lemma") << "[Theory::Rels] **** Generate a lemma conclusion = " << conc << " with reason = " << ant << " by " << c << std::endl;
1069       }
1070     }
1071   }
1072 
sendInfer(Node fact,Node exp,const char * c)1073   void TheorySetsRels::sendInfer( Node fact, Node exp, const char * c ) {
1074     if( !holds( fact ) ) {
1075       Trace("rels-send-lemma") << "[Theory::Rels] **** Generate an infered fact "
1076                                << fact << " with reason " << exp << " by "<< c << std::endl;
1077       d_pending_facts[fact] = exp;
1078     } else {
1079       Trace("rels-send-lemma-debug") << "[Theory::Rels] **** Generate an infered fact "
1080                                      << fact << " with reason " << exp << " by "<< c
1081                                      << ", but it holds already, thus skip it!" << std::endl;
1082     }
1083   }
1084 
getRepresentative(Node t)1085   Node TheorySetsRels::getRepresentative( Node t ) {
1086     if( d_eqEngine->hasTerm( t ) ){
1087       return d_eqEngine->getRepresentative( t );
1088     }else{
1089       return t;
1090     }
1091   }
1092 
hasTerm(Node a)1093   bool TheorySetsRels::hasTerm( Node a ){
1094     return d_eqEngine->hasTerm( a );
1095   }
1096 
areEqual(Node a,Node b)1097   bool TheorySetsRels::areEqual( Node a, Node b ){
1098     Assert(a.getType() == b.getType());
1099     Trace("rels-eq") << "[sets-rels]**** checking equality between " << a << " and " << b << std::endl;
1100     if(a == b) {
1101       return true;
1102     } else if( hasTerm( a ) && hasTerm( b ) ){
1103       return d_eqEngine->areEqual( a, b );
1104     } else if(a.getType().isTuple()) {
1105       bool equal = true;
1106       for(unsigned int i = 0; i < a.getType().getTupleLength(); i++) {
1107         equal = equal && areEqual(RelsUtils::nthElementOfTuple(a, i), RelsUtils::nthElementOfTuple(b, i));
1108       }
1109       return equal;
1110     } else if(!a.getType().isBoolean()){
1111       makeSharedTerm(a);
1112       makeSharedTerm(b);
1113     }
1114     return false;
1115   }
1116 
1117   /*
1118    * Make sure duplicate members are not added in map
1119    */
safelyAddToMap(std::map<Node,std::vector<Node>> & map,Node rel_rep,Node member)1120   bool TheorySetsRels::safelyAddToMap(std::map< Node, std::vector<Node> >& map, Node rel_rep, Node member) {
1121     std::map< Node, std::vector< Node > >::iterator mem_it = map.find(rel_rep);
1122     if(mem_it == map.end()) {
1123       std::vector<Node> members;
1124       members.push_back(member);
1125       map[rel_rep] = members;
1126       return true;
1127     } else {
1128       std::vector<Node>::iterator mems = mem_it->second.begin();
1129       while(mems != mem_it->second.end()) {
1130         if(areEqual(*mems, member)) {
1131           return false;
1132         }
1133         mems++;
1134       }
1135       map[rel_rep].push_back(member);
1136       return true;
1137     }
1138     return false;
1139   }
1140 
addToMap(std::map<Node,std::vector<Node>> & map,Node rel_rep,Node member)1141   void TheorySetsRels::addToMap(std::map< Node, std::vector<Node> >& map, Node rel_rep, Node member) {
1142     if(map.find(rel_rep) == map.end()) {
1143       std::vector<Node> members;
1144       members.push_back(member);
1145       map[rel_rep] = members;
1146     } else {
1147       map[rel_rep].push_back(member);
1148     }
1149   }
1150 
addSharedTerm(TNode n)1151   void TheorySetsRels::addSharedTerm( TNode n ) {
1152     Trace("rels-debug") << "[sets-rels] Add a shared term:  " << n << std::endl;
1153     d_sets_theory.addSharedTerm(n);
1154     d_eqEngine->addTriggerTerm(n, THEORY_SETS);
1155   }
1156 
makeSharedTerm(Node n)1157   void TheorySetsRels::makeSharedTerm( Node n ) {
1158     Trace("rels-share") << " [sets-rels] making shared term " << n << std::endl;
1159     if(d_shared_terms.find(n) == d_shared_terms.end()) {
1160       Node skolem = NodeManager::currentNM()->mkSkolem( "sts", NodeManager::currentNM()->mkSetType( n.getType() ) );
1161       sendLemma(skolem.eqNode(NodeManager::currentNM()->mkNode(kind::SINGLETON,n)), d_trueNode, "share-term");
1162       d_shared_terms.insert(n);
1163     }
1164   }
1165 
holds(Node node)1166   bool TheorySetsRels::holds(Node node) {
1167     bool polarity       = node.getKind() != kind::NOT;
1168     Node atom           = polarity ? node : node[0];
1169     return d_sets_theory.isEntailed( atom, polarity );
1170   }
1171 
1172   /*
1173    * For each tuple n, we store a mapping between n and a list of its elements representatives
1174    * in d_tuple_reps. This would later be used for applying JOIN operator.
1175    */
computeTupleReps(Node n)1176   void TheorySetsRels::computeTupleReps( Node n ) {
1177     if( d_tuple_reps.find( n ) == d_tuple_reps.end() ){
1178       for( unsigned i = 0; i < n.getType().getTupleLength(); i++ ){
1179         d_tuple_reps[n].push_back( getRepresentative( RelsUtils::nthElementOfTuple(n, i) ) );
1180       }
1181     }
1182   }
1183 
1184   /*
1185    * Node n[0] is a tuple variable, reduce n[0] to a concrete representation,
1186    * which is (e1, ..., en) where e1, ... ,en are concrete elements of tuple n[0].
1187    */
reduceTupleVar(Node n)1188   void TheorySetsRels::reduceTupleVar(Node n) {
1189     if(d_symbolic_tuples.find(n) == d_symbolic_tuples.end()) {
1190       Trace("rels-debug") << "[Theory::Rels] Reduce tuple var: " << n[0] << " to a concrete one " << " node = " << n << std::endl;
1191       std::vector<Node> tuple_elements;
1192       tuple_elements.push_back(Node::fromExpr((n[0].getType().getDatatype())[0].getConstructor()));
1193       for(unsigned int i = 0; i < n[0].getType().getTupleLength(); i++) {
1194         Node element = RelsUtils::nthElementOfTuple(n[0], i);
1195         makeSharedTerm(element);
1196         tuple_elements.push_back(element);
1197       }
1198       Node tuple_reduct = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, tuple_elements);
1199       tuple_reduct = NodeManager::currentNM()->mkNode(kind::MEMBER,tuple_reduct, n[1]);
1200       Node tuple_reduction_lemma = NodeManager::currentNM()->mkNode(kind::EQUAL, n, tuple_reduct);
1201       sendLemma(tuple_reduction_lemma, d_trueNode, "tuple-reduction");
1202       d_symbolic_tuples.insert(n);
1203     }
1204   }
1205 
TheorySetsRels(context::Context * c,context::UserContext * u,eq::EqualityEngine * eq,context::CDO<bool> * conflict,TheorySets & d_set)1206   TheorySetsRels::TheorySetsRels( context::Context* c,
1207                                   context::UserContext* u,
1208                                   eq::EqualityEngine* eq,
1209                                   context::CDO<bool>* conflict,
1210                                   TheorySets& d_set ):
1211     d_eqEngine(eq),
1212     d_conflict(conflict),
1213     d_sets_theory(d_set),
1214     d_trueNode(NodeManager::currentNM()->mkConst<bool>(true)),
1215     d_falseNode(NodeManager::currentNM()->mkConst<bool>(false)),
1216     d_pending_merge(c),
1217     d_lemmas_produced(u),
1218     d_shared_terms(u)
1219   {
1220     d_eqEngine->addFunctionKind(kind::PRODUCT);
1221     d_eqEngine->addFunctionKind(kind::JOIN);
1222     d_eqEngine->addFunctionKind(kind::TRANSPOSE);
1223     d_eqEngine->addFunctionKind(kind::TCLOSURE);
1224     d_eqEngine->addFunctionKind(kind::JOIN_IMAGE);
1225     d_eqEngine->addFunctionKind(kind::IDEN);
1226   }
1227 
~TheorySetsRels()1228   TheorySetsRels::~TheorySetsRels() {
1229     for(std::map< Node, EqcInfo* >::iterator i = d_eqc_info.begin(), iend = d_eqc_info.end();
1230         i != iend; ++i){
1231       EqcInfo* current = (*i).second;
1232       Assert(current != NULL);
1233       delete current;
1234     }
1235   }
1236 
findTerms(std::vector<Node> & reps,int argIndex)1237   std::vector<Node> TupleTrie::findTerms( std::vector< Node >& reps, int argIndex ) {
1238     std::vector<Node>                           nodes;
1239     std::map< Node, TupleTrie >::iterator       it;
1240 
1241     if( argIndex==(int)reps.size()-1 ){
1242       if(reps[argIndex].getKind() == kind::SKOLEM) {
1243         it = d_data.begin();
1244         while(it != d_data.end()) {
1245           nodes.push_back(it->first);
1246           it++;
1247         }
1248       }
1249       return nodes;
1250     }else{
1251       it = d_data.find( reps[argIndex] );
1252       if( it==d_data.end() ){
1253         return nodes;
1254       }else{
1255         return it->second.findTerms( reps, argIndex+1 );
1256       }
1257     }
1258   }
1259 
findSuccessors(std::vector<Node> & reps,int argIndex)1260   std::vector<Node> TupleTrie::findSuccessors( std::vector< Node >& reps, int argIndex ) {
1261     std::vector<Node>   nodes;
1262     std::map< Node, TupleTrie >::iterator it;
1263 
1264     if( argIndex==(int)reps.size() ){
1265       it = d_data.begin();
1266       while(it != d_data.end()) {
1267         nodes.push_back(it->first);
1268         it++;
1269       }
1270       return nodes;
1271     }else{
1272       it = d_data.find( reps[argIndex] );
1273       if( it==d_data.end() ){
1274         return nodes;
1275       }else{
1276         return it->second.findSuccessors( reps, argIndex+1 );
1277       }
1278     }
1279   }
1280 
existsTerm(std::vector<Node> & reps,int argIndex)1281   Node TupleTrie::existsTerm( std::vector< Node >& reps, int argIndex ) {
1282     if( argIndex==(int)reps.size() ){
1283       if( d_data.empty() ){
1284         return Node::null();
1285       }else{
1286         return d_data.begin()->first;
1287       }
1288     }else{
1289       std::map< Node, TupleTrie >::iterator it = d_data.find( reps[argIndex] );
1290       if( it==d_data.end() ){
1291         return Node::null();
1292       }else{
1293         return it->second.existsTerm( reps, argIndex+1 );
1294       }
1295     }
1296   }
1297 
addTerm(Node n,std::vector<Node> & reps,int argIndex)1298   bool TupleTrie::addTerm( Node n, std::vector< Node >& reps, int argIndex ){
1299     if( argIndex==(int)reps.size() ){
1300       if( d_data.empty() ){
1301         //store n in d_data (this should be interpretted as the "data" and not as a reference to a child)
1302         d_data[n].clear();
1303         return true;
1304       }else{
1305         return false;
1306       }
1307     }else{
1308       return d_data[reps[argIndex]].addTerm( n, reps, argIndex+1 );
1309     }
1310   }
1311 
debugPrint(const char * c,Node n,unsigned depth)1312   void TupleTrie::debugPrint( const char * c, Node n, unsigned depth ) {
1313     for( std::map< Node, TupleTrie >::iterator it = d_data.begin(); it != d_data.end(); ++it ){
1314       for( unsigned i=0; i<depth; i++ ){ Debug(c) << "  "; }
1315       Debug(c) << it->first << std::endl;
1316       it->second.debugPrint( c, n, depth+1 );
1317     }
1318   }
1319 
EqcInfo(context::Context * c)1320   TheorySetsRels::EqcInfo::EqcInfo( context::Context* c ) :
1321   d_mem(c), d_mem_exp(c), d_tp(c), d_pt(c), d_tc(c), d_rel_tc(c) {}
1322 
eqNotifyNewClass(Node n)1323   void TheorySetsRels::eqNotifyNewClass( Node n ) {
1324     Trace("rels-std") << "[sets-rels] eqNotifyNewClass:" << " t = " << n << std::endl;
1325     if(n.getKind() == kind::TRANSPOSE || n.getKind() == kind::PRODUCT || n.getKind() == kind::TCLOSURE) {
1326       getOrMakeEqcInfo( n, true );
1327       if( n.getKind() == kind::TCLOSURE ) {
1328         Node relRep_of_tc = getRepresentative( n[0] );
1329         EqcInfo*  rel_ei = getOrMakeEqcInfo( relRep_of_tc, true );
1330 
1331         if( rel_ei->d_rel_tc.get().isNull() ) {
1332           rel_ei->d_rel_tc = n;
1333           Node exp = relRep_of_tc == n[0] ? d_trueNode : NodeManager::currentNM()->mkNode( kind::EQUAL, relRep_of_tc, n[0] );
1334           for( NodeSet::const_iterator mem_it = rel_ei->d_mem.begin(); mem_it != rel_ei->d_mem.end(); mem_it++ ) {
1335             Node mem_exp = (*rel_ei->d_mem_exp.find(*mem_it)).second;
1336             exp = NodeManager::currentNM()->mkNode( kind::AND, exp, mem_exp);
1337             if( mem_exp[1] != relRep_of_tc ) {
1338               exp = NodeManager::currentNM()->mkNode( kind::AND, exp, NodeManager::currentNM()->mkNode(kind::EQUAL, relRep_of_tc, mem_exp[1] ) );
1339             }
1340             sendMergeInfer( NodeManager::currentNM()->mkNode(kind::MEMBER, mem_exp[0], n), exp, "TCLOSURE-UP I" );
1341           }
1342         }
1343       }
1344     }
1345   }
1346 
1347   // Merge t2 into t1, t1 will be the rep of the new eqc
eqNotifyPostMerge(Node t1,Node t2)1348   void TheorySetsRels::eqNotifyPostMerge( Node t1, Node t2 ) {
1349     Trace("rels-std") << "[sets-rels-std] eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
1350 
1351     // Merge membership constraint with "true" eqc
1352     if( t1 == d_trueNode && t2.getKind() == kind::MEMBER && t2[0].getType().isTuple() ) {
1353       Node      mem_rep  = getRepresentative( t2[0] );
1354       Node      t2_1rep  = getRepresentative( t2[1] );
1355       EqcInfo*  ei       = getOrMakeEqcInfo( t2_1rep, true );
1356       if(ei->d_mem.contains(mem_rep)) {
1357         return;
1358       }
1359       Node exp = t2;
1360 
1361       ei->d_mem.insert( mem_rep );
1362       ei->d_mem_exp.insert( mem_rep, exp );
1363 
1364       // Process a membership constraint that a tuple is a member of transpose of rel
1365       if( !ei->d_tp.get().isNull() ) {
1366         if( ei->d_tp.get() != t2[1] ) {
1367           exp = NodeManager::currentNM()->mkNode(kind::AND, NodeManager::currentNM()->mkNode(kind::EQUAL, ei->d_tp.get(), t2[1]), t2 );
1368         }
1369         sendInferTranspose( t2[0], ei->d_tp.get(), exp );
1370       }
1371       // Process a membership constraint that a tuple is a member of product of rel
1372       if( !ei->d_pt.get().isNull() ) {
1373         if( ei->d_pt.get() != t2[1] ) {
1374           exp = NodeManager::currentNM()->mkNode(kind::AND, NodeManager::currentNM()->mkNode(kind::EQUAL, ei->d_pt.get(), t2[1]), t2 );
1375         }
1376         sendInferProduct( t2[0], ei->d_pt.get(), exp );
1377       }
1378 
1379       if( !ei->d_rel_tc.get().isNull() ) {
1380         if( ei->d_rel_tc.get()[0] != t2[1] ) {
1381           exp = NodeManager::currentNM()->mkNode(kind::AND, NodeManager::currentNM()->mkNode(kind::EQUAL, ei->d_rel_tc.get()[0], t2[1]), t2 );
1382         }
1383         sendMergeInfer(NodeManager::currentNM()->mkNode(kind::MEMBER, t2[0], ei->d_rel_tc.get()), exp, "TCLOSURE-UP I");
1384       }
1385       // Process a membership constraint that a tuple is a member of transitive closure of rel
1386       if( !ei->d_tc.get().isNull() ) {
1387         sendInferTClosure( t2, ei );
1388       }
1389 
1390     // Merge two relation eqcs
1391     } else if( t1.getType().isSet() && t2.getType().isSet() && t1.getType().getSetElementType().isTuple() ) {
1392       EqcInfo* t1_ei = getOrMakeEqcInfo( t1 );
1393       EqcInfo* t2_ei = getOrMakeEqcInfo( t2 );
1394 
1395       if( t1_ei != NULL && t2_ei != NULL ) {
1396         if( !t1_ei->d_tp.get().isNull() && !t2_ei->d_tp.get().isNull() ) {
1397           sendInferTranspose(t1_ei->d_tp.get(), t2_ei->d_tp.get(), NodeManager::currentNM()->mkNode(kind::EQUAL, t1_ei->d_tp.get(), t2_ei->d_tp.get() ) );
1398         }
1399         std::vector< Node > t2_new_mems;
1400         std::vector< Node > t2_new_exps;
1401         NodeSet::const_iterator t2_mem_it = t2_ei->d_mem.begin();
1402         NodeSet::const_iterator t1_mem_it = t1_ei->d_mem.begin();
1403 
1404         for( ; t2_mem_it != t2_ei->d_mem.end(); t2_mem_it++ ) {
1405           if( !t1_ei->d_mem.contains( *t2_mem_it ) ) {
1406             Node t2_mem_exp = (*t2_ei->d_mem_exp.find(*t2_mem_it)).second;
1407 
1408             if( t2_ei->d_tp.get().isNull() && !t1_ei->d_tp.get().isNull() ) {
1409               Node reason = t1_ei->d_tp.get() == (t2_mem_exp)[1]
1410                             ? (t2_mem_exp) : NodeManager::currentNM()->mkNode(kind::AND, t2_mem_exp, NodeManager::currentNM()->mkNode(kind::EQUAL, (t2_mem_exp)[1], t1_ei->d_tp.get()));
1411               sendInferTranspose( t2_mem_exp[0], t1_ei->d_tp.get(), reason );
1412             }
1413             if( t2_ei->d_pt.get().isNull() && !t1_ei->d_pt.get().isNull() ) {
1414               Node reason = t1_ei->d_pt.get() == (t2_mem_exp)[1]
1415                             ? (t2_mem_exp) : NodeManager::currentNM()->mkNode(kind::AND, t2_mem_exp, NodeManager::currentNM()->mkNode(kind::EQUAL, (t2_mem_exp)[1], t1_ei->d_pt.get()));
1416               sendInferProduct( t2_mem_exp[0], t1_ei->d_pt.get(), reason );
1417             }
1418             if( t2_ei->d_tc.get().isNull() && !t1_ei->d_tc.get().isNull() ) {
1419               sendInferTClosure( t2_mem_exp, t1_ei );
1420             }
1421             if( t2_ei->d_rel_tc.get().isNull() && !t1_ei->d_rel_tc.get().isNull() ) {
1422               Node reason = t1_ei->d_rel_tc.get()[0] == t2_mem_exp[1] ?
1423                             t2_mem_exp : NodeManager::currentNM()->mkNode(kind::AND, NodeManager::currentNM()->mkNode(kind::EQUAL, t1_ei->d_rel_tc.get()[0], t2_mem_exp[1]), t2_mem_exp );
1424               sendMergeInfer(NodeManager::currentNM()->mkNode(kind::MEMBER, t2_mem_exp[0], t1_ei->d_rel_tc.get()), reason, "TCLOSURE-UP I");
1425             }
1426             t2_new_mems.push_back( *t2_mem_it );
1427             t2_new_exps.push_back( t2_mem_exp );
1428           }
1429         }
1430         for( ; t1_mem_it != t1_ei->d_mem.end(); t1_mem_it++ ) {
1431           if( !t2_ei->d_mem.contains( *t1_mem_it ) ) {
1432             Node t1_mem_exp = (*t1_ei->d_mem_exp.find(*t1_mem_it)).second;
1433             if( t1_ei->d_tp.get().isNull() && !t2_ei->d_tp.get().isNull() ) {
1434               Node reason = t2_ei->d_tp.get() == (t1_mem_exp)[1]
1435                             ? (t1_mem_exp) : NodeManager::currentNM()->mkNode(kind::AND, t1_mem_exp, NodeManager::currentNM()->mkNode(kind::EQUAL, (t1_mem_exp)[1], t2_ei->d_tp.get()));
1436               sendInferTranspose( (t1_mem_exp)[0], t2_ei->d_tp.get(), reason );
1437             }
1438             if( t1_ei->d_pt.get().isNull() && !t2_ei->d_pt.get().isNull() ) {
1439               Node reason = t2_ei->d_pt.get() == (t1_mem_exp)[1]
1440                             ? (t1_mem_exp) : NodeManager::currentNM()->mkNode(kind::AND, t1_mem_exp, NodeManager::currentNM()->mkNode(kind::EQUAL, (t1_mem_exp)[1], t2_ei->d_pt.get()));
1441               sendInferProduct( (t1_mem_exp)[0], t2_ei->d_pt.get(), reason );
1442             }
1443             if( t1_ei->d_tc.get().isNull() && !t2_ei->d_tc.get().isNull() ) {
1444               sendInferTClosure(t1_mem_exp, t2_ei );
1445             }
1446             if( t1_ei->d_rel_tc.get().isNull() && !t2_ei->d_rel_tc.get().isNull() ) {
1447               Node reason = t2_ei->d_rel_tc.get()[0] == t1_mem_exp[1] ?
1448                             t1_mem_exp : NodeManager::currentNM()->mkNode(kind::AND, NodeManager::currentNM()->mkNode(kind::EQUAL, t2_ei->d_rel_tc.get()[0], t1_mem_exp[1]), t1_mem_exp );
1449               sendMergeInfer(NodeManager::currentNM()->mkNode(kind::MEMBER, t1_mem_exp[0], t2_ei->d_rel_tc.get()), reason, "TCLOSURE-UP I");
1450             }
1451           }
1452         }
1453         std::vector< Node >::iterator t2_new_mem_it = t2_new_mems.begin();
1454         std::vector< Node >::iterator t2_new_exp_it = t2_new_exps.begin();
1455         for( ; t2_new_mem_it != t2_new_mems.end(); t2_new_mem_it++, t2_new_exp_it++ ) {
1456           t1_ei->d_mem.insert( *t2_new_mem_it );
1457           t1_ei->d_mem_exp.insert( *t2_new_mem_it, *t2_new_exp_it );
1458         }
1459         if( t1_ei->d_tp.get().isNull() && !t2_ei->d_tp.get().isNull() ) {
1460           t1_ei->d_tp.set(t2_ei->d_tp.get());
1461         }
1462         if( t1_ei->d_pt.get().isNull() && !t2_ei->d_pt.get().isNull() ) {
1463           t1_ei->d_pt.set(t2_ei->d_pt.get());
1464         }
1465         if( t1_ei->d_tc.get().isNull() && !t2_ei->d_tc.get().isNull() ) {
1466           t1_ei->d_tc.set(t2_ei->d_tc.get());
1467         }
1468         if( t1_ei->d_rel_tc.get().isNull() && !t2_ei->d_rel_tc.get().isNull() ) {
1469           t1_ei->d_rel_tc.set(t2_ei->d_rel_tc.get());
1470         }
1471       } else if( t1_ei != NULL ) {
1472         if( (t2.getKind() == kind::TRANSPOSE && t1_ei->d_tp.get().isNull()) ||
1473             (t2.getKind() == kind::PRODUCT && t1_ei->d_pt.get().isNull()) ||
1474             (t2.getKind() == kind::TCLOSURE && t1_ei->d_tc.get().isNull()) ) {
1475           NodeSet::const_iterator t1_mem_it = t1_ei->d_mem.begin();
1476 
1477           if( t2.getKind() == kind::TRANSPOSE ) {
1478             t1_ei->d_tp = t2;
1479           } else if( t2.getKind() == kind::PRODUCT ) {
1480             t1_ei->d_pt = t2;
1481           } else if( t2.getKind() == kind::TCLOSURE ) {
1482             t1_ei->d_tc = t2;
1483           }
1484           for( ; t1_mem_it != t1_ei->d_mem.end(); t1_mem_it++ ) {
1485             Node t1_exp = (*t1_ei->d_mem_exp.find(*t1_mem_it)).second;
1486             if( t2.getKind() == kind::TRANSPOSE ) {
1487               Node reason = t2 == t1_exp[1]
1488                             ? (t1_exp) : NodeManager::currentNM()->mkNode(kind::AND, (t1_exp), NodeManager::currentNM()->mkNode(kind::EQUAL, (t1_exp)[1], t2));
1489               sendInferTranspose( (t1_exp)[0], t2, reason );
1490             } else if( t2.getKind() == kind::PRODUCT ) {
1491               Node reason = t2 == (t1_exp)[1]
1492                             ? (t1_exp) : NodeManager::currentNM()->mkNode(kind::AND, (t1_exp), NodeManager::currentNM()->mkNode(kind::EQUAL, (t1_exp)[1], t2));
1493               sendInferProduct( (t1_exp)[0], t2, reason );
1494             } else if( t2.getKind() == kind::TCLOSURE ) {
1495               sendInferTClosure( t1_exp, t1_ei );
1496             }
1497           }
1498         }
1499       } else if( t2_ei != NULL ) {
1500         EqcInfo* new_t1_ei = getOrMakeEqcInfo( t1, true );
1501         if( new_t1_ei->d_tp.get().isNull() && !t2_ei->d_tp.get().isNull() ) {
1502           new_t1_ei->d_tp.set(t2_ei->d_tp.get());
1503         }
1504         if( new_t1_ei->d_pt.get().isNull() && !t2_ei->d_pt.get().isNull() ) {
1505           new_t1_ei->d_pt.set(t2_ei->d_pt.get());
1506         }
1507         if( new_t1_ei->d_tc.get().isNull() && !t2_ei->d_tc.get().isNull() ) {
1508           new_t1_ei->d_tc.set(t2_ei->d_tc.get());
1509         }
1510         if( new_t1_ei->d_rel_tc.get().isNull() && !t2_ei->d_rel_tc.get().isNull() ) {
1511           new_t1_ei->d_rel_tc.set(t2_ei->d_rel_tc.get());
1512         }
1513         if( (t1.getKind() == kind::TRANSPOSE && t2_ei->d_tp.get().isNull()) ||
1514             (t1.getKind() == kind::PRODUCT && t2_ei->d_pt.get().isNull()) ||
1515             (t1.getKind() == kind::TCLOSURE && t2_ei->d_tc.get().isNull()) ) {
1516           NodeSet::const_iterator t2_mem_it = t2_ei->d_mem.begin();
1517 
1518           for( ; t2_mem_it != t2_ei->d_mem.end(); t2_mem_it++ ) {
1519             Node t2_exp = (*t1_ei->d_mem_exp.find(*t2_mem_it)).second;
1520 
1521             if( t1.getKind() == kind::TRANSPOSE ) {
1522               Node reason = t1 == (t2_exp)[1]
1523                             ? (t2_exp) : NodeManager::currentNM()->mkNode(kind::AND, (t2_exp), NodeManager::currentNM()->mkNode(kind::EQUAL, (t2_exp)[1], t1));
1524               sendInferTranspose( (t2_exp)[0], t1, reason );
1525             } else if( t1.getKind() == kind::PRODUCT ) {
1526               Node reason = t1 == (t2_exp)[1]
1527                             ? (t2_exp) : NodeManager::currentNM()->mkNode(kind::AND, (t2_exp), NodeManager::currentNM()->mkNode(kind::EQUAL, (t2_exp)[1], t1));
1528               sendInferProduct( (t2_exp)[0], t1, reason );
1529             } else if( t1.getKind() == kind::TCLOSURE ) {
1530               sendInferTClosure( t2_exp, new_t1_ei );
1531             }
1532           }
1533         }
1534       }
1535     }
1536 
1537     Trace("rels-std") << "[sets-rels] done with eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
1538   }
1539 
sendInferTClosure(Node new_mem_exp,EqcInfo * ei)1540   void TheorySetsRels::sendInferTClosure( Node new_mem_exp, EqcInfo* ei ) {
1541     NodeSet::const_iterator mem_it = ei->d_mem.begin();
1542     Node mem_rep = getRepresentative( new_mem_exp[0] );
1543     Node new_mem_rel = new_mem_exp[1];
1544     Node new_mem_fst = RelsUtils::nthElementOfTuple( new_mem_exp[0], 0 );
1545     Node new_mem_snd = RelsUtils::nthElementOfTuple( new_mem_exp[0], 1 );
1546     for( ; mem_it != ei->d_mem.end(); mem_it++ ) {
1547       if( *mem_it == mem_rep ) {
1548         continue;
1549       }
1550       Node d_mem_exp = (*ei->d_mem_exp.find(*mem_it)).second;
1551       Node d_mem_fst = RelsUtils::nthElementOfTuple( d_mem_exp[0], 0 );
1552       Node d_mem_snd = RelsUtils::nthElementOfTuple( d_mem_exp[0], 1 );
1553       Node d_mem_rel = d_mem_exp[1];
1554       if( areEqual( new_mem_fst, d_mem_snd) ) {
1555         Node reason = NodeManager::currentNM()->mkNode( kind::AND, new_mem_exp, d_mem_exp );
1556         reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, new_mem_fst, d_mem_snd ) );
1557         if( new_mem_rel != ei->d_tc.get() ) {
1558           reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, new_mem_rel, ei->d_tc.get() ) );
1559         }
1560         if( d_mem_rel != ei->d_tc.get() ) {
1561           reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, d_mem_rel, ei->d_tc.get() ) );
1562         }
1563         Node new_membership = NodeManager::currentNM()->mkNode( kind::MEMBER, RelsUtils::constructPair( d_mem_rel, d_mem_fst, new_mem_snd ), ei->d_tc.get() );
1564         sendMergeInfer( new_membership, reason, "TCLOSURE-UP II" );
1565       }
1566       if( areEqual( new_mem_snd, d_mem_fst ) ) {
1567         Node reason = NodeManager::currentNM()->mkNode( kind::AND, new_mem_exp, d_mem_exp );
1568         reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, new_mem_snd, d_mem_fst ) );
1569         if( new_mem_rel != ei->d_tc.get() ) {
1570           reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, new_mem_rel, ei->d_tc.get() ) );
1571         }
1572         if( d_mem_rel != ei->d_tc.get() ) {
1573           reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, d_mem_rel, ei->d_tc.get() ) );
1574         }
1575         Node new_membership = NodeManager::currentNM()->mkNode( kind::MEMBER, RelsUtils::constructPair( d_mem_rel, new_mem_fst, d_mem_snd ), ei->d_tc.get() );
1576         sendMergeInfer( new_membership, reason, "TCLOSURE-UP II" );
1577       }
1578     }
1579   }
1580 
1581 
doPendingMerge()1582   void TheorySetsRels::doPendingMerge() {
1583     for( NodeList::const_iterator itr = d_pending_merge.begin(); itr != d_pending_merge.end(); itr++ ) {
1584       if( !holds(*itr) ) {
1585         if( d_lemmas_produced.find(*itr)==d_lemmas_produced.end() ) {
1586           Trace("rels-std-lemma") << "[std-sets-rels-lemma] Send out a merge fact as lemma: "
1587                               << *itr << std::endl;
1588           d_sets_theory.d_out->lemma( *itr );
1589           d_lemmas_produced.insert(*itr);
1590         }
1591       }
1592     }
1593   }
1594 
1595   // t1 and t2 can be both relations
1596   // or t1 is a tuple, t2 is a transposed relation
sendInferTranspose(Node t1,Node t2,Node exp)1597   void TheorySetsRels::sendInferTranspose( Node t1, Node t2, Node exp ) {
1598     Assert( t2.getKind() == kind::TRANSPOSE );
1599 
1600     if( isRel(t1) && isRel(t2) ) {
1601       Assert(t1.getKind() == kind::TRANSPOSE);
1602       sendMergeInfer(NodeManager::currentNM()->mkNode(kind::EQUAL, t1[0], t2[0]), exp, "Transpose-Equal");
1603       return;
1604     }
1605     sendMergeInfer(NodeManager::currentNM()->mkNode(kind::MEMBER, RelsUtils::reverseTuple(t1), t2[0]), exp, "Transpose-Rule");
1606   }
1607 
sendMergeInfer(Node fact,Node reason,const char * c)1608   void TheorySetsRels::sendMergeInfer( Node fact, Node reason, const char * c ) {
1609     if( !holds( fact ) ) {
1610       Node lemma = NodeManager::currentNM()->mkNode( kind::IMPLIES, reason, fact);
1611       d_pending_merge.push_back(lemma);
1612       Trace("rels-std") << "[std-rels-lemma] Generate a lemma by applying " << c
1613                         << ": " << lemma << std::endl;
1614     }
1615   }
1616 
sendInferProduct(Node member,Node pt_rel,Node exp)1617   void TheorySetsRels::sendInferProduct( Node member, Node pt_rel, Node exp ) {
1618     Assert( pt_rel.getKind() == kind::PRODUCT );
1619 
1620     std::vector<Node>   r1_element;
1621     std::vector<Node>   r2_element;
1622     Node                r1      = pt_rel[0];
1623     Node                r2      = pt_rel[1];
1624     Datatype            dt      = r1.getType().getSetElementType().getDatatype();
1625     unsigned int        i       = 0;
1626     unsigned int        s1_len  = r1.getType().getSetElementType().getTupleLength();
1627     unsigned int        tup_len = pt_rel.getType().getSetElementType().getTupleLength();
1628 
1629     r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
1630     for( ; i < s1_len; ++i ) {
1631       r1_element.push_back( RelsUtils::nthElementOfTuple( member, i ) );
1632     }
1633 
1634     dt = r2.getType().getSetElementType().getDatatype();
1635     r2_element.push_back( Node::fromExpr( dt[0].getConstructor() ) );
1636     for( ; i < tup_len; ++i ) {
1637       r2_element.push_back( RelsUtils::nthElementOfTuple(member, i) );
1638     }
1639 
1640     Node tuple_1 = NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, r1_element );
1641     Node tuple_2 = NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, r2_element );
1642     sendMergeInfer( NodeManager::currentNM()->mkNode(kind::MEMBER, tuple_1, r1), exp, "Product-Split" );
1643     sendMergeInfer( NodeManager::currentNM()->mkNode(kind::MEMBER, tuple_2, r2), exp, "Product-Split" );
1644   }
1645 
getOrMakeEqcInfo(Node n,bool doMake)1646   TheorySetsRels::EqcInfo* TheorySetsRels::getOrMakeEqcInfo( Node n, bool doMake ){
1647     std::map< Node, EqcInfo* >::iterator eqc_i = d_eqc_info.find( n );
1648     if( eqc_i == d_eqc_info.end() ){
1649       if( doMake ){
1650         EqcInfo* ei;
1651         if( eqc_i!=d_eqc_info.end() ){
1652           ei = eqc_i->second;
1653         }else{
1654           ei = new EqcInfo(d_sets_theory.getSatContext());
1655           d_eqc_info[n] = ei;
1656         }
1657         if( n.getKind() == kind::TRANSPOSE ){
1658           ei->d_tp = n;
1659         } else if( n.getKind() == kind::PRODUCT ) {
1660           ei->d_pt = n;
1661         } else if( n.getKind() == kind::TCLOSURE ) {
1662           ei->d_tc = n;
1663         }
1664         return ei;
1665       }else{
1666         return NULL;
1667       }
1668     }else{
1669       return (*eqc_i).second;
1670     }
1671   }
1672 
printNodeMap(const char * fst,const char * snd,const NodeMap & map)1673   void TheorySetsRels::printNodeMap(const char* fst,
1674                                     const char* snd,
1675                                     const NodeMap& map)
1676   {
1677     for (const auto& key_data : map)
1678     {
1679       Trace("rels-debug") << fst << " " << key_data.first << " " << snd << " "
1680                           << key_data.second << std::endl;
1681     }
1682   }
1683 }
1684 }
1685 }
1686