1import cStringIO
2import array
3import hashlib
4import MySQLdb
5from MySQLdb.constants import CR
6from MySQLdb.constants import ER
7from collections import deque
8import os
9import random
10import signal
11import sys
12import threading
13import time
14import string
15import traceback
16import logging
17import argparse
18
19# This is a generic load_generator for mysqld which persists across server
20# restarts and attempts to verify both committed and uncommitted transactions
21# are persisted correctly.
22#
23# The table schema used should look something like:
24#
25# CREATE TABLE t1(id INT PRIMARY KEY,
26#                 thread_id INT NOT NULL,
27#                 request_id BIGINT UNSIGNED NOT NULL,
28#                 update_count INT UNSIGNED NOT NULL DEFAULT 0,
29#                 zero_sum INT DEFAULT 0,
30#                 msg VARCHAR(1024),
31#                 msg_length int,
32#                 msg_checksum varchar(128),
33#                 KEY msg_i(msg(255), zero_sum))
34# ENGINE=RocksDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin;
35#
36#   zero_sum should always sum up to 0 regardless of when the transaction tries
37#   to process the transaction. Each transaction always maintain this sum to 0.
38#
39#   request_id should be unique across transactions. It is used during
40#   transaction verification and is monotonically increasing..
41#
42# Several threads are spawned at the start of the test to populate the table.
43# Once the table is populated, both loader and checker threads are created.
44#
45# The row id space is split into two sections: exclusive and shared. Each
46# loader thread owns some part of the exclusive section which it maintains
47# complete information on insert/updates/deletes. Since this section is only
48# modified by one thread, the thread can maintain an accurate picture of all
49# changes. The shared section contains rows which multiple threads can
50# update/delete/insert.  For checking purposes, the request_id is used to
51# determine if a row is consistent with a committed transaction.
52#
53# Each loader thread's transaction consists of selecting some number of rows
54# randomly. The thread can choose to delete the row, update the row or insert
55# the row if it doesn't exist.  The state of rows that are owned by the loader
56# thread are tracked within the thread's id_map. This map contains the row id
57# and the request_id of the latest update. For indicating deleted rows, the
58# -request_id marker is used. Thus, at any point in time, the thread's id_map
59# should reflect the exact state of the rows that are owned.
60#
61# The loader thread also maintains the state of older transactions that were
62# successfully processed in addition to the current transaction, which may or
63# may not be committed. Each transaction state consists of the row id, and the
64# request_id. Again, -request_id is used to indicate a delete. For committed
65# transactions, the thread can verify the request_id of the row is larger than
66# what the thread has recorded. For uncommitted transactions, the thread would
67# verify the request_id of the row does not match that of the transaction. To
68# determine whether or not a transaction succeeded in case of a crash right at
69# commit, each thread always includes a particular row in the transaction which
70# it could use to check the request id against.
71#
72# Checker threads run continuously to verify the checksums on the rows and to
73# verify the zero_sum column sums up to zero at any point in time. The checker
74# threads run both point lookups and range scans for selecting the rows.
75
76class ValidateError(Exception):
77  """Raised when validation fails"""
78  pass
79
80class TestError(Exception):
81  """Raised when the test cannot make forward progress"""
82  pass
83
84CHARS = string.letters + string.digits
85OPTIONS = {}
86
87# max number of rows per transaction
88MAX_ROWS_PER_REQ = 10
89
90# global variable checked by threads to determine if the test is stopping
91TEST_STOP = False
92LOADERS_READY = 0
93
94# global monotonically increasing request id counter
95REQUEST_ID = 1
96REQUEST_ID_LOCK = threading.Lock()
97
98INSERT_ID_SET = set()
99
100def get_next_request_id():
101  global REQUEST_ID
102  with REQUEST_ID_LOCK:
103    REQUEST_ID += 1
104    return REQUEST_ID
105
106# given a percentage value, rolls a 100-sided die and return whether the
107# given value is above or equal to the die roll
108#
109# passing 0 should always return false and 100 should always return true
110def roll_d100(p):
111  assert p >= 0 and p <= 100
112  return p >= random.randint(1, 100)
113
114def sha1(x):
115  return hashlib.sha1(str(x)).hexdigest()
116
117def is_connection_error(exc):
118  error_code = exc.args[0]
119  return (error_code == MySQLdb.constants.CR.CONNECTION_ERROR or
120          error_code == MySQLdb.constants.CR.CONN_HOST_ERROR or
121          error_code == MySQLdb.constants.CR.SERVER_LOST or
122          error_code == MySQLdb.constants.CR.SERVER_GONE_ERROR or
123          error_code == MySQLdb.constants.ER.QUERY_INTERRUPTED or
124          error_code == MySQLdb.constants.ER.SERVER_SHUTDOWN)
125
126def is_deadlock_error(exc):
127  error_code = exc.args[0]
128  return (error_code == MySQLdb.constants.ER.LOCK_DEADLOCK or
129          error_code == MySQLdb.constants.ER.LOCK_WAIT_TIMEOUT)
130
131# should be deterministic given an idx
132def gen_msg(idx, thread_id, request_id):
133  random.seed(idx);
134  # field length is 1024 bytes, but 32 are reserved for the tid and req tag
135  blob_length = random.randint(1, 1024 - 32)
136
137  if roll_d100(50):
138    # blob that cannot be compressed (well, compresses to 85% of original size)
139    msg = ''.join([random.choice(CHARS) for x in xrange(blob_length)])
140  else:
141    # blob that can be compressed
142    msg = random.choice(CHARS) * blob_length
143
144  # append the thread_id and request_id to the end of the msg
145  return ''.join([msg, ' tid: %d req: %d' % (thread_id, request_id)])
146
147def execute(cur, stmt):
148  ROW_COUNT_ERROR = 18446744073709551615L
149  logging.debug("Executing %s" % stmt)
150  cur.execute(stmt)
151  if cur.rowcount < 0 or cur.rowcount == ROW_COUNT_ERROR:
152    raise MySQLdb.OperationalError(MySQLdb.constants.CR.CONNECTION_ERROR,
153                                   "Possible connection error, rowcount is %d"
154                                   % cur.rowcount)
155
156def wait_for_workers(workers, min_active = 0):
157  logging.info("Waiting for %d workers", len(workers))
158  # min_active needs to include the current waiting thread
159  min_active += 1
160
161  # polling here allows this thread to be responsive to keyboard interrupt
162  # exceptions, otherwise a user hitting ctrl-c would see the load_generator as
163  # hanging and unresponsive
164  try:
165    while threading.active_count() > min_active:
166      time.sleep(1)
167  except KeyboardInterrupt, e:
168    os._exit(1)
169
170  num_failures = 0
171  for w in workers:
172    w.join()
173    if w.exception:
174      logging.error(w.exception)
175      num_failures += 1
176
177  return num_failures
178
179# base class for worker threads and contains logic for handling reconnecting to
180# the mysqld server during connection failure
181class WorkerThread(threading.Thread):
182  def __init__(self, name):
183    threading.Thread.__init__(self)
184    self.name = name
185    self.exception = None
186    self.con = None
187    self.cur = None
188    self.isolation_level = None
189    self.start_time = time.time()
190    self.total_time = 0
191
192  def run(self):
193    global TEST_STOP
194
195    try:
196      logging.info("Started")
197      self.runme()
198      logging.info("Completed successfully")
199    except Exception, e:
200      self.exception = traceback.format_exc()
201      logging.error(self.exception)
202      TEST_STOP = True
203    finally:
204      self.total_time = time.time() - self.start_time
205      logging.info("Total run time: %.2f s" % self.total_time)
206      self.finish()
207
208  def reconnect(self, timeout=900):
209    global TEST_STOP
210
211    self.con = None
212    SECONDS_BETWEEN_RETRY = 10
213    attempts = 1
214    logging.info("Attempting to connect to MySQL Server")
215    while not self.con and timeout > 0 and not TEST_STOP:
216      try:
217        self.con = MySQLdb.connect(user=OPTIONS.user, host=OPTIONS.host,
218                                   port=OPTIONS.port, db=OPTIONS.db)
219        if self.con:
220          self.con.autocommit(False)
221          self.cur = self.con.cursor()
222          self.set_isolation_level(self.isolation_level)
223          logging.info("Connection successful after attempt %d" % attempts)
224          break
225      except MySQLdb.Error, e:
226        logging.debug(traceback.format_exc())
227      time.sleep(SECONDS_BETWEEN_RETRY)
228      timeout -= SECONDS_BETWEEN_RETRY
229      attempts += 1
230    return self.con is None
231
232  def get_isolation_level(self):
233    execute(self.cur, "SELECT @@SESSION.tx_isolation")
234    if self.cur.rowcount != 1:
235      raise TestError("Unable to retrieve tx_isolation")
236    return self.cur.fetchone()[0]
237
238  def set_isolation_level(self, isolation_level, persist = False):
239    if isolation_level is not None:
240      execute(self.cur, "SET @@SESSION.tx_isolation = '%s'" % isolation_level)
241      if self.cur.rowcount != 0:
242        raise TestError("Unable to set the isolation level to %s")
243
244    if isolation_level is None or persist:
245      self.isolation_level = isolation_level
246
247# periodically kills the server
248class ReaperWorker(WorkerThread):
249  def __init__(self):
250    WorkerThread.__init__(self, 'reaper')
251    self.start()
252    self.kills = 0
253
254  def finish(self):
255    logging.info('complete with %d kills' % self.kills)
256    if self.con:
257      self.con.close()
258
259  def get_server_pid(self):
260    execute(self.cur, "SELECT @@pid_file")
261    if self.cur.rowcount != 1:
262      raise TestError("Unable to retrieve pid_file")
263    return int(open(self.cur.fetchone()[0]).read())
264
265  def runme(self):
266    global TEST_STOP
267    time_remain = random.randint(10, 30)
268    while not TEST_STOP:
269      if time_remain > 0:
270        time_remain -= 1
271        time.sleep(1)
272        continue
273      if self.reconnect():
274        raise Exception("Unable to connect to MySQL server")
275      logging.info('killing server...')
276      with open(OPTIONS.expect_file, 'w+') as expect_file:
277        expect_file.write('restart')
278      os.kill(self.get_server_pid(), signal.SIGTERM)
279      self.kills += 1
280      time_remain = random.randint(0, 30) + OPTIONS.reap_delay;
281
282# runs initially to populate the table with the given number of rows
283class PopulateWorker(WorkerThread):
284  def __init__(self, thread_id, start_id, num_to_add):
285    WorkerThread.__init__(self, 'populate-%d' % thread_id)
286    self.thread_id = thread_id
287    self.start_id = start_id
288    self.num_to_add = num_to_add
289    self.table = OPTIONS.table
290    self.start()
291
292  def finish(self):
293    if self.con:
294      self.con.commit()
295      self.con.close()
296
297  def runme(self):
298    if self.reconnect():
299      raise Exception("Unable to connect to MySQL server")
300
301    stmt = None
302    for i in xrange(self.start_id, self.start_id + self.num_to_add):
303      stmt = gen_insert(self.table, i, 0, 0, 0)
304      execute(self.cur, stmt)
305      if i % 101 == 0:
306        self.con.commit()
307        check_id(self.con.insert_id())
308    self.con.commit()
309    check_id(self.con.insert_id())
310    logging.info("Inserted %d rows starting at id %d" %
311                 (self.num_to_add, self.start_id))
312
313def check_id(id):
314  if id == 0:
315    return
316  if id in INSERT_ID_SET:
317    raise Exception("Duplicate auto_inc id %d" % id)
318  INSERT_ID_SET.add(id)
319
320def populate_table(num_records):
321
322  logging.info("Populate_table started for %d records" % num_records)
323  if num_records == 0:
324    return False
325
326  num_workers = min(10, num_records / 100)
327  workers = []
328
329  N = num_records / num_workers
330  start_id = 0
331  for i in xrange(num_workers):
332     workers.append(PopulateWorker(i, start_id, N))
333     start_id += N
334  if num_records > start_id:
335    workers.append(PopulateWorker(num_workers, start_id,
336                   num_records - start_id))
337
338  # Wait for the populate threads to complete
339  return wait_for_workers(workers) > 0
340
341def gen_insert(table, idx, thread_id, request_id, zero_sum):
342  msg = gen_msg(idx, thread_id, request_id)
343  return ("INSERT INTO %s (id, thread_id, request_id, zero_sum, "
344          "msg, msg_length, msg_checksum) VALUES (%d,%d,%d,%d,'%s',%d,'%s')"
345           % (table, idx, thread_id, request_id,
346              zero_sum, msg, len(msg), sha1(msg)))
347
348def gen_update(table, idx, thread_id, request_id, zero_sum):
349  msg = gen_msg(idx, thread_id, request_id)
350  return ("UPDATE %s SET thread_id = %d, request_id = %d, "
351          "update_count = update_count + 1, zero_sum = zero_sum + (%d), "
352          "msg = '%s', msg_length = %d, msg_checksum = '%s' WHERE id = %d "
353          % (table, thread_id, request_id, zero_sum, msg, len(msg),
354             sha1(msg), idx))
355
356def gen_delete(table, idx):
357    return "DELETE FROM %s WHERE id = %d" % (table, idx)
358
359def gen_insert_on_dup(table, idx, thread_id, request_id, zero_sum):
360  msg = gen_msg(idx, thread_id, request_id)
361  msg_checksum = sha1(msg)
362  return ("INSERT INTO %s (id, thread_id, request_id, zero_sum, "
363          "msg, msg_length, msg_checksum) VALUES (%d,%d,%d,%d,'%s',%d,'%s') "
364          "ON DUPLICATE KEY UPDATE "
365          "thread_id=%d, request_id=%d, "
366          "update_count=update_count+1, "
367          "zero_sum=zero_sum + (%d), msg='%s', msg_length=%d, "
368          "msg_checksum='%s'" %
369          (table, idx, thread_id, request_id,
370           zero_sum, msg, len(msg), msg_checksum, thread_id, request_id,
371           zero_sum, msg, len(msg), msg_checksum))
372
373# Each loader thread owns a part of the id space which it maintains inventory
374# for. The loader thread generates inserts, updates and deletes for the table.
375# The latest successful transaction and the latest open transaction are kept to
376# verify after a disconnect that the rows were recovered properly.
377class LoadGenWorker(WorkerThread):
378  TXN_UNCOMMITTED = 0
379  TXN_COMMIT_STARTED = 1
380  TXN_COMMITTED = 2
381
382  def __init__(self, thread_id):
383    WorkerThread.__init__(self, 'loader-%02d' % thread_id)
384    self.thread_id = thread_id
385    self.rand = random.Random()
386    self.rand.seed(thread_id)
387    self.loop_num = 0
388
389    # id_map contains the array of id's owned by this worker thread. It needs
390    # to be offset by start_id for the actual id
391    self.id_map = array.array('l')
392    self.start_id = thread_id * OPTIONS.ids_per_loader
393    self.num_id = OPTIONS.ids_per_loader
394    self.start_share_id = OPTIONS.num_loaders * OPTIONS.ids_per_loader
395    self.max_id = OPTIONS.max_id
396    self.table = OPTIONS.table
397    self.num_requests = OPTIONS.num_requests
398
399    # stores information about the latest series of successful transactions
400    #
401    # each transaction is simply a map of id -> request_id
402    # deleted rows are indicated by -request_id
403    self.prev_txn = deque()
404    self.cur_txn = None
405    self.cur_txn_state = None
406
407    self.start()
408
409  def finish(self):
410    if self.total_time:
411      req_per_sec = self.loop_num / self.total_time
412    else:
413      req_per_sec = -1
414    logging.info("total txns: %d, txn/s: %.2f rps" %
415                 (self.loop_num, req_per_sec))
416
417  # constructs the internal hash map of the ids owned by this thread and
418  # the request id of each id
419  def populate_id_map(self):
420    logging.info("Populating id map")
421
422    REQ_ID_COL = 0
423    stmt = "SELECT request_id FROM %s WHERE id = %d"
424
425    # the start_id is used for tracking active transactions, so the row needs
426    # to exist
427    idx = self.start_id
428    execute(self.cur, stmt % (self.table, idx))
429    if self.cur.rowcount > 0:
430      request_id = self.cur.fetchone()[REQ_ID_COL]
431    else:
432      request_id = get_next_request_id()
433      execute(self.cur, gen_insert(self.table, idx, self.thread_id,
434                                   request_id, 0))
435      self.con.commit()
436      check_id(self.con.insert_id())
437
438    self.id_map.append(request_id)
439
440    self.cur_txn = {idx:request_id}
441    self.cur_txn_state = self.TXN_COMMITTED
442    for i in xrange(OPTIONS.committed_txns):
443      self.prev_txn.append(self.cur_txn)
444
445    # fetch the rest of the row for the id space owned by this thread
446    for idx in xrange(self.start_id + 1, self.start_id + self.num_id):
447      execute(self.cur, stmt % (self.table, idx))
448      if self.cur.rowcount == 0:
449        # Negative number is used to indicated a missing row
450        self.id_map.append(-1)
451      else:
452        res = self.cur.fetchone()
453        self.id_map.append(res[REQ_ID_COL])
454
455    self.con.commit()
456
457  def apply_cur_txn_changes(self):
458    # apply the changes to the id_map
459    for idx in self.cur_txn:
460      if idx < self.start_id + self.num_id:
461        assert idx >= self.start_id
462        self.id_map[idx - self.start_id] = self.cur_txn[idx]
463    self.cur_txn_state = self.TXN_COMMITTED
464
465    self.prev_txn.append(self.cur_txn)
466    self.prev_txn.popleft()
467
468  def verify_txn(self, txn, committed):
469    request_id = txn[self.start_id]
470    if not committed:
471      # if the transaction was not committed, then there should be no rows
472      # in the table that have this request_id
473      cond = '='
474      # it is possible the start_id used to track this transaction is in
475      # the process of being deleted
476      if request_id < 0:
477        request_id = -request_id
478    else:
479      # if the transaction was committed, then no rows modified by this
480      # transaction should have a request_id less than this transaction's id
481      cond = '<'
482    stmt = ("SELECT COUNT(*) FROM %s WHERE id IN (%s) AND request_id %s %d" %
483            (self.table, ','.join(str(x) for x in txn), cond, request_id))
484    execute(self.cur, stmt)
485    if (self.cur.rowcount != 1):
486      raise TestError("Unable to retrieve results for query '%s'" % stmt)
487    count = self.cur.fetchone()[0]
488    if (count > 0):
489      raise TestError("Expected '%s' to return 0 rows, but %d returned "
490                      "instead" % (stmt, count))
491    self.con.commit()
492
493  def verify_data(self):
494    # if the state of the current transaction is unknown (i.e. a commit was
495    # issued, but the connection failed before, check the start_id row to
496    # determine if it was committed
497    request_id = self.cur_txn[self.start_id]
498    if self.cur_txn_state == self.TXN_COMMIT_STARTED:
499      assert request_id >= 0
500      idx = self.start_id
501      stmt = "SELECT id, request_id FROM %s where id = %d" % (self.table, idx)
502      execute(self.cur, stmt)
503      if (self.cur.rowcount == 0):
504        raise TestError("Fetching start_id %d via '%s' returned no data! "
505                        "This row should never be deleted!" % (idx, stmt))
506      REQUEST_ID_COL = 1
507      res = self.cur.fetchone()
508      if res[REQUEST_ID_COL] == self.cur_txn[idx]:
509        self.apply_cur_txn_changes()
510      else:
511        self.cur_txn_state = self.TXN_UNCOMMITTED
512      self.con.commit()
513
514    # if the transaction was not committed, verify there are no rows at this
515    # request id
516    #
517    # however, if the transaction was committed, then verify none of the rows
518    # have a request_id below the request_id recorded by the start_id row.
519    if self.cur_txn_state == self.TXN_UNCOMMITTED:
520      self.verify_txn(self.cur_txn, False)
521
522    # verify all committed transactions
523    for txn in self.prev_txn:
524      self.verify_txn(txn, True)
525
526    # verify the rows owned by this worker matches the request_id at which
527    # they were set.
528    idx = self.start_id
529    max_map_id = self.start_id + self.num_id
530    row_count = 0
531    ID_COL = 0
532    REQ_ID_COL = ID_COL + 1
533
534    while idx < max_map_id:
535      if (row_count == 0):
536        num_rows_to_check = random.randint(50, 100)
537        execute(self.cur,
538          "SELECT id, request_id FROM %s where id >= %d and id < %d "
539          "ORDER BY id LIMIT %d"
540          % (self.table, idx, max_map_id, num_rows_to_check))
541
542        # prevent future queries from being issued since we've hit the end of
543        # the rows that exist in the table
544        row_count = self.cur.rowcount if self.cur.rowcount != 0 else -1
545
546      # determine the id of the next available row in the table
547      if (row_count > 0):
548        res = self.cur.fetchone()
549        assert idx <= res[ID_COL]
550        next_id = res[ID_COL]
551        row_count -= 1
552      else:
553        next_id = max_map_id
554
555      # rows up to the next id don't exist within the table, verify our
556      # map has them as removed
557      while idx < next_id:
558        # see if the latest transaction may have modified this id. If so, use
559        # that value.
560        if self.id_map[idx - self.start_id] >= 0:
561          raise ValidateError("Row id %d was not found in table, but "
562                              "id_map has it at request_id %d" %
563                              (idx, self.id_map[idx - self.start_id]))
564        idx += 1
565
566      if idx == max_map_id:
567        break
568
569      if (self.id_map[idx - self.start_id] != res[REQ_ID_COL]):
570        raise ValidateError("Row id %d has req id %d, but %d is the "
571                            "expected value!" %
572                            (idx, res[REQ_ID_COL],
573                             self.id_map[idx - self.start_id]))
574      idx += 1
575
576    self.con.commit()
577    logging.debug("Verified data successfully")
578
579  def execute_one(self):
580    # select a number of rows; perform an insert; update or delete operation on
581    # them
582    num_rows = random.randint(1, MAX_ROWS_PER_REQ)
583    ids = array.array('L')
584
585    # allocate at least one row in the id space owned by this worker
586    idx = random.randint(self.start_id, self.start_id + self.num_id - 1)
587    ids.append(idx)
588
589    for i in xrange(1, num_rows):
590      # The valid ranges for ids is from start_id to start_id + num_id and from
591      # start_share_id to max_id. The randint() uses the range from
592      # start_share_id to max_id + num_id - 1. start_share_id to max_id covers
593      # the shared range. The exclusive range is covered by max_id to max_id +
594      # num_id - 1. If any number lands in this >= max_id section, it is
595      # remapped to start_id and used for selecting a row in the exclusive
596      # section.
597      idx = random.randint(self.start_share_id, self.max_id + self.num_id - 1)
598      if idx >= self.max_id:
599        idx -= self.max_id - self.start_id
600      if ids.count(idx) == 0:
601        ids.append(idx)
602
603    # perform a read of these rows
604    ID_COL = 0
605    ZERO_SUM_COL = ID_COL + 1
606
607    # For repeatable-read isolation levels on MyRocks, during the lock
608    # acquisition part of this transaction, it is possible the selected rows
609    # conflict with another thread's transaction. This results in a deadlock
610    # error that requires the whole transaction to be rolled back because the
611    # transaction's current snapshot will always be reading an older version of
612    # the row. MyRocks will prevent any updates to this row until the
613    # snapshot is released and re-acquired.
614    NUM_RETRIES = 100
615    for i in xrange(NUM_RETRIES):
616      ids_found = {}
617      try:
618        for idx in ids:
619          stmt = ("SELECT id, zero_sum FROM %s WHERE id = %d "
620                  "FOR UPDATE" % (self.table, idx))
621          execute(self.cur, stmt)
622          if self.cur.rowcount > 0:
623            res = self.cur.fetchone()
624            ids_found[res[ID_COL]] = res[ZERO_SUM_COL]
625        break
626      except MySQLdb.OperationalError, e:
627        if not is_deadlock_error(e):
628          raise e
629
630      # if a deadlock occurred, rollback the transaction and wait a short time
631      # before retrying.
632      logging.debug("%s generated deadlock, retry %d of %d" %
633                    (stmt, i, NUM_RETRIES))
634      self.con.rollback()
635      time.sleep(0.2)
636
637    if i == NUM_RETRIES - 1:
638      raise TestError("Unable to acquire locks after a number of retries "
639                      "for query '%s'" % stmt)
640
641    # ensure that the zero_sum column remains summed up to zero at the
642    # end of this operation
643    current_sum = 0
644
645    # all row locks acquired at this point, so allocate a request_id
646    request_id = get_next_request_id()
647    self.cur_txn = {self.start_id:request_id}
648    self.cur_txn_state = self.TXN_UNCOMMITTED
649
650    for idx in ids:
651      stmt = None
652      zero_sum = self.rand.randint(-1000, 1000)
653      action = self.rand.randint(0, 3)
654      is_delete = False
655
656      if idx in ids_found:
657        # for each row found, determine if it should be updated or deleted
658        if action == 0:
659          stmt = gen_delete(self.table, idx)
660          is_delete = True
661          current_sum -= ids_found[idx]
662        else:
663          stmt = gen_update(self.table, idx, self.thread_id, request_id,
664                            zero_sum)
665          current_sum += zero_sum
666      else:
667        # if it does not exist, then determine if an insert should happen
668        if action <= 1:
669          stmt = gen_insert(self.table, idx, self.thread_id, request_id,
670                            zero_sum)
671          current_sum += zero_sum
672
673      if stmt is not None:
674        # mark in self.cur_txn what these new changes will be
675        if is_delete:
676          self.cur_txn[idx] = -request_id
677        else:
678          self.cur_txn[idx] = request_id
679        execute(self.cur, stmt)
680        if self.cur.rowcount == 0:
681          raise TestError("Executing %s returned row count of 0!" % stmt)
682
683    # the start_id row is used to determine if this transaction has been
684    # committed if the connect fails and it is used to adjust the zero_sum
685    # correctly
686    idx = self.start_id
687    ids.append(idx)
688    self.cur_txn[idx] = request_id
689    stmt = gen_insert_on_dup(self.table, idx, self.thread_id, request_id,
690                             -current_sum)
691    execute(self.cur, stmt)
692    if self.cur.rowcount == 0:
693      raise TestError("Executing '%s' returned row count of 0!" % stmt)
694
695    # 90% commit, 10% rollback
696    if roll_d100(90):
697      self.con.rollback()
698      logging.debug("request %s was rolled back" % request_id)
699    else:
700      self.cur_txn_state = self.TXN_COMMIT_STARTED
701      self.con.commit()
702      check_id(self.con.insert_id())
703      if not self.con.get_server_info():
704        raise MySQLdb.OperationalError(MySQLdb.constants.CR.CONNECTION_ERROR,
705                                       "Possible connection error on commit")
706      self.apply_cur_txn_changes()
707
708    self.loop_num += 1
709    if self.loop_num % 1000 == 0:
710      logging.info("Processed %d transactions so far" % self.loop_num)
711
712  def runme(self):
713    global TEST_STOP, LOADERS_READY
714
715    self.start_time = time.time()
716    if self.reconnect():
717      raise Exception("Unable to connect to MySQL server")
718
719    self.populate_id_map()
720    self.verify_data()
721
722    logging.info("Starting load generator")
723    reconnected = False
724    LOADERS_READY += 1
725
726    while self.loop_num < self.num_requests and not TEST_STOP:
727      try:
728        # verify our data on each reconnect and also on ocassion
729        if reconnected or random.randint(1, 500) == 1:
730          self.verify_data()
731          reconnected = False
732
733        self.execute_one()
734        self.loop_num += 1
735      except MySQLdb.OperationalError, e:
736        if not is_connection_error(e):
737          raise e
738        if self.reconnect():
739          raise Exception("Unable to connect to MySQL server")
740        reconnected = True
741    return
742
743# the checker thread is running read only transactions to verify the row
744# checksums match the message.
745class CheckerWorker(WorkerThread):
746  def __init__(self, thread_id):
747    WorkerThread.__init__(self, 'checker-%02d' % thread_id)
748    self.thread_id = thread_id
749    self.rand = random.Random()
750    self.rand.seed(thread_id)
751    self.max_id = OPTIONS.max_id
752    self.table = OPTIONS.table
753    self.loop_num = 0
754    self.start()
755
756  def finish(self):
757    logging.info("total loops: %d" % self.loop_num)
758
759  def check_zerosum(self):
760    # two methods for checking zero sum
761    #   1. request the server to do it (90% of the time for now)
762    #   2. read all rows and calculate directly
763    if roll_d100(90):
764      stmt = "SELECT SUM(zero_sum) FROM %s" % self.table
765      if roll_d100(50):
766        stmt += " FORCE INDEX(msg_i)"
767      execute(self.cur, stmt)
768
769      if self.cur.rowcount != 1:
770        raise ValidateError("Error with query '%s'" % stmt)
771      res = self.cur.fetchone()[0]
772      if res != 0:
773        raise ValidateError("Expected zero_sum to be 0, but %d returned "
774                            "instead" % res)
775    else:
776      cur_isolation_level = self.get_isolation_level()
777      self.set_isolation_level('REPEATABLE-READ')
778      num_rows_to_check = random.randint(500, 1000)
779      idx = 0
780      sum = 0
781
782      stmt = "SELECT id, zero_sum FROM %s where id >= %d ORDER BY id LIMIT %d"
783      ID_COL = 0
784      ZERO_SUM_COL = 1
785
786      while idx < self.max_id:
787        execute(self.cur, stmt % (self.table, idx, num_rows_to_check))
788        if self.cur.rowcount == 0:
789          break
790
791        for i in xrange(self.cur.rowcount - 1):
792          sum += self.cur.fetchone()[ZERO_SUM_COL]
793
794        last_row = self.cur.fetchone()
795        idx = last_row[ID_COL] + 1
796        sum += last_row[ZERO_SUM_COL]
797
798      if sum != 0:
799        raise TestError("Zero sum column expected to total 0, but sum is %d "
800                        "instead!" % sum)
801      self.set_isolation_level(cur_isolation_level)
802
803  def check_rows(self):
804    class id_range():
805      def __init__(self, min_id, min_inclusive, max_id, max_inclusive):
806        self.min_id = min_id if min_inclusive else min_id + 1
807        self.max_id = max_id if max_inclusive else max_id - 1
808      def count(self, idx):
809        return idx >= self.min_id and idx <= self.max_id
810
811    stmt = ("SELECT id, msg, msg_length, msg_checksum FROM %s WHERE " %
812            self.table)
813
814    # two methods for checking rows
815    #  1. pick a number of rows at random
816    #  2. range scan
817    if roll_d100(90):
818      ids = []
819      for i in xrange(random.randint(1, MAX_ROWS_PER_REQ)):
820        ids.append(random.randint(0, self.max_id - 1))
821      stmt += "id in (%s)" % ','.join(str(x) for x in ids)
822    else:
823      id1 = random.randint(0, self.max_id - 1)
824      id2 = random.randint(0, self.max_id - 1)
825      min_inclusive = random.randint(0, 1)
826      cond1 = '>=' if min_inclusive else '>'
827      max_inclusive = random.randint(0, 1)
828      cond2 = '<=' if max_inclusive else '<'
829      stmt += ("id %s %d AND id %s %d" %
830               (cond1, min(id1, id2), cond2, max(id1, id2)))
831      ids = id_range(min(id1, id2), min_inclusive, max(id1, id2), max_inclusive)
832
833    execute(self.cur, stmt)
834
835    ID_COL = 0
836    MSG_COL = ID_COL + 1
837    MSG_LENGTH_COL = MSG_COL + 1
838    MSG_CHECKSUM_COL = MSG_LENGTH_COL + 1
839
840    for row in self.cur.fetchall():
841      idx = row[ID_COL]
842      msg = row[MSG_COL]
843      msg_length = row[MSG_LENGTH_COL]
844      msg_checksum = row[MSG_CHECKSUM_COL]
845      if ids.count(idx) < 1:
846        raise ValidateError(
847            "id %d returned from database, but query was '%s'" % (idx, stmt))
848      if (len(msg) != msg_length):
849        raise ValidateError(
850            "id %d contains msg_length %d, but msg '%s' is only %d "
851            "characters long" % (idx, msg_length, msg, len(msg)))
852      if (sha1(msg) != msg_checksum):
853        raise ValidateError("id %d has checksum '%s', but expected checksum "
854                            "is '%s'" % (idx, msg_checksum, sha1(msg)))
855
856  def runme(self):
857    global TEST_STOP
858
859    self.start_time = time.time()
860    if self.reconnect():
861      raise Exception("Unable to connect to MySQL server")
862    logging.info("Starting checker")
863
864    while not TEST_STOP:
865      try:
866        # choose one of three options:
867        #   1. compute zero_sum across all rows is 0
868        #   2. read a number of rows and verify checksums
869        if roll_d100(25):
870          self.check_zerosum()
871        else:
872          self.check_rows()
873
874        self.con.commit()
875        self.loop_num += 1
876        if self.loop_num % 10000 == 0:
877          logging.info("Processed %d transactions so far" % self.loop_num)
878      except MySQLdb.OperationalError, e:
879        if not is_connection_error(e):
880          raise e
881        if self.reconnect():
882          raise Exception("Unable to reconnect to MySQL server")
883
884if  __name__ == '__main__':
885  parser = argparse.ArgumentParser(description='Concurrent load generator.')
886
887  parser.add_argument('-C, --committed-txns', dest='committed_txns',
888                      default=3, type=int,
889                      help="number of committed txns to verify")
890
891  parser.add_argument('-c, --num-checkers', dest='num_checkers', type=int,
892                      default=4,
893                      help="number of reader/checker threads to test with")
894
895  parser.add_argument('-d, --db', dest='db', default='test',
896                      help="mysqld server database to test with")
897
898  parser.add_argument('-H, --host', dest='host', default='127.0.0.1',
899                      help="mysqld server host ip address")
900
901  parser.add_argument('-i, --ids-per-loader', dest='ids_per_loader',
902                      type=int, default=100,
903                      help="number of records which each loader owns "
904                           "exclusively, up to max-id / 2 / num-loaders")
905
906  parser.add_argument('-L, --log-file', dest='log_file', default=None,
907                      help="log file for output")
908
909  parser.add_argument('-l, --num-loaders', dest='num_loaders', type=int,
910                      default=16,
911                      help="number of loader threads to test with")
912
913  parser.add_argument('-m, --max-id', dest='max_id', type=int, default=1000,
914                      help="maximum number of records which the table "
915                           "extends to, must be larger than ids_per_loader * "
916                           "num_loaders")
917
918  parser.add_argument('-n, --num-records', dest='num_records', type=int,
919                      default=0,
920                      help="number of records to populate the table with")
921
922  parser.add_argument('-P, --port', dest='port', default=3307, type=int,
923                      help='mysqld server host port')
924
925  parser.add_argument('-r, --num-requests', dest='num_requests', type=int,
926                      default=100000000,
927                      help="number of requests issued per worker thread")
928
929  parser.add_argument('-T, --truncate', dest='truncate', action='store_true',
930                      help="truncates or creates the table before the test")
931
932  parser.add_argument('-t, --table', dest='table', default='t1',
933                      help="mysqld server table to test with")
934
935  parser.add_argument('-u, --user', dest='user', default='root',
936                      help="user to log into the mysql server")
937
938  parser.add_argument('-v, --verbose', dest='verbose', action='store_true',
939                      help="enable debug logging")
940
941  parser.add_argument('-E, --expect-file', dest='expect_file', default=None,
942                      help="expect file for server restart")
943
944  parser.add_argument('-D, --reap-delay', dest='reap_delay', type=int,
945                      default=0,
946                      help="seconds to sleep after each server reap")
947
948  OPTIONS = parser.parse_args()
949
950  if OPTIONS.verbose:
951    log_level = logging.DEBUG
952  else:
953    log_level = logging.INFO
954
955  logging.basicConfig(level=log_level,
956                      format='%(asctime)s %(threadName)s [%(levelname)s] '
957                             '%(message)s',
958                      datefmt='%Y-%m-%d %H:%M:%S',
959                      filename=OPTIONS.log_file)
960
961  logging.info("Command line given: %s" % ' '.join(sys.argv))
962
963  if (OPTIONS.max_id < 0 or OPTIONS.ids_per_loader <= 0 or
964      OPTIONS.max_id < OPTIONS.ids_per_loader * OPTIONS.num_loaders):
965    logging.error("ids-per-loader must be larger tha 0 and max-id must be "
966                  "larger than ids_per_loader * num_loaders")
967    exit(1)
968
969  logging.info("Using table %s.%s for test" % (OPTIONS.db, OPTIONS.table))
970
971  if OPTIONS.truncate:
972    logging.info("Truncating table")
973    con = MySQLdb.connect(user=OPTIONS.user, host=OPTIONS.host,
974                          port=OPTIONS.port, db=OPTIONS.db)
975    if not con:
976      raise TestError("Unable to connect to mysqld server to create/truncate "
977                      "table")
978    cur = con.cursor()
979    cur.execute("SELECT COUNT(*) FROM INFORMATION_SCHEMA.tables WHERE "
980                         "table_schema = '%s' AND table_name = '%s'" %
981                         (OPTIONS.db, OPTIONS.table))
982    if cur.rowcount != 1:
983      logging.error("Unable to retrieve information about table %s "
984                    "from information_schema!" % OPTIONS.table)
985      exit(1)
986
987    if cur.fetchone()[0] == 0:
988      logging.info("Table %s not found, creating a new one" % OPTIONS.table)
989      cur.execute("CREATE TABLE %s (id INT PRIMARY KEY, "
990                  "thread_id INT NOT NULL, "
991                  "request_id BIGINT UNSIGNED NOT NULL, "
992                  "update_count INT UNSIGNED NOT NULL DEFAULT 0, "
993                  "zero_sum INT DEFAULT 0, "
994                  "msg VARCHAR(1024), "
995                  "msg_length int, "
996                  "msg_checksum varchar(128), "
997                  "KEY msg_i(msg(255), zero_sum)) "
998                  "ENGINE=RocksDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin" %
999                  OPTIONS.table)
1000    else:
1001      logging.info("Table %s found, truncating" % OPTIONS.table)
1002      cur.execute("TRUNCATE TABLE %s" % OPTIONS.table)
1003    con.commit()
1004
1005  if populate_table(OPTIONS.num_records):
1006    logging.error("Populate table returned an error")
1007    exit(1)
1008
1009  logging.info("Starting %d loaders" % OPTIONS.num_loaders)
1010  loaders = []
1011  for i in xrange(OPTIONS.num_loaders):
1012    loaders.append(LoadGenWorker(i))
1013
1014  logging.info("Starting %d checkers" % OPTIONS.num_checkers)
1015  checkers = []
1016  for i in xrange(OPTIONS.num_checkers):
1017    checkers.append(CheckerWorker(i))
1018
1019  while LOADERS_READY < OPTIONS.num_loaders:
1020    time.sleep(0.5)
1021
1022  if OPTIONS.expect_file and OPTIONS.reap_delay > 0:
1023    logging.info('Starting reaper')
1024    checkers.append(ReaperWorker())
1025
1026  workers_failed = 0
1027  workers_failed += wait_for_workers(loaders, len(checkers))
1028
1029  if TEST_STOP:
1030    logging.error("Detected test failure, aborting")
1031    os._exit(1)
1032
1033  TEST_STOP = True
1034
1035  workers_failed += wait_for_workers(checkers)
1036
1037  if workers_failed > 0:
1038    logging.error("Test detected %d failures, aborting" % workers_failed)
1039    sys.exit(1)
1040
1041  logging.info("Test completed successfully")
1042  sys.exit(0)
1043