1 /*********************                                                        */
2 /*! \file theory_uf_strong_solver.h
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Morgan Deters, Tim King
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Theory uf strong solver
13  **/
14 
15 #include "cvc4_private.h"
16 
17 #ifndef __CVC4__THEORY_UF_STRONG_SOLVER_H
18 #define __CVC4__THEORY_UF_STRONG_SOLVER_H
19 
20 #include "context/cdhashmap.h"
21 #include "context/context.h"
22 #include "context/context_mm.h"
23 #include "theory/theory.h"
24 #include "util/statistics_registry.h"
25 
26 #include "theory/decision_manager.h"
27 
28 namespace CVC4 {
29 class SortInference;
30 namespace theory {
31 namespace uf {
32 class TheoryUF;
33 } /* namespace CVC4::theory::uf */
34 } /* namespace CVC4::theory */
35 } /* namespace CVC4 */
36 
37 namespace CVC4 {
38 namespace theory {
39 namespace uf {
40 
41 class StrongSolverTheoryUF{
42 protected:
43   typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
44   typedef context::CDHashMap<Node, int, NodeHashFunction> NodeIntMap;
45   typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
46   typedef context::CDHashMap<TypeNode, bool, TypeNodeHashFunction> TypeNodeBoolMap;
47 public:
48   /**
49    * Information for incremental conflict/clique finding for a
50    * particular sort.
51    */
52   class SortModel {
53   private:
54     std::map< Node, std::vector< int > > d_totality_lems;
55     std::map< TypeNode, std::map< int, std::vector< Node > > > d_sym_break_terms;
56     std::map< Node, int > d_sym_break_index;
57   public:
58 
59     /**
60      * A partition of the current equality graph for which cliques
61      * can occur internally.
62      */
63     class Region {
64     public:
65       /** information stored about each node in region */
66       class RegionNodeInfo {
67       public:
68         /** disequality list for node */
69         class DiseqList {
70         public:
DiseqList(context::Context * c)71           DiseqList( context::Context* c )
72             : d_size( c, 0 ), d_disequalities( c ) {}
~DiseqList()73           ~DiseqList(){}
74 
setDisequal(Node n,bool valid)75           void setDisequal( Node n, bool valid ){
76             Assert( (!isSet(n)) || getDisequalityValue(n) != valid );
77             d_disequalities[ n ] = valid;
78             d_size = d_size + ( valid ? 1 : -1 );
79           }
isSet(Node n)80           bool isSet(Node n) const {
81             return d_disequalities.find(n) != d_disequalities.end();
82           }
getDisequalityValue(Node n)83           bool getDisequalityValue(Node n) const {
84             Assert(isSet(n));
85             return (*(d_disequalities.find(n))).second;
86           }
87 
size()88           int size() const { return d_size; }
89 
90           typedef NodeBoolMap::iterator iterator;
begin()91           iterator begin() { return d_disequalities.begin(); }
end()92           iterator end() { return d_disequalities.end(); }
93 
94         private:
95           context::CDO< int > d_size;
96           NodeBoolMap d_disequalities;
97         }; /* class DiseqList */
98       public:
99         /** constructor */
RegionNodeInfo(context::Context * c)100         RegionNodeInfo( context::Context* c )
101           : d_internal(c), d_external(c), d_valid(c, true) {
102           d_disequalities[0] = &d_internal;
103           d_disequalities[1] = &d_external;
104         }
~RegionNodeInfo()105         ~RegionNodeInfo(){}
106 
getNumDisequalities()107         int getNumDisequalities() const {
108           return d_disequalities[0]->size() + d_disequalities[1]->size();
109         }
getNumExternalDisequalities()110         int getNumExternalDisequalities() const {
111           return d_disequalities[0]->size();
112         }
getNumInternalDisequalities()113         int getNumInternalDisequalities() const {
114           return d_disequalities[1]->size();
115         }
116 
valid()117         bool valid() const { return d_valid; }
setValid(bool valid)118         void setValid(bool valid) { d_valid = valid; }
119 
get(unsigned i)120         DiseqList* get(unsigned i) { return d_disequalities[i]; }
121 
122       private:
123         DiseqList d_internal;
124         DiseqList d_external;
125         context::CDO< bool > d_valid;
126         DiseqList* d_disequalities[2];
127       }; /* class RegionNodeInfo */
128 
129     private:
130       /** conflict find pointer */
131       SortModel* d_cf;
132 
133       context::CDO< unsigned > d_testCliqueSize;
134       context::CDO< unsigned > d_splitsSize;
135       //a postulated clique
136       NodeBoolMap d_testClique;
137       //disequalities needed for this clique to happen
138       NodeBoolMap d_splits;
139       //number of valid representatives in this region
140       context::CDO< unsigned > d_reps_size;
141       //total disequality size (external)
142       context::CDO< unsigned > d_total_diseq_external;
143       //total disequality size (internal)
144       context::CDO< unsigned > d_total_diseq_internal;
145       /** set rep */
146       void setRep( Node n, bool valid );
147       //region node infomation
148       std::map< Node, RegionNodeInfo* > d_nodes;
149       //whether region is valid
150       context::CDO< bool > d_valid;
151 
152     public:
153       //constructor
154       Region( SortModel* cf, context::Context* c );
155       virtual ~Region();
156 
157       typedef std::map< Node, RegionNodeInfo* >::iterator iterator;
begin()158       iterator begin() { return d_nodes.begin(); }
end()159       iterator end() { return d_nodes.end(); }
160 
161       typedef NodeBoolMap::iterator split_iterator;
begin_splits()162       split_iterator begin_splits() { return d_splits.begin(); }
end_splits()163       split_iterator end_splits() { return d_splits.end(); }
164 
165       /** Returns a RegionInfo. */
getRegionInfo(Node n)166       RegionNodeInfo* getRegionInfo(Node n) {
167         Assert(d_nodes.find(n) != d_nodes.end());
168         return (* (d_nodes.find(n))).second;
169       }
170 
171       /** Returns whether or not d_valid is set in current context. */
valid()172       bool valid() const { return d_valid; }
173 
174       /** Sets d_valid to the value valid in the current context.*/
setValid(bool valid)175       void setValid(bool valid) { d_valid = valid; }
176 
177       /** add rep */
178       void addRep( Node n );
179       //take node from region
180       void takeNode( Region* r, Node n );
181       //merge with other region
182       void combine( Region* r );
183       /** merge */
184       void setEqual( Node a, Node b );
185       //set n1 != n2 to value 'valid', type is whether it is internal/external
186       void setDisequal( Node n1, Node n2, int type, bool valid );
187       //get num reps
getNumReps()188       int getNumReps() { return d_reps_size; }
189       //get test clique size
getTestCliqueSize()190       int getTestCliqueSize() { return d_testCliqueSize; }
191       // has representative
hasRep(Node n)192       bool hasRep( Node n ) {
193         return d_nodes.find(n) != d_nodes.end() && d_nodes[n]->valid();
194       }
195       // is disequal
196       bool isDisequal( Node n1, Node n2, int type );
197       /** get must merge */
198       bool getMustCombine( int cardinality );
199       /** has splits */
hasSplits()200       bool hasSplits() { return d_splitsSize>0; }
201       /** get external disequalities */
202       void getNumExternalDisequalities(std::map< Node, int >& num_ext_disequalities );
203       /** check for cliques */
204       bool check( Theory::Effort level, int cardinality, std::vector< Node >& clique );
205       //print debug
206       void debugPrint( const char* c, bool incClique = false );
207 
208       // Returns the first key in d_nodes.
frontKey()209       Node frontKey() const { return d_nodes.begin()->first; }
210     }; /* class Region */
211 
212   private:
213     /** the type this model is for */
214     TypeNode d_type;
215     /** strong solver pointer */
216     StrongSolverTheoryUF* d_thss;
217     /** regions used to d_region_index */
218     context::CDO< unsigned > d_regions_index;
219     /** vector of regions */
220     std::vector< Region* > d_regions;
221     /** map from Nodes to index of d_regions they exist in, -1 means invalid */
222     NodeIntMap d_regions_map;
223     /** the score for each node for splitting */
224     NodeIntMap d_split_score;
225     /** number of valid disequalities in d_disequalities */
226     context::CDO< unsigned > d_disequalities_index;
227     /** list of all disequalities */
228     std::vector< Node > d_disequalities;
229     /** number of representatives in all regions */
230     context::CDO< unsigned > d_reps;
231 
232     /** get number of disequalities from node n to region ri */
233     int getNumDisequalitiesToRegion( Node n, int ri );
234     /** get number of disequalities from Region r to other regions */
235     void getDisequalitiesToRegions( int ri, std::map< int, int >& regions_diseq );
236     /** is valid */
isValid(int ri)237     bool isValid( int ri ) {
238       return ri>=0 && ri<(int)d_regions_index && d_regions[ ri ]->valid();
239     }
240     /** set split score */
241     void setSplitScore( Node n, int s );
242     /** check if we need to combine region ri */
243     void checkRegion( int ri, bool checkCombine = true );
244     /** force combine region */
245     int forceCombineRegion( int ri, bool useDensity = true );
246     /** merge regions */
247     int combineRegions( int ai, int bi );
248     /** move node n to region ri */
249     void moveNode( Node n, int ri );
250     /** allocate cardinality */
251     void allocateCardinality( OutputChannel* out );
252     /**
253      * Add splits. Returns
254      *   0 = no split,
255      *  -1 = entailed disequality added, or
256      *   1 = split added.
257      */
258     int addSplit( Region* r, OutputChannel* out );
259     /** add clique lemma */
260     void addCliqueLemma( std::vector< Node >& clique, OutputChannel* out );
261     /** add totality axiom */
262     void addTotalityAxiom( Node n, int cardinality, OutputChannel* out );
263     /** Are we in conflict */
264     context::CDO<bool> d_conflict;
265     /** cardinality */
266     context::CDO< int > d_cardinality;
267     /** cardinality lemma term */
268     Node d_cardinality_term;
269     /** cardinality totality terms */
270     std::map< int, std::vector< Node > > d_totality_terms;
271     /** cardinality literals */
272     std::map< int, Node > d_cardinality_literal;
273     /** whether a positive cardinality constraint has been asserted */
274     context::CDO< bool > d_hasCard;
275     /** clique lemmas that have been asserted */
276     std::map< int, std::vector< std::vector< Node > > > d_cliques;
277     /** maximum negatively asserted cardinality */
278     context::CDO< int > d_maxNegCard;
279     /** list of fresh representatives allocated */
280     std::vector< Node > d_fresh_aloc_reps;
281     /** whether we are initialized */
282     context::CDO< bool > d_initialized;
283     /** cache for lemmas */
284     NodeBoolMap d_lemma_cache;
285 
286     /** apply totality */
287     bool applyTotality( int cardinality );
288     /** get totality lemma terms */
289     Node getTotalityLemmaTerm( int cardinality, int i );
290     /** simple check cardinality */
291     void simpleCheckCardinality();
292 
293     bool doSendLemma( Node lem );
294 
295   public:
296     SortModel( Node n, context::Context* c, context::UserContext* u,
297                StrongSolverTheoryUF* thss );
298     virtual ~SortModel();
299     /** initialize */
300     void initialize( OutputChannel* out );
301     /** new node */
302     void newEqClass( Node n );
303     /** merge */
304     void merge( Node a, Node b );
305     /** assert terms are disequal */
306     void assertDisequal( Node a, Node b, Node reason );
307     /** are disequal */
308     bool areDisequal( Node a, Node b );
309     /** check */
310     void check( Theory::Effort level, OutputChannel* out );
311     /** presolve */
312     void presolve();
313     /** propagate */
314     void propagate( Theory::Effort level, OutputChannel* out );
315     /** assert cardinality */
316     void assertCardinality( OutputChannel* out, int c, bool val );
317     /** is in conflict */
isConflict()318     bool isConflict() { return d_conflict; }
319     /** get cardinality */
getCardinality()320     int getCardinality() { return d_cardinality; }
321     /** has cardinality */
hasCardinalityAsserted()322     bool hasCardinalityAsserted() { return d_hasCard; }
323     /** get cardinality term */
getCardinalityTerm()324     Node getCardinalityTerm() { return d_cardinality_term; }
325     /** get cardinality literal */
326     Node getCardinalityLiteral(unsigned c);
327     /** get maximum negative cardinality */
getMaximumNegativeCardinality()328     int getMaximumNegativeCardinality() { return d_maxNegCard.get(); }
329     //print debug
330     void debugPrint( const char* c );
331     /** debug a model */
332     bool debugModel( TheoryModel* m );
333     /** get number of regions (for debugging) */
334     int getNumRegions();
335 
336    private:
337     /**
338      * Decision strategy for cardinality constraints. This asserts
339      * the minimal constraint positively in the SAT context. For details, see
340      * Section 6.3 of Reynolds et al, "Constraint Solving for Finite Model
341      * Finding in SMT Solvers", TPLP 2017.
342      */
343     class CardinalityDecisionStrategy : public DecisionStrategyFmf
344     {
345      public:
346       CardinalityDecisionStrategy(Node t,
347                                   context::Context* satContext,
348                                   Valuation valuation);
349       /** make literal (the i^th combined cardinality literal) */
350       Node mkLiteral(unsigned i) override;
351       /** identify */
352       std::string identify() const override;
353 
354      private:
355       /** the cardinality term */
356       Node d_cardinality_term;
357     };
358     /** cardinality decision strategy */
359     std::unique_ptr<CardinalityDecisionStrategy> d_c_dec_strat;
360   }; /** class SortModel */
361 
362 public:
363   StrongSolverTheoryUF(context::Context* c, context::UserContext* u,
364                        OutputChannel& out, TheoryUF* th);
365   ~StrongSolverTheoryUF();
366   /** get theory */
getTheory()367   TheoryUF* getTheory() { return d_th; }
368   /** get sort inference module */
369   SortInference* getSortInference();
370   /** get default sat context */
371   context::Context* getSatContext();
372   /** get default output channel */
373   OutputChannel& getOutputChannel();
374   /** new node */
375   void newEqClass( Node n );
376   /** merge */
377   void merge( Node a, Node b );
378   /** assert terms are disequal */
379   void assertDisequal( Node a, Node b, Node reason );
380   /** assert node */
381   void assertNode( Node n, bool isDecision );
382   /** are disequal */
383   bool areDisequal( Node a, Node b );
384   /** check */
385   void check( Theory::Effort level );
386   /** presolve */
387   void presolve();
388   /** preregister a term */
389   void preRegisterTerm( TNode n );
390   /** identify */
identify()391   std::string identify() const { return std::string("StrongSolverTheoryUF"); }
392   //print debug
393   void debugPrint( const char* c );
394   /** debug a model */
395   bool debugModel( TheoryModel* m );
396   /** get is in conflict */
isConflict()397   bool isConflict() { return d_conflict; }
398   /** get cardinality for node */
399   int getCardinality( Node n );
400   /** get cardinality for type */
401   int getCardinality( TypeNode tn );
402   /** has eqc */
403   bool hasEqc(Node a);
404 
405   class Statistics {
406    public:
407     IntStat d_clique_conflicts;
408     IntStat d_clique_lemmas;
409     IntStat d_split_lemmas;
410     IntStat d_disamb_term_lemmas;
411     IntStat d_totality_lemmas;
412     IntStat d_max_model_size;
413     Statistics();
414     ~Statistics();
415   };
416   /** statistics class */
417   Statistics d_statistics;
418 
419  private:
420   /** get sort model */
421   SortModel* getSortModel(Node n);
422   /** initialize */
423   void initializeCombinedCardinality();
424   /** check */
425   void checkCombinedCardinality();
426   /** ensure eqc */
427   void ensureEqc(SortModel* c, Node a);
428   /** ensure eqc for all subterms of n */
429   void ensureEqcRec(Node n);
430 
431   /** The output channel for the strong solver. */
432   OutputChannel* d_out;
433   /** theory uf pointer */
434   TheoryUF* d_th;
435   /** Are we in conflict */
436   context::CDO<bool> d_conflict;
437   /** rep model structure, one for each type */
438   std::map<TypeNode, SortModel*> d_rep_model;
439 
440   /** minimum positive combined cardinality */
441   context::CDO<int> d_min_pos_com_card;
442   /**
443    * Decision strategy for combined cardinality constraints. This asserts
444    * the minimal combined cardinality constraint positively in the SAT
445    * context. It is enabled by options::ufssFairness(). For details, see
446    * the extension to multiple sorts in Section 6.3 of Reynolds et al,
447    * "Constraint Solving for Finite Model Finding in SMT Solvers", TPLP 2017.
448    */
449   class CombinedCardinalityDecisionStrategy : public DecisionStrategyFmf
450   {
451    public:
452     CombinedCardinalityDecisionStrategy(context::Context* satContext,
453                                         Valuation valuation);
454     /** make literal (the i^th combined cardinality literal) */
455     Node mkLiteral(unsigned i) override;
456     /** identify */
457     std::string identify() const override;
458   };
459   /** combined cardinality decision strategy */
460   std::unique_ptr<CombinedCardinalityDecisionStrategy> d_cc_dec_strat;
461   /** Have we initialized combined cardinality? */
462   context::CDO<bool> d_initializedCombinedCardinality;
463 
464   /** cardinality literals for which we have added */
465   NodeBoolMap d_card_assertions_eqv_lemma;
466   /** the master monotone type (if ufssFairnessMonotone enabled) */
467   TypeNode d_tn_mono_master;
468   std::map<TypeNode, bool> d_tn_mono_slave;
469   context::CDO<int> d_min_pos_tn_master_card;
470   /** relevant eqc */
471   NodeBoolMap d_rel_eqc;
472 }; /* class StrongSolverTheoryUF */
473 
474 
475 }/* CVC4::theory namespace::uf */
476 }/* CVC4::theory namespace */
477 }/* CVC4 namespace */
478 
479 #endif /* __CVC4__THEORY_UF_STRONG_SOLVER_H */
480