1"""
2This script tests concurrent inserts on a given table.
3Example Usage (in Mysql Test Framework):
4
5  CREATE TABLE t1 (a INT) ENGINE=rocksdb;
6
7  let $exec = python suite/rocksdb/t/rocksdb_concurrent_insert.py \
8                     root 127.0.0.1 $MASTER_MYPORT test t1 100 4;
9  exec $exec;
10
11"""
12import io
13import hashlib
14import MySQLdb
15import os
16import random
17import signal
18import sys
19import threading
20import time
21import string
22
23def get_insert(table_name, idx):
24  return """INSERT INTO %s (a) VALUES (%d)""" % (table_name, idx)
25
26class Inserter(threading.Thread):
27  Instance = None
28  def __init__(self, con, table_name, num_inserts):
29    threading.Thread.__init__(self)
30    self.finished = False
31    self.num_inserts = num_inserts
32    con.autocommit(False)
33    self.con = con
34    self.rand = random.Random()
35    self.exception = None
36    self.table_name = table_name
37    Inserter.Instance = self
38    self.start()
39  def run(self):
40    try:
41      self.runme()
42    except Exception as e:
43      self.exception = traceback.format_exc()
44      print("caught (%py)" % e)
45    finally:
46      self.finish()
47  def runme(self):
48    cur = self.con.cursor()
49    for i in range(self.num_inserts):
50      try:
51        cur.execute(get_insert(self.table_name, i))
52        r = self.rand.randint(1,10)
53        if r < 4:
54          self.con.commit()
55      except:
56        cur = self.con.cursor()
57    try:
58      self.con.commit()
59    except Exception as e:
60      self.exception = traceback.format_exc()
61      print("caught (%s)" % e)
62      pass
63  def finish(self):
64    self.finished = True
65
66if __name__ == '__main__':
67  if len(sys.argv) != 8:
68    print("Usage: rocksdb_concurrent_insert.py user host port db_name " \
69          "table_name num_inserts num_threads")
70    sys.exit(1)
71
72  user = sys.argv[1]
73  host = sys.argv[2]
74  port = int(sys.argv[3])
75  db = sys.argv[4]
76  table_name = sys.argv[5]
77  num_inserts = int(sys.argv[6])
78  num_workers = int(sys.argv[7])
79
80  worker_failed = False
81  workers = []
82  for i in range(num_workers):
83    inserter = Inserter(
84      MySQLdb.connect(user=user, host=host, port=port, db=db), table_name,
85      num_inserts)
86    workers.append(inserter)
87
88  for w in workers:
89    w.join()
90    if w.exception:
91      print("Worker hit an exception:\n%s\n" % w.exception)
92      worker_failed = True
93
94  if worker_failed:
95    sys.exit(1)
96