1############################################
2# Copyright (c) 2016 Microsoft Corporation
3#
4# MSS enumeration based on maximal resolution.
5#
6# Author: Nikolaj Bjorner (nbjorner)
7############################################
8
9"""
10
11The following is a procedure for enumerating maximal satisfying subsets.
12It uses maximal resolution to eliminate cores from the state space.
13Whenever the hard constraints are satisfiable, it finds a model that
14satisfies the maximal number of soft constraints.
15During this process it collects the set of cores that are encountered.
16It then reduces the set of soft constraints using max-resolution in
17the style of [Narodytska & Bacchus, AAAI'14]. In other words,
18let F1, ..., F_k be a core among the soft constraints F1,...,F_n
19Replace F1,.., F_k by
20          F1 or F2, F3 or (F2 & F1), F4 or (F3 & (F2 & F1)), ...,
21          F_k or (F_{k-1} & (...))
22Optionally, add the core ~F1 or ... or ~F_k to F
23The current model M satisfies the new set F, F1,...,F_{n-1} if the core is minimal.
24Whenever we modify the soft constraints by the core reduction any assignment
25to the reduced set satisfies a k-1 of the original soft constraints.
26
27"""
28
29from z3 import *
30
31def main():
32    x, y = Reals('x y')
33    soft_constraints = [x > 2, x < 1, x < 0, Or(x + y > 0, y < 0), Or(y >= 0, x >= 0), Or(y < 0, x < 0), Or(y > 0, x < 0)]
34    hard_constraints = BoolVal(True)
35    solver = MSSSolver(hard_constraints, soft_constraints)
36    for lits in enumerate_sets(solver):
37        print("%s" % lits)
38
39
40def enumerate_sets(solver):
41    while True:
42        if sat == solver.s.check():
43           MSS = solver.grow()
44           yield MSS
45        else:
46           break
47
48class MSSSolver:
49   s = Solver()
50   varcache = {}
51   idcache = {}
52
53   def __init__(self, hard, soft):
54       self.n = len(soft)
55       self.soft = soft
56       self.s.add(hard)
57       self.soft_vars = set([self.c_var(i) for i in range(self.n)])
58       self.orig_soft_vars = set([self.c_var(i) for i in range(self.n)])
59       self.s.add([(self.c_var(i) == soft[i]) for i in range(self.n)])
60
61   def c_var(self, i):
62       if i not in self.varcache:
63          v = Bool(str(self.soft[abs(i)]))
64          self.idcache[v] = abs(i)
65          if i >= 0:
66             self.varcache[i] = v
67          else:
68             self.varcache[i] = Not(v)
69       return self.varcache[i]
70
71   # Retrieve the latest model
72   # Add formulas that are true in the model to
73   # the current mss
74
75   def update_unknown(self):
76       self.model = self.s.model()
77       new_unknown = set([])
78       for x in self.unknown:
79           if is_true(self.model[x]):
80              self.mss.append(x)
81           else:
82              new_unknown.add(x)
83       self.unknown = new_unknown
84
85   # Create a name, propositional atom,
86   #  for formula 'fml' and return the name.
87
88   def add_def(self, fml):
89       name = Bool("%s" % fml)
90       self.s.add(name == fml)
91       return name
92
93   # replace Fs := f0, f1, f2, .. by
94   # Or(f1, f0), Or(f2, And(f1, f0)), Or(f3, And(f2, And(f1, f0))), ...
95
96   def relax_core(self, Fs):
97       assert(Fs <= self.soft_vars)
98       prefix = BoolVal(True)
99       self.soft_vars -= Fs
100       Fs = [ f for f in Fs ]
101       for i in range(len(Fs)-1):
102           prefix = self.add_def(And(Fs[i], prefix))
103           self.soft_vars.add(self.add_def(Or(prefix, Fs[i+1])))
104
105   # Resolve literals from the core that
106   # are 'explained', e.g., implied by
107   # other literals.
108
109   def resolve_core(self, core):
110       new_core = set([])
111       for x in core:
112           if x in self.mcs_explain:
113              new_core |= self.mcs_explain[x]
114           else:
115              new_core.add(x)
116       return new_core
117
118
119   # Given a current satisfiable state
120   # Extract an MSS, and ensure that currently
121   # encountered cores are avoided in next iterations
122   # by weakening the set of literals that are
123   # examined in next iterations.
124   # Strengthen the solver state by enforcing that
125   # an element from the MCS is encountered.
126
127   def grow(self):
128       self.mss = []
129       self.mcs = []
130       self.nmcs = []
131       self.mcs_explain = {}
132       self.unknown = self.soft_vars
133       self.update_unknown()
134       cores = []
135       while len(self.unknown) > 0:
136           x = self.unknown.pop()
137           is_sat = self.s.check(self.mss + [x] + self.nmcs)
138           if is_sat == sat:
139              self.mss.append(x)
140              self.update_unknown()
141           elif is_sat == unsat:
142              core = self.s.unsat_core()
143              core = self.resolve_core(core)
144              self.mcs_explain[Not(x)] = {y for y in core if not eq(x,y)}
145              self.mcs.append(x)
146              self.nmcs.append(Not(x))
147              cores += [core]
148           else:
149              print("solver returned %s" % is_sat)
150              exit()
151       mss = [x for x in self.orig_soft_vars if is_true(self.model[x])]
152       mcs = [x for x in self.orig_soft_vars if not is_true(self.model[x])]
153       self.s.add(Or(mcs))
154       core_literals = set([])
155       cores.sort(key=lambda element: len(element))
156       for core in cores:
157           if len(core & core_literals) == 0:
158              self.relax_core(core)
159              core_literals |= core
160       return mss
161
162
163main()
164