1"""
2This script stress tests deadlock detection.
3
4Usage: rocksdb_deadlock_stress.py user host port db_name table_name
5       num_iters num_threads
6"""
7import cStringIO
8import hashlib
9import MySQLdb
10from MySQLdb.constants import ER
11import os
12import random
13import signal
14import sys
15import threading
16import time
17import string
18import traceback
19
20def is_deadlock_error(exc):
21    error_code = exc.args[0]
22    return (error_code == MySQLdb.constants.ER.LOCK_DEADLOCK)
23
24def get_query(table_name, idx):
25  # Let's assume that even indexes will always be acquireable, to make
26  # deadlock detection more interesting.
27  if idx % 2 == 0:
28    return """SELECT * from %s WHERE a = %d LOCK IN SHARE MODE""" % (table_name, idx)
29  else:
30    r = random.randint(1, 3);
31    if r == 1:
32      return """SELECT * from %s WHERE a = %d FOR UPDATE""" % (table_name, idx)
33    elif r == 2:
34      return """INSERT INTO %s VALUES (%d, 1)
35                ON DUPLICATE KEY UPDATE b=b+1""" % (table_name, idx)
36    else:
37      return """DELETE from %s WHERE a = %d""" % (table_name, idx)
38
39class Worker(threading.Thread):
40  def __init__(self, con, table_name, num_iters):
41    threading.Thread.__init__(self)
42    self.con = con
43    self.table_name = table_name
44    self.num_iters = num_iters
45    self.exception = None
46    self.start()
47  def run(self):
48    try:
49      self.runme()
50    except Exception, e:
51      self.exception = traceback.format_exc()
52  def runme(self):
53    cur = self.con.cursor()
54    for x in xrange(self.num_iters):
55      try:
56        for i in random.sample(xrange(100), 10):
57          cur.execute(get_query(self.table_name, i))
58        self.con.commit()
59      except MySQLdb.OperationalError, e:
60        self.con.rollback()
61        cur = self.con.cursor()
62        if not is_deadlock_error(e):
63          raise e
64
65if __name__ == '__main__':
66  if len(sys.argv) != 8:
67    print "Usage: rocksdb_deadlock_stress.py user host port db_name " \
68          "table_name num_iters num_threads"
69    sys.exit(1)
70
71  user = sys.argv[1]
72  host = sys.argv[2]
73  port = int(sys.argv[3])
74  db = sys.argv[4]
75  table_name = sys.argv[5]
76  num_iters = int(sys.argv[6])
77  num_workers = int(sys.argv[7])
78
79  worker_failed = False
80  workers = []
81  for i in xrange(num_workers):
82    w = Worker(
83      MySQLdb.connect(user=user, host=host, port=port, db=db), table_name,
84      num_iters)
85    workers.append(w)
86
87  for w in workers:
88    w.join()
89    if w.exception:
90      print "Worker hit an exception:\n%s\n" % w.exception
91      worker_failed = True
92
93  if worker_failed:
94    sys.exit(1)
95