1from z3 import * 2import heapq 3import numpy 4import time 5import random 6 7verbose = True 8 9# Simplistic (and fragile) converter from 10# a class of Horn clauses corresponding to 11# a transition system into a transition system 12# representation as <init, trans, goal> 13# It assumes it is given three Horn clauses 14# of the form: 15# init(x) => Invariant(x) 16# Invariant(x) and trans(x,x') => Invariant(x') 17# Invariant(x) and goal(x) => Goal(x) 18# where Invariant and Goal are uninterpreted predicates 19 20class Horn2Transitions: 21 def __init__(self): 22 self.trans = True 23 self.init = True 24 self.inputs = [] 25 self.goal = True 26 self.index = 0 27 28 def parse(self, file): 29 fp = Fixedpoint() 30 goals = fp.parse_file(file) 31 for r in fp.get_rules(): 32 if not is_quantifier(r): 33 continue 34 b = r.body() 35 if not is_implies(b): 36 continue 37 f = b.arg(0) 38 g = b.arg(1) 39 if self.is_goal(f, g): 40 continue 41 if self.is_transition(f, g): 42 continue 43 if self.is_init(f, g): 44 continue 45 46 def is_pred(self, p, name): 47 return is_app(p) and p.decl().name() == name 48 49 def is_goal(self, body, head): 50 if not self.is_pred(head, "Goal"): 51 return False 52 pred, inv = self.is_body(body) 53 if pred is None: 54 return False 55 self.goal = self.subst_vars("x", inv, pred) 56 self.goal = self.subst_vars("i", self.goal, self.goal) 57 self.inputs += self.vars 58 self.inputs = list(set(self.inputs)) 59 return True 60 61 def is_body(self, body): 62 if not is_and(body): 63 return None, None 64 fmls = [f for f in body.children() if self.is_inv(f) is None] 65 inv = None 66 for f in body.children(): 67 if self.is_inv(f) is not None: 68 inv = f; 69 break 70 return And(fmls), inv 71 72 def is_inv(self, f): 73 if self.is_pred(f, "Invariant"): 74 return f 75 return None 76 77 def is_transition(self, body, head): 78 pred, inv0 = self.is_body(body) 79 if pred is None: 80 return False 81 inv1 = self.is_inv(head) 82 if inv1 is None: 83 return False 84 pred = self.subst_vars("x", inv0, pred) 85 self.xs = self.vars 86 pred = self.subst_vars("xn", inv1, pred) 87 self.xns = self.vars 88 pred = self.subst_vars("i", pred, pred) 89 self.inputs += self.vars 90 self.inputs = list(set(self.inputs)) 91 self.trans = pred 92 return True 93 94 def is_init(self, body, head): 95 for f in body.children(): 96 if self.is_inv(f) is not None: 97 return False 98 inv = self.is_inv(head) 99 if inv is None: 100 return False 101 self.init = self.subst_vars("x", inv, body) 102 return True 103 104 def subst_vars(self, prefix, inv, fml): 105 subst = self.mk_subst(prefix, inv) 106 self.vars = [ v for (k,v) in subst ] 107 return substitute(fml, subst) 108 109 def mk_subst(self, prefix, inv): 110 self.index = 0 111 if self.is_inv(inv) is not None: 112 return [(f, self.mk_bool(prefix)) for f in inv.children()] 113 else: 114 vars = self.get_vars(inv) 115 return [(f, self.mk_bool(prefix)) for f in vars] 116 117 def mk_bool(self, prefix): 118 self.index += 1 119 return Bool("%s%d" % (prefix, self.index)) 120 121 def get_vars(self, f, rs=[]): 122 if is_var(f): 123 return z3util.vset(rs + [f], str) 124 else: 125 for f_ in f.children(): 126 rs = self.get_vars(f_, rs) 127 return z3util.vset(rs, str) 128 129# Produce a finite domain solver. 130# The theory QF_FD covers bit-vector formulas 131# and pseudo-Boolean constraints. 132# By default cardinality and pseudo-Boolean 133# constraints are converted to clauses. To override 134# this default for cardinality constraints 135# we set sat.cardinality.solver to True 136 137def fd_solver(): 138 s = SolverFor("QF_FD") 139 s.set("sat.cardinality.solver", True) 140 return s 141 142 143# negate, avoid double negation 144def negate(f): 145 if is_not(f): 146 return f.arg(0) 147 else: 148 return Not(f) 149 150def cube2clause(cube): 151 return Or([negate(f) for f in cube]) 152 153class State: 154 def __init__(self, s): 155 self.R = set([]) 156 self.solver = s 157 158 def add(self, clause): 159 if clause not in self.R: 160 self.R |= { clause } 161 self.solver.add(clause) 162 163def is_seq(f): 164 return isinstance(f, list) or isinstance(f, tuple) or isinstance(f, AstVector) 165 166# Check if the initial state is bad 167def check_disjoint(a, b): 168 s = fd_solver() 169 s.add(a) 170 s.add(b) 171 return unsat == s.check() 172 173 174# Remove clauses that are subsumed 175def prune(R): 176 removed = set([]) 177 s = fd_solver() 178 for f1 in R: 179 s.push() 180 for f2 in R: 181 if f2 not in removed: 182 s.add(Not(f2) if f1.eq(f2) else f2) 183 if s.check() == unsat: 184 removed |= { f1 } 185 s.pop() 186 return R - removed 187 188# Quip variant of IC3 189 190must = True 191may = False 192 193class QLemma: 194 def __init__(self, c): 195 self.cube = c 196 self.clause = cube2clause(c) 197 self.bad = False 198 199 def __hash__(self): 200 return hash(tuple(set(self.cube))) 201 202 def __eq__(self, qlemma2): 203 if set(self.cube) == set(qlemma2.cube) and self.bad == qlemma2.bad: 204 return True 205 else: 206 return False 207 208 def __ne__(): 209 if not self.__eq__(self, qlemma2): 210 return True 211 else: 212 return False 213 214 215class QGoal: 216 def __init__(self, cube, parent, level, must, encounter): 217 self.level = level 218 self.cube = cube 219 self.parent = parent 220 self.must = must 221 222 def __lt__(self, other): 223 return self.level < other.level 224 225 226class QReach: 227 228 # it is assumed that there is a single initial state 229 # with all latches set to 0 in hardware design, so 230 # here init will always give a state where all variable are set to 0 231 def __init__(self, init, xs): 232 self.xs = xs 233 self.constant_xs = [Not(x) for x in self.xs] 234 s = fd_solver() 235 s.add(init) 236 is_sat = s.check() 237 assert is_sat == sat 238 m = s.model() 239 # xs is a list, "for" will keep the order when iterating 240 self.states = numpy.array([[False for x in self.xs]]) # all set to False 241 assert not numpy.max(self.states) # since all element is False, so maximum should be False 242 243 # check if new state exists 244 def is_exist(self, state): 245 if state in self.states: 246 return True 247 return False 248 249 def enumerate(self, i, state_b, state): 250 while i < len(state) and state[i] not in self.xs: 251 i += 1 252 if i >= len(state): 253 if state_b.tolist() not in self.states.tolist(): 254 self.states = numpy.append(self.states, [state_b], axis = 0) 255 return state_b 256 else: 257 return None 258 state_b[i] = False 259 if self.enumerate(i+1, state_b, state) is not None: 260 return state_b 261 else: 262 state_b[i] = True 263 return self.enumerate(i+1, state_b, state) 264 265 def is_full_state(self, state): 266 for i in range(len(self.xs)): 267 if state[i] in self.xs: 268 return False 269 return True 270 271 def add(self, cube): 272 state = self.cube2partial_state(cube) 273 assert len(state) == len(self.xs) 274 if not self.is_exist(state): 275 return None 276 if self.is_full_state(state): 277 self.states = numpy.append(self.states, [state], axis = 0) 278 else: 279 # state[i] is instance, state_b[i] is boolean 280 state_b = numpy.array(state) 281 for i in range(len(state)): # state is of same length as self.xs 282 # i-th literal in state hasn't been assigned value 283 # init un-assigned literals in state_b as True 284 # make state_b only contain boolean value 285 if state[i] in self.xs: 286 state_b[i] = True 287 else: 288 state_b[i] = is_true(state[i]) 289 if self.enumerate(0, state_b, state) is not None: 290 lits_to_remove = set([negate(f) for f in list(set(cube) - set(self.constant_xs))]) 291 self.constant_xs = list(set(self.constant_xs) - lits_to_remove) 292 return state 293 return None 294 295 296 def cube2partial_state(self, cube): 297 s = fd_solver() 298 s.add(And(cube)) 299 is_sat = s.check() 300 assert is_sat == sat 301 m = s.model() 302 state = numpy.array([m.eval(x) for x in self.xs]) 303 return state 304 305 306 def state2cube(self, s): 307 result = copy.deepcopy(self.xs) # x1, x2, ... 308 for i in range(len(self.xs)): 309 if not s[i]: 310 result[i] = Not(result[i]) 311 return result 312 313 def intersect(self, cube): 314 state = self.cube2partial_state(cube) 315 mask = True 316 for i in range(len(self.xs)): 317 if is_true(state[i]) or is_false(state[i]): 318 mask = (self.states[:, i] == state[i]) & mask 319 intersects = numpy.reshape(self.states[mask], (-1, len(self.xs))) 320 if intersects.size > 0: 321 return And(self.state2cube(intersects[0])) # only need to return one single intersect 322 return None 323 324 325class Quip: 326 327 def __init__(self, init, trans, goal, x0, inputs, xn): 328 self.x0 = x0 329 self.inputs = inputs 330 self.xn = xn 331 self.init = init 332 self.bad = goal 333 self.trans = trans 334 self.min_cube_solver = fd_solver() 335 self.min_cube_solver.add(Not(trans)) 336 self.goals = [] 337 s = State(fd_solver()) 338 s.add(init) 339 s.solver.add(trans) # check if a bad state can be reached in one step from current level 340 self.states = [s] 341 self.s_bad = fd_solver() 342 self.s_good = fd_solver() 343 self.s_bad.add(self.bad) 344 self.s_good.add(Not(self.bad)) 345 self.reachable = QReach(self.init, x0) 346 self.frames = [] # frames is a 2d list, each row (representing level) is a set containing several (clause, bad) pairs 347 self.count_may = 0 348 349 def next(self, f): 350 if is_seq(f): 351 return [self.next(f1) for f1 in f] 352 return substitute(f, zip(self.x0, self.xn)) 353 354 def prev(self, f): 355 if is_seq(f): 356 return [self.prev(f1) for f1 in f] 357 return substitute(f, zip(self.xn, self.x0)) 358 359 def add_solver(self): 360 s = fd_solver() 361 s.add(self.trans) 362 self.states += [State(s)] 363 364 def R(self, i): 365 return And(self.states[i].R) 366 367 def value2literal(self, m, x): 368 value = m.eval(x) 369 if is_true(value): 370 return x 371 if is_false(value): 372 return Not(x) 373 return None 374 375 def values2literals(self, m, xs): 376 p = [self.value2literal(m, x) for x in xs] 377 return [x for x in p if x is not None] 378 379 def project0(self, m): 380 return self.values2literals(m, self.x0) 381 382 def projectI(self, m): 383 return self.values2literals(m, self.inputs) 384 385 def projectN(self, m): 386 return self.values2literals(m, self.xn) 387 388 389 # Block a cube by asserting the clause corresponding to its negation 390 def block_cube(self, i, cube): 391 self.assert_clause(i, cube2clause(cube)) 392 393 # Add a clause to levels 1 until i 394 def assert_clause(self, i, clause): 395 for j in range(1, i + 1): 396 self.states[j].add(clause) 397 assert str(self.states[j].solver) != str([False]) 398 399 400 # minimize cube that is core of Dual solver. 401 # this assumes that props & cube => Trans 402 # which means props & cube can only give us a Tr in Trans, 403 # and it will never make !Trans sat 404 def minimize_cube(self, cube, inputs, lits): 405 # min_cube_solver has !Trans (min_cube.solver.add(!Trans)) 406 is_sat = self.min_cube_solver.check(lits + [c for c in cube] + [i for i in inputs]) 407 assert is_sat == unsat 408 # unsat_core gives us some lits which make Tr sat, 409 # so that we can ignore other lits and include more states 410 core = self.min_cube_solver.unsat_core() 411 assert core 412 return [c for c in core if c in set(cube)] 413 414 # push a goal on a heap 415 def push_heap(self, goal): 416 heapq.heappush(self.goals, (goal.level, goal)) 417 418 419 # make sure cube to be blocked excludes all reachable states 420 def check_reachable(self, cube): 421 s = fd_solver() 422 for state in self.reachable.states: 423 s.push() 424 r = self.reachable.state2cube(state) 425 s.add(And(self.prev(r))) 426 s.add(self.prev(cube)) 427 is_sat = s.check() 428 s.pop() 429 if is_sat == sat: 430 # if sat, it means the cube to be blocked contains reachable states 431 # so it is an invalid cube 432 return False 433 # if all fail, is_sat will be unsat 434 return True 435 436 # Rudimentary generalization: 437 # If the cube is already unsat with respect to transition relation 438 # extract a core (not necessarily minimal) 439 # otherwise, just return the cube. 440 def generalize(self, cube, f): 441 s = self.states[f - 1].solver 442 if unsat == s.check(cube): 443 core = s.unsat_core() 444 if self.check_reachable(core): 445 return core, f 446 return cube, f 447 448 449 def valid_reachable(self, level): 450 s = fd_solver() 451 s.add(self.init) 452 for i in range(level): 453 s.add(self.trans) 454 for state in self.reachable.states: 455 s.push() 456 s.add(And(self.next(self.reachable.state2cube(state)))) 457 print self.reachable.state2cube(state) 458 print s.check() 459 s.pop() 460 461 def lemmas(self, level): 462 return [(l.clause, l.bad) for l in self.frames[level]] 463 464 # whenever a new reachable state is found, we use it to mark some existing lemmas as bad lemmas 465 def mark_bad_lemmas(self, new): 466 s = fd_solver() 467 reset = False 468 for frame in self.frames: 469 for lemma in frame: 470 s.push() 471 s.add(lemma.clause) 472 is_sat = s.check(new) 473 if is_sat == unsat: 474 reset = True 475 lemma.bad = True 476 s.pop() 477 if reset: 478 self.states = [self.states[0]] 479 for i in range(1, len(self.frames)): 480 self.add_solver() 481 for lemma in self.frames[i]: 482 if not lemma.bad: 483 self.states[i].add(lemma.clause) 484 485 # prev & tras -> r', such that r' intersects with cube 486 def add_reachable(self, prev, cube): 487 s = fd_solver() 488 s.add(self.trans) 489 s.add(prev) 490 s.add(self.next(And(cube))) 491 is_sat = s.check() 492 assert is_sat == sat 493 m = s.model() 494 new = self.projectN(m) 495 state = self.reachable.add(self.prev(new)) # always add as non-primed 496 if state is not None: # if self.states do not have new state yet 497 self.mark_bad_lemmas(self.prev(new)) 498 499 500 # Check if the negation of cube is inductive at level f 501 def is_inductive(self, f, cube): 502 s = self.states[f - 1].solver 503 s.push() 504 s.add(self.prev(Not(And(cube)))) 505 is_sat = s.check(cube) 506 if is_sat == sat: 507 m = s.model() 508 s.pop() 509 if is_sat == sat: 510 cube = self.next(self.minimize_cube(self.project0(m), self.projectI(m), self.projectN(m))) 511 elif is_sat == unsat: 512 cube, f = self.generalize(cube, f) 513 cube = self.next(cube) 514 return cube, f, is_sat 515 516 517 # Determine if there is a cube for the current state 518 # that is potentially reachable. 519 def unfold(self, level): 520 core = [] 521 self.s_bad.push() 522 R = self.R(level) 523 self.s_bad.add(R) # check if current frame intersects with bad states, no trans 524 is_sat = self.s_bad.check() 525 if is_sat == sat: 526 m = self.s_bad.model() 527 cube = self.project0(m) 528 props = cube + self.projectI(m) 529 self.s_good.push() 530 self.s_good.add(R) 531 is_sat2 = self.s_good.check(props) 532 assert is_sat2 == unsat 533 core = self.s_good.unsat_core() 534 assert core 535 core = [c for c in core if c in set(cube)] 536 self.s_good.pop() 537 self.s_bad.pop() 538 return is_sat, core 539 540 # A state s0 and level f0 such that 541 # not(s0) is f0-1 inductive 542 def quip_blocked(self, s0, f0): 543 self.push_heap(QGoal(self.next(s0), None, f0, must, 0)) 544 while self.goals: 545 f, g = heapq.heappop(self.goals) 546 sys.stdout.write("%d." % f) 547 if not g.must: 548 self.count_may -= 1 549 sys.stdout.flush() 550 if f == 0: 551 if g.must: 552 s = fd_solver() 553 s.add(self.init) 554 s.add(self.prev(g.cube)) 555 # since init is a complete assignment, so g.cube must equal to init in sat solver 556 assert is_sat == s.check() 557 if verbose: 558 print("") 559 return g 560 self.add_reachable(self.init, g.parent.cube) 561 continue 562 563 r0 = self.reachable.intersect(self.prev(g.cube)) 564 if r0 is not None: 565 if g.must: 566 if verbose: 567 print "" 568 s = fd_solver() 569 s.add(self.trans) 570 # make it as a concrete reachable state 571 # intersect returns an And(...), so use children to get cube list 572 g.cube = r0.children() 573 while True: 574 is_sat = s.check(self.next(g.cube)) 575 assert is_sat == sat 576 r = self.next(self.project0(s.model())) 577 r = self.reachable.intersect(self.prev(r)) 578 child = QGoal(self.next(r.children()), g, 0, g.must, 0) 579 g = child 580 if not check_disjoint(self.init, self.prev(g.cube)): 581 # g is init, break the loop 582 break 583 init = g 584 while g.parent is not None: 585 g.parent.level = g.level + 1 586 g = g.parent 587 return init 588 if g.parent is not None: 589 self.add_reachable(r0, g.parent.cube) 590 continue 591 592 cube = None 593 is_sat = sat 594 f_1 = len(self.frames) - 1 595 while f_1 >= f: 596 for l in self.frames[f_1]: 597 if not l.bad and len(l.cube) > 0 and set(l.cube).issubset(g.cube): 598 cube = l.cube 599 is_sat == unsat 600 break 601 f_1 -= 1 602 if cube is None: 603 cube, f_1, is_sat = self.is_inductive(f, g.cube) 604 if is_sat == unsat: 605 self.frames[f_1].add(QLemma(self.prev(cube))) 606 self.block_cube(f_1, self.prev(cube)) 607 if f_1 < f0: 608 # learned clause might also be able to block same bad states in higher level 609 if set(list(cube)) != set(list(g.cube)): 610 self.push_heap(QGoal(cube, None, f_1 + 1, may, 0)) 611 self.count_may += 1 612 else: 613 # re-queue g.cube in higher level, here g.parent is simply for tracking down the trace when output. 614 self.push_heap(QGoal(g.cube, g.parent, f_1 + 1, g.must, 0)) 615 if not g.must: 616 self.count_may += 1 617 else: 618 # qcube is a predecessor of g 619 qcube = QGoal(cube, g, f_1 - 1, g.must, 0) 620 if not g.must: 621 self.count_may += 1 622 self.push_heap(qcube) 623 624 if verbose: 625 print("") 626 return None 627 628 # Check if there are two states next to each other that have the same clauses. 629 def is_valid(self): 630 i = 1 631 inv = None 632 while True: 633 # self.states[].R contains full lemmas 634 # self.frames[] contains delta-encoded lemmas 635 while len(self.states) <= i+1: 636 self.add_solver() 637 while len(self.frames) <= i+1: 638 self.frames.append(set()) 639 duplicates = set([]) 640 for l in self.frames[i+1]: 641 if l in self.frames[i]: 642 duplicates |= {l} 643 self.frames[i] = self.frames[i] - duplicates 644 pushed = set([]) 645 for l in (self.frames[i] - self.frames[i+1]): 646 if not l.bad: 647 s = self.states[i].solver 648 s.push() 649 s.add(self.next(Not(l.clause))) 650 s.add(l.clause) 651 is_sat = s.check() 652 s.pop() 653 if is_sat == unsat: 654 self.frames[i+1].add(l) 655 self.states[i+1].add(l.clause) 656 pushed |= {l} 657 self.frames[i] = self.frames[i] - pushed 658 if (not (self.states[i].R - self.states[i+1].R) 659 and len(self.states[i].R) != 0): 660 inv = prune(self.states[i].R) 661 F_inf = self.frames[i] 662 j = i + 1 663 while j < len(self.states): 664 for l in F_inf: 665 self.states[j].add(l.clause) 666 j += 1 667 self.frames[len(self.states)-1] = F_inf 668 self.frames[i] = set([]) 669 break 670 elif (len(self.states[i].R) == 0 671 and len(self.states[i+1].R) == 0): 672 break 673 i += 1 674 675 if inv is not None: 676 self.s_bad.push() 677 self.s_bad.add(And(inv)) 678 is_sat = self.s_bad.check() 679 if is_sat == unsat: 680 self.s_bad.pop() 681 return And(inv) 682 self.s_bad.pop() 683 return None 684 685 def run(self): 686 if not check_disjoint(self.init, self.bad): 687 return "goal is reached in initial state" 688 level = 0 689 while True: 690 inv = self.is_valid() # self.add_solver() here 691 if inv is not None: 692 return inv 693 is_sat, cube = self.unfold(level) 694 if is_sat == unsat: 695 level += 1 696 if verbose: 697 print("Unfold %d" % level) 698 sys.stdout.flush() 699 elif is_sat == sat: 700 cex = self.quip_blocked(cube, level) 701 if cex is not None: 702 return cex 703 else: 704 return is_sat 705 706def test(file): 707 h2t = Horn2Transitions() 708 h2t.parse(file) 709 if verbose: 710 print("Test file: %s") % file 711 mp = Quip(h2t.init, h2t.trans, h2t.goal, h2t.xs, h2t.inputs, h2t.xns) 712 start_time = time.time() 713 result = mp.run() 714 end_time = time.time() 715 if isinstance(result, QGoal): 716 g = result 717 if verbose: 718 print("Trace") 719 while g: 720 if verbose: 721 print(g.level, g.cube) 722 g = g.parent 723 print("--- used %.3f seconds ---" % (end_time - start_time)) 724 validate(mp, result, mp.trans) 725 return 726 if isinstance(result, ExprRef): 727 if verbose: 728 print("Invariant:\n%s " % result) 729 print("--- used %.3f seconds ---" % (end_time - start_time)) 730 validate(mp, result, mp.trans) 731 return 732 print(result) 733 734def validate(var, result, trans): 735 if isinstance(result, QGoal): 736 g = result 737 s = fd_solver() 738 s.add(trans) 739 while g.parent is not None: 740 s.push() 741 s.add(var.prev(g.cube)) 742 s.add(var.next(g.parent.cube)) 743 assert sat == s.check() 744 s.pop() 745 g = g.parent 746 if verbose: 747 print "--- validation succeed ----" 748 return 749 if isinstance(result, ExprRef): 750 inv = result 751 s = fd_solver() 752 s.add(trans) 753 s.push() 754 s.add(var.prev(inv)) 755 s.add(Not(var.next(inv))) 756 assert unsat == s.check() 757 s.pop() 758 cube = var.prev(var.init) 759 step = 0 760 while True: 761 step += 1 762 # too many steps to reach invariant 763 if step > 1000: 764 if verbose: 765 print "--- validation failed --" 766 return 767 if not check_disjoint(var.prev(cube), var.prev(inv)): 768 # reach invariant 769 break 770 s.push() 771 s.add(cube) 772 assert s.check() == sat 773 cube = var.projectN(s.model()) 774 s.pop() 775 if verbose: 776 print "--- validation succeed ----" 777 return 778 779 780 781test("data/horn1.smt2") 782test("data/horn2.smt2") 783test("data/horn3.smt2") 784test("data/horn4.smt2") 785test("data/horn5.smt2") 786# test("data/horn6.smt2") # not able to finish 787