1""" 2This script stress tests deadlock detection. 3 4Usage: rocksdb_deadlock_stress.py user host port db_name table_name 5 num_iters num_threads 6""" 7import cStringIO 8import hashlib 9import MySQLdb 10from MySQLdb.constants import ER 11import os 12import random 13import signal 14import sys 15import threading 16import time 17import string 18import traceback 19 20def is_deadlock_error(exc): 21 error_code = exc.args[0] 22 return (error_code == MySQLdb.constants.ER.LOCK_DEADLOCK) 23 24def get_query(table_name, idx): 25 # Let's assume that even indexes will always be acquireable, to make 26 # deadlock detection more interesting. 27 if idx % 2 == 0: 28 return """SELECT * from %s WHERE a = %d LOCK IN SHARE MODE""" % (table_name, idx) 29 else: 30 r = random.randint(1, 3); 31 if r == 1: 32 return """SELECT * from %s WHERE a = %d FOR UPDATE""" % (table_name, idx) 33 elif r == 2: 34 return """INSERT INTO %s VALUES (%d, 1) 35 ON DUPLICATE KEY UPDATE b=b+1""" % (table_name, idx) 36 else: 37 return """DELETE from %s WHERE a = %d""" % (table_name, idx) 38 39class Worker(threading.Thread): 40 def __init__(self, con, table_name, num_iters): 41 threading.Thread.__init__(self) 42 self.con = con 43 self.table_name = table_name 44 self.num_iters = num_iters 45 self.exception = None 46 self.start() 47 def run(self): 48 try: 49 self.runme() 50 except Exception, e: 51 self.exception = traceback.format_exc() 52 def runme(self): 53 cur = self.con.cursor() 54 for x in xrange(self.num_iters): 55 try: 56 for i in random.sample(xrange(100), 10): 57 cur.execute(get_query(self.table_name, i)) 58 self.con.commit() 59 except MySQLdb.OperationalError, e: 60 self.con.rollback() 61 cur = self.con.cursor() 62 if not is_deadlock_error(e): 63 raise e 64 65if __name__ == '__main__': 66 if len(sys.argv) != 8: 67 print "Usage: rocksdb_deadlock_stress.py user host port db_name " \ 68 "table_name num_iters num_threads" 69 sys.exit(1) 70 71 user = sys.argv[1] 72 host = sys.argv[2] 73 port = int(sys.argv[3]) 74 db = sys.argv[4] 75 table_name = sys.argv[5] 76 num_iters = int(sys.argv[6]) 77 num_workers = int(sys.argv[7]) 78 79 worker_failed = False 80 workers = [] 81 for i in xrange(num_workers): 82 w = Worker( 83 MySQLdb.connect(user=user, host=host, port=port, db=db), table_name, 84 num_iters) 85 workers.append(w) 86 87 for w in workers: 88 w.join() 89 if w.exception: 90 print "Worker hit an exception:\n%s\n" % w.exception 91 worker_failed = True 92 93 if worker_failed: 94 sys.exit(1) 95