1from z3 import *
2import heapq
3import numpy
4import time
5import random
6
7verbose = True
8
9# Simplistic (and fragile) converter from
10# a class of Horn clauses corresponding to
11# a transition system into a transition system
12# representation as <init, trans, goal>
13# It assumes it is given three Horn clauses
14# of the form:
15#  init(x) => Invariant(x)
16#  Invariant(x) and trans(x,x') => Invariant(x')
17#  Invariant(x) and goal(x) => Goal(x)
18# where Invariant and Goal are uninterpreted predicates
19
20class Horn2Transitions:
21    def __init__(self):
22        self.trans = True
23        self.init = True
24        self.inputs = []
25        self.goal = True
26        self.index = 0
27
28    def parse(self, file):
29        fp = Fixedpoint()
30        goals = fp.parse_file(file)
31        for r in fp.get_rules():
32            if not is_quantifier(r):
33                continue
34            b = r.body()
35            if not is_implies(b):
36                continue
37            f = b.arg(0)
38            g = b.arg(1)
39            if self.is_goal(f, g):
40                continue
41            if self.is_transition(f, g):
42                continue
43            if self.is_init(f, g):
44                continue
45
46    def is_pred(self, p, name):
47        return is_app(p) and p.decl().name() == name
48
49    def is_goal(self, body, head):
50        if not self.is_pred(head, "Goal"):
51            return False
52        pred, inv = self.is_body(body)
53        if pred is None:
54            return False
55        self.goal = self.subst_vars("x", inv, pred)
56        self.goal = self.subst_vars("i", self.goal, self.goal)
57        self.inputs += self.vars
58        self.inputs = list(set(self.inputs))
59        return True
60
61    def is_body(self, body):
62        if not is_and(body):
63            return None, None
64        fmls = [f for f in body.children() if self.is_inv(f) is None]
65        inv = None
66        for f in body.children():
67            if self.is_inv(f) is not None:
68                inv = f;
69                break
70        return And(fmls), inv
71
72    def is_inv(self, f):
73        if self.is_pred(f, "Invariant"):
74            return f
75        return None
76
77    def is_transition(self, body, head):
78        pred, inv0 = self.is_body(body)
79        if pred is None:
80            return False
81        inv1 = self.is_inv(head)
82        if inv1 is None:
83            return False
84        pred = self.subst_vars("x",  inv0, pred)
85        self.xs = self.vars
86        pred = self.subst_vars("xn", inv1, pred)
87        self.xns = self.vars
88        pred = self.subst_vars("i", pred, pred)
89        self.inputs += self.vars
90        self.inputs = list(set(self.inputs))
91        self.trans = pred
92        return True
93
94    def is_init(self, body, head):
95        for f in body.children():
96            if self.is_inv(f) is not None:
97               return False
98        inv = self.is_inv(head)
99        if inv is None:
100            return False
101        self.init = self.subst_vars("x", inv, body)
102        return True
103
104    def subst_vars(self, prefix, inv, fml):
105        subst = self.mk_subst(prefix, inv)
106        self.vars = [ v for (k,v) in subst ]
107        return substitute(fml, subst)
108
109    def mk_subst(self, prefix, inv):
110        self.index = 0
111        if self.is_inv(inv) is not None:
112            return [(f, self.mk_bool(prefix)) for f in inv.children()]
113        else:
114            vars = self.get_vars(inv)
115            return [(f, self.mk_bool(prefix)) for f in vars]
116
117    def mk_bool(self, prefix):
118        self.index += 1
119        return Bool("%s%d" % (prefix, self.index))
120
121    def get_vars(self, f, rs=[]):
122        if is_var(f):
123            return z3util.vset(rs + [f], str)
124        else:
125            for f_ in f.children():
126                rs = self.get_vars(f_, rs)
127            return z3util.vset(rs, str)
128
129# Produce a finite domain solver.
130# The theory QF_FD covers bit-vector formulas
131# and pseudo-Boolean constraints.
132# By default cardinality and pseudo-Boolean
133# constraints are converted to clauses. To override
134# this default for cardinality constraints
135# we set sat.cardinality.solver to True
136
137def fd_solver():
138    s = SolverFor("QF_FD")
139    s.set("sat.cardinality.solver", True)
140    return s
141
142
143# negate, avoid double negation
144def negate(f):
145    if is_not(f):
146        return f.arg(0)
147    else:
148        return Not(f)
149
150def cube2clause(cube):
151    return Or([negate(f) for f in cube])
152
153class State:
154    def __init__(self, s):
155        self.R = set([])
156        self.solver = s
157
158    def add(self, clause):
159        if clause not in self.R:
160           self.R |= { clause }
161           self.solver.add(clause)
162
163def is_seq(f):
164    return isinstance(f, list) or isinstance(f, tuple) or isinstance(f, AstVector)
165
166# Check if the initial state is bad
167def check_disjoint(a, b):
168    s = fd_solver()
169    s.add(a)
170    s.add(b)
171    return unsat == s.check()
172
173
174# Remove clauses that are subsumed
175def prune(R):
176    removed = set([])
177    s = fd_solver()
178    for f1 in R:
179        s.push()
180        for f2 in R:
181            if f2 not in removed:
182               s.add(Not(f2) if f1.eq(f2) else f2)
183        if s.check() == unsat:
184            removed |= { f1 }
185        s.pop()
186    return R - removed
187
188# Quip variant of IC3
189
190must = True
191may = False
192
193class QLemma:
194    def __init__(self, c):
195        self.cube = c
196        self.clause = cube2clause(c)
197        self.bad = False
198
199    def __hash__(self):
200        return hash(tuple(set(self.cube)))
201
202    def __eq__(self, qlemma2):
203        if set(self.cube) == set(qlemma2.cube) and self.bad == qlemma2.bad:
204            return True
205        else:
206            return False
207
208    def __ne__():
209        if not self.__eq__(self, qlemma2):
210            return True
211        else:
212            return False
213
214
215class QGoal:
216    def __init__(self, cube, parent, level, must, encounter):
217        self.level = level
218        self.cube = cube
219        self.parent = parent
220        self.must = must
221
222    def __lt__(self, other):
223        return self.level < other.level
224
225
226class QReach:
227
228    # it is assumed that there is a single initial state
229    # with all latches set to 0 in hardware design, so
230    # here init will always give a state where all variable are set to 0
231    def __init__(self, init, xs):
232        self.xs = xs
233        self.constant_xs = [Not(x) for x in self.xs]
234        s = fd_solver()
235        s.add(init)
236        is_sat = s.check()
237        assert is_sat == sat
238        m = s.model()
239        # xs is a list, "for" will keep the order when iterating
240        self.states = numpy.array([[False for x in self.xs]])  # all set to False
241        assert not numpy.max(self.states)  # since all element is False, so maximum should be False
242
243    # check if new state exists
244    def is_exist(self, state):
245        if state in self.states:
246            return True
247        return False
248
249    def enumerate(self, i, state_b, state):
250        while i < len(state) and state[i] not in self.xs:
251            i += 1
252        if i >= len(state):
253            if state_b.tolist() not in self.states.tolist():
254                self.states = numpy.append(self.states, [state_b], axis = 0)
255                return state_b
256            else:
257                return None
258        state_b[i] = False
259        if self.enumerate(i+1, state_b, state) is not None:
260            return state_b
261        else:
262            state_b[i] = True
263            return self.enumerate(i+1, state_b, state)
264
265    def is_full_state(self, state):
266        for i in range(len(self.xs)):
267            if state[i] in self.xs:
268                return False
269        return True
270
271    def add(self, cube):
272        state = self.cube2partial_state(cube)
273        assert len(state) == len(self.xs)
274        if not self.is_exist(state):
275            return None
276        if self.is_full_state(state):
277            self.states = numpy.append(self.states, [state], axis = 0)
278        else:
279            # state[i] is instance, state_b[i] is boolean
280            state_b = numpy.array(state)
281            for i in range(len(state)):  # state is of same length as self.xs
282                # i-th literal in state hasn't been assigned value
283                # init un-assigned literals in state_b as True
284                # make state_b only contain boolean value
285                if state[i] in self.xs:
286                    state_b[i] = True
287                else:
288                    state_b[i] = is_true(state[i])
289            if self.enumerate(0, state_b, state) is not None:
290                lits_to_remove = set([negate(f) for f in list(set(cube) - set(self.constant_xs))])
291                self.constant_xs = list(set(self.constant_xs) - lits_to_remove)
292                return state
293        return None
294
295
296    def cube2partial_state(self, cube):
297        s = fd_solver()
298        s.add(And(cube))
299        is_sat = s.check()
300        assert is_sat == sat
301        m = s.model()
302        state = numpy.array([m.eval(x) for x in self.xs])
303        return state
304
305
306    def state2cube(self, s):
307        result = copy.deepcopy(self.xs)  # x1, x2, ...
308        for i in range(len(self.xs)):
309            if not s[i]:
310                result[i] = Not(result[i])
311        return result
312
313    def intersect(self, cube):
314        state = self.cube2partial_state(cube)
315        mask = True
316        for i in range(len(self.xs)):
317            if is_true(state[i]) or is_false(state[i]):
318                mask = (self.states[:, i] == state[i]) & mask
319        intersects = numpy.reshape(self.states[mask], (-1, len(self.xs)))
320        if intersects.size > 0:
321            return And(self.state2cube(intersects[0]))  # only need to return one single intersect
322        return None
323
324
325class Quip:
326
327    def __init__(self, init, trans, goal, x0, inputs, xn):
328        self.x0 = x0
329        self.inputs = inputs
330        self.xn = xn
331        self.init = init
332        self.bad = goal
333        self.trans = trans
334        self.min_cube_solver = fd_solver()
335        self.min_cube_solver.add(Not(trans))
336        self.goals = []
337        s = State(fd_solver())
338        s.add(init)
339        s.solver.add(trans)  # check if a bad state can be reached in one step from current level
340        self.states = [s]
341        self.s_bad = fd_solver()
342        self.s_good = fd_solver()
343        self.s_bad.add(self.bad)
344        self.s_good.add(Not(self.bad))
345        self.reachable = QReach(self.init, x0)
346        self.frames = []  # frames is a 2d list, each row (representing level) is a set containing several (clause, bad) pairs
347        self.count_may = 0
348
349    def next(self, f):
350        if is_seq(f):
351           return [self.next(f1) for f1 in f]
352        return substitute(f, zip(self.x0, self.xn))
353
354    def prev(self, f):
355        if is_seq(f):
356           return [self.prev(f1) for f1 in f]
357        return substitute(f, zip(self.xn, self.x0))
358
359    def add_solver(self):
360        s = fd_solver()
361        s.add(self.trans)
362        self.states += [State(s)]
363
364    def R(self, i):
365        return And(self.states[i].R)
366
367    def value2literal(self, m, x):
368        value = m.eval(x)
369        if is_true(value):
370            return x
371        if is_false(value):
372            return Not(x)
373        return None
374
375    def values2literals(self, m, xs):
376        p = [self.value2literal(m, x) for x in xs]
377        return [x for x in p if x is not None]
378
379    def project0(self, m):
380        return self.values2literals(m, self.x0)
381
382    def projectI(self, m):
383        return self.values2literals(m, self.inputs)
384
385    def projectN(self, m):
386        return self.values2literals(m, self.xn)
387
388
389    # Block a cube by asserting the clause corresponding to its negation
390    def block_cube(self, i, cube):
391        self.assert_clause(i, cube2clause(cube))
392
393    # Add a clause to levels 1 until i
394    def assert_clause(self, i, clause):
395        for j in range(1, i + 1):
396            self.states[j].add(clause)
397            assert str(self.states[j].solver) != str([False])
398
399
400    # minimize cube that is core of Dual solver.
401    # this assumes that props & cube => Trans
402    # which means props & cube can only give us a Tr in Trans,
403    # and it will never make !Trans sat
404    def minimize_cube(self, cube, inputs, lits):
405        # min_cube_solver has !Trans (min_cube.solver.add(!Trans))
406        is_sat = self.min_cube_solver.check(lits + [c for c in cube] + [i for i in inputs])
407        assert is_sat == unsat
408        # unsat_core gives us some lits which make Tr sat,
409        # so that we can ignore other lits and include more states
410        core = self.min_cube_solver.unsat_core()
411        assert core
412        return [c for c in core if c in set(cube)]
413
414    # push a goal on a heap
415    def push_heap(self, goal):
416        heapq.heappush(self.goals, (goal.level, goal))
417
418
419    # make sure cube to be blocked excludes all reachable states
420    def check_reachable(self, cube):
421        s = fd_solver()
422        for state in self.reachable.states:
423            s.push()
424            r = self.reachable.state2cube(state)
425            s.add(And(self.prev(r)))
426            s.add(self.prev(cube))
427            is_sat = s.check()
428            s.pop()
429            if is_sat == sat:
430                # if sat, it means the cube to be blocked contains reachable states
431                # so it is an invalid cube
432                return False
433        # if all fail, is_sat will be unsat
434        return True
435
436    # Rudimentary generalization:
437    # If the cube is already unsat with respect to transition relation
438    # extract a core (not necessarily minimal)
439    # otherwise, just return the cube.
440    def generalize(self, cube, f):
441        s = self.states[f - 1].solver
442        if unsat == s.check(cube):
443            core = s.unsat_core()
444            if self.check_reachable(core):
445                return core, f
446        return cube, f
447
448
449    def valid_reachable(self, level):
450        s = fd_solver()
451        s.add(self.init)
452        for i in range(level):
453            s.add(self.trans)
454        for state in self.reachable.states:
455            s.push()
456            s.add(And(self.next(self.reachable.state2cube(state))))
457            print self.reachable.state2cube(state)
458            print s.check()
459            s.pop()
460
461    def lemmas(self, level):
462        return [(l.clause, l.bad) for l in self.frames[level]]
463
464    # whenever a new reachable state is found, we use it to mark some existing lemmas as bad lemmas
465    def mark_bad_lemmas(self, new):
466        s = fd_solver()
467        reset = False
468        for frame in self.frames:
469            for lemma in frame:
470                s.push()
471                s.add(lemma.clause)
472                is_sat = s.check(new)
473                if is_sat == unsat:
474                    reset = True
475                    lemma.bad = True
476                s.pop()
477        if reset:
478            self.states = [self.states[0]]
479            for i in range(1, len(self.frames)):
480                self.add_solver()
481                for lemma in self.frames[i]:
482                    if not lemma.bad:
483                        self.states[i].add(lemma.clause)
484
485    # prev & tras -> r', such that r' intersects with cube
486    def add_reachable(self, prev, cube):
487        s = fd_solver()
488        s.add(self.trans)
489        s.add(prev)
490        s.add(self.next(And(cube)))
491        is_sat = s.check()
492        assert is_sat == sat
493        m = s.model()
494        new = self.projectN(m)
495        state = self.reachable.add(self.prev(new))  # always add as non-primed
496        if state is not None:  # if self.states do not have new state yet
497            self.mark_bad_lemmas(self.prev(new))
498
499
500    # Check if the negation of cube is inductive at level f
501    def is_inductive(self, f, cube):
502        s = self.states[f - 1].solver
503        s.push()
504        s.add(self.prev(Not(And(cube))))
505        is_sat = s.check(cube)
506        if is_sat == sat:
507            m = s.model()
508        s.pop()
509        if is_sat == sat:
510            cube = self.next(self.minimize_cube(self.project0(m), self.projectI(m), self.projectN(m)))
511        elif is_sat == unsat:
512            cube, f = self.generalize(cube, f)
513            cube = self.next(cube)
514        return cube, f, is_sat
515
516
517    # Determine if there is a cube for the current state
518    # that is potentially reachable.
519    def unfold(self, level):
520        core = []
521        self.s_bad.push()
522        R = self.R(level)
523        self.s_bad.add(R)  # check if current frame intersects with bad states, no trans
524        is_sat = self.s_bad.check()
525        if is_sat == sat:
526           m = self.s_bad.model()
527           cube = self.project0(m)
528           props = cube + self.projectI(m)
529           self.s_good.push()
530           self.s_good.add(R)
531           is_sat2 = self.s_good.check(props)
532           assert is_sat2 == unsat
533           core = self.s_good.unsat_core()
534           assert core
535           core = [c for c in core if c in set(cube)]
536           self.s_good.pop()
537        self.s_bad.pop()
538        return is_sat, core
539
540    # A state s0 and level f0 such that
541    # not(s0) is f0-1 inductive
542    def quip_blocked(self, s0, f0):
543        self.push_heap(QGoal(self.next(s0), None, f0, must, 0))
544        while self.goals:
545            f, g = heapq.heappop(self.goals)
546            sys.stdout.write("%d." % f)
547            if not g.must:
548                self.count_may -= 1
549            sys.stdout.flush()
550            if f == 0:
551                if g.must:
552                    s = fd_solver()
553                    s.add(self.init)
554                    s.add(self.prev(g.cube))
555                    # since init is a complete assignment, so g.cube must equal to init in sat solver
556                    assert is_sat == s.check()
557                    if verbose:
558                        print("")
559                    return g
560                self.add_reachable(self.init, g.parent.cube)
561                continue
562
563            r0 = self.reachable.intersect(self.prev(g.cube))
564            if r0 is not None:
565                if g.must:
566                    if verbose:
567                        print ""
568                    s = fd_solver()
569                    s.add(self.trans)
570                    # make it as a concrete reachable state
571                    # intersect returns an And(...), so use children to get cube list
572                    g.cube = r0.children()
573                    while True:
574                        is_sat = s.check(self.next(g.cube))
575                        assert is_sat == sat
576                        r = self.next(self.project0(s.model()))
577                        r = self.reachable.intersect(self.prev(r))
578                        child = QGoal(self.next(r.children()), g, 0, g.must, 0)
579                        g = child
580                        if not check_disjoint(self.init, self.prev(g.cube)):
581                            # g is init, break the loop
582                            break
583                    init = g
584                    while g.parent is not None:
585                        g.parent.level = g.level + 1
586                        g = g.parent
587                    return init
588                if g.parent is not None:
589                    self.add_reachable(r0, g.parent.cube)
590                continue
591
592            cube = None
593            is_sat = sat
594            f_1 = len(self.frames) - 1
595            while f_1 >= f:
596                for l in self.frames[f_1]:
597                    if not l.bad and len(l.cube) > 0 and set(l.cube).issubset(g.cube):
598                        cube = l.cube
599                        is_sat == unsat
600                        break
601                f_1 -= 1
602            if cube is None:
603                cube, f_1, is_sat = self.is_inductive(f, g.cube)
604            if is_sat == unsat:
605                self.frames[f_1].add(QLemma(self.prev(cube)))
606                self.block_cube(f_1, self.prev(cube))
607                if f_1 < f0:
608                    # learned clause might also be able to block same bad states in higher level
609                    if set(list(cube)) != set(list(g.cube)):
610                        self.push_heap(QGoal(cube, None, f_1 + 1, may, 0))
611                        self.count_may += 1
612                    else:
613                        # re-queue g.cube in higher level, here g.parent is simply for tracking down the trace when output.
614                        self.push_heap(QGoal(g.cube, g.parent, f_1 + 1, g.must, 0))
615                        if not g.must:
616                            self.count_may += 1
617            else:
618                # qcube is a predecessor of g
619                qcube = QGoal(cube, g, f_1 - 1, g.must, 0)
620                if not g.must:
621                    self.count_may += 1
622                self.push_heap(qcube)
623
624        if verbose:
625            print("")
626        return None
627
628    # Check if there are two states next to each other that have the same clauses.
629    def is_valid(self):
630        i = 1
631        inv = None
632        while True:
633            # self.states[].R contains full lemmas
634            # self.frames[] contains delta-encoded lemmas
635            while len(self.states) <= i+1:
636                self.add_solver()
637            while len(self.frames) <= i+1:
638                self.frames.append(set())
639            duplicates = set([])
640            for l in self.frames[i+1]:
641                if l in self.frames[i]:
642                    duplicates |= {l}
643            self.frames[i] = self.frames[i] - duplicates
644            pushed = set([])
645            for l in (self.frames[i] - self.frames[i+1]):
646                if not l.bad:
647                    s = self.states[i].solver
648                    s.push()
649                    s.add(self.next(Not(l.clause)))
650                    s.add(l.clause)
651                    is_sat = s.check()
652                    s.pop()
653                    if is_sat == unsat:
654                        self.frames[i+1].add(l)
655                        self.states[i+1].add(l.clause)
656                        pushed |= {l}
657            self.frames[i] = self.frames[i] - pushed
658            if (not (self.states[i].R - self.states[i+1].R)
659                and len(self.states[i].R) != 0):
660                inv = prune(self.states[i].R)
661                F_inf = self.frames[i]
662                j = i + 1
663                while j < len(self.states):
664                    for l in F_inf:
665                        self.states[j].add(l.clause)
666                    j += 1
667                self.frames[len(self.states)-1] = F_inf
668                self.frames[i] = set([])
669                break
670            elif (len(self.states[i].R) == 0
671                  and len(self.states[i+1].R) == 0):
672                break
673            i += 1
674
675        if inv is not None:
676            self.s_bad.push()
677            self.s_bad.add(And(inv))
678            is_sat = self.s_bad.check()
679            if is_sat == unsat:
680                self.s_bad.pop()
681                return And(inv)
682            self.s_bad.pop()
683        return None
684
685    def run(self):
686        if not check_disjoint(self.init, self.bad):
687            return "goal is reached in initial state"
688        level = 0
689        while True:
690            inv = self.is_valid()  # self.add_solver() here
691            if inv is not None:
692                return inv
693            is_sat, cube = self.unfold(level)
694            if is_sat == unsat:
695                level += 1
696                if verbose:
697                    print("Unfold %d" % level)
698                sys.stdout.flush()
699            elif is_sat == sat:
700                cex = self.quip_blocked(cube, level)
701                if cex is not None:
702                    return cex
703            else:
704                return is_sat
705
706def test(file):
707    h2t = Horn2Transitions()
708    h2t.parse(file)
709    if verbose:
710        print("Test file: %s") % file
711    mp = Quip(h2t.init, h2t.trans, h2t.goal, h2t.xs, h2t.inputs, h2t.xns)
712    start_time = time.time()
713    result = mp.run()
714    end_time = time.time()
715    if isinstance(result, QGoal):
716        g = result
717        if verbose:
718            print("Trace")
719        while g:
720           if verbose:
721               print(g.level, g.cube)
722           g = g.parent
723        print("--- used %.3f seconds ---" % (end_time - start_time))
724        validate(mp, result, mp.trans)
725        return
726    if isinstance(result, ExprRef):
727        if verbose:
728            print("Invariant:\n%s " % result)
729        print("--- used %.3f seconds ---" % (end_time - start_time))
730        validate(mp, result, mp.trans)
731        return
732    print(result)
733
734def validate(var, result, trans):
735    if isinstance(result, QGoal):
736        g = result
737        s = fd_solver()
738        s.add(trans)
739        while g.parent is not None:
740            s.push()
741            s.add(var.prev(g.cube))
742            s.add(var.next(g.parent.cube))
743            assert sat == s.check()
744            s.pop()
745            g = g.parent
746        if verbose:
747            print "--- validation succeed ----"
748        return
749    if isinstance(result, ExprRef):
750        inv = result
751        s = fd_solver()
752        s.add(trans)
753        s.push()
754        s.add(var.prev(inv))
755        s.add(Not(var.next(inv)))
756        assert unsat == s.check()
757        s.pop()
758        cube = var.prev(var.init)
759        step = 0
760        while True:
761            step += 1
762            # too many steps to reach invariant
763            if step > 1000:
764                if verbose:
765                    print "--- validation failed --"
766                return
767            if not check_disjoint(var.prev(cube), var.prev(inv)):
768                # reach invariant
769                break
770            s.push()
771            s.add(cube)
772            assert s.check() == sat
773            cube = var.projectN(s.model())
774            s.pop()
775        if verbose:
776            print "--- validation succeed ----"
777        return
778
779
780
781test("data/horn1.smt2")
782test("data/horn2.smt2")
783test("data/horn3.smt2")
784test("data/horn4.smt2")
785test("data/horn5.smt2")
786# test("data/horn6.smt2")  # not able to finish
787