1import io
2import array
3import hashlib
4import MySQLdb
5from MySQLdb.constants import CR
6from MySQLdb.constants import ER
7from collections import deque
8import os
9import random
10import signal
11import sys
12import threading
13import time
14import string
15import traceback
16import logging
17import argparse
18
19# This is a generic load_generator for mysqld which persists across server
20# restarts and attempts to verify both committed and uncommitted transactions
21# are persisted correctly.
22#
23# The table schema used should look something like:
24#
25# CREATE TABLE t1(id INT PRIMARY KEY,
26#                 thread_id INT NOT NULL,
27#                 request_id BIGINT UNSIGNED NOT NULL,
28#                 update_count INT UNSIGNED NOT NULL DEFAULT 0,
29#                 zero_sum INT DEFAULT 0,
30#                 msg VARCHAR(1024),
31#                 msg_length int,
32#                 msg_checksum varchar(128),
33#                 KEY msg_i(msg(255), zero_sum))
34# ENGINE=RocksDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin;
35#
36#   zero_sum should always sum up to 0 regardless of when the transaction tries
37#   to process the transaction. Each transaction always maintain this sum to 0.
38#
39#   request_id should be unique across transactions. It is used during
40#   transaction verification and is monotonically increasing..
41#
42# Several threads are spawned at the start of the test to populate the table.
43# Once the table is populated, both loader and checker threads are created.
44#
45# The row id space is split into two sections: exclusive and shared. Each
46# loader thread owns some part of the exclusive section which it maintains
47# complete information on insert/updates/deletes. Since this section is only
48# modified by one thread, the thread can maintain an accurate picture of all
49# changes. The shared section contains rows which multiple threads can
50# update/delete/insert.  For checking purposes, the request_id is used to
51# determine if a row is consistent with a committed transaction.
52#
53# Each loader thread's transaction consists of selecting some number of rows
54# randomly. The thread can choose to delete the row, update the row or insert
55# the row if it doesn't exist.  The state of rows that are owned by the loader
56# thread are tracked within the thread's id_map. This map contains the row id
57# and the request_id of the latest update. For indicating deleted rows, the
58# -request_id marker is used. Thus, at any point in time, the thread's id_map
59# should reflect the exact state of the rows that are owned.
60#
61# The loader thread also maintains the state of older transactions that were
62# successfully processed in addition to the current transaction, which may or
63# may not be committed. Each transaction state consists of the row id, and the
64# request_id. Again, -request_id is used to indicate a delete. For committed
65# transactions, the thread can verify the request_id of the row is larger than
66# what the thread has recorded. For uncommitted transactions, the thread would
67# verify the request_id of the row does not match that of the transaction. To
68# determine whether or not a transaction succeeded in case of a crash right at
69# commit, each thread always includes a particular row in the transaction which
70# it could use to check the request id against.
71#
72# Checker threads run continuously to verify the checksums on the rows and to
73# verify the zero_sum column sums up to zero at any point in time. The checker
74# threads run both point lookups and range scans for selecting the rows.
75
76class ValidateError(Exception):
77  """Raised when validation fails"""
78  pass
79
80class TestError(Exception):
81  """Raised when the test cannot make forward progress"""
82  pass
83
84CHARS = string.ascii_letters + string.digits
85OPTIONS = {}
86
87# max number of rows per transaction
88MAX_ROWS_PER_REQ = 10
89
90# global variable checked by threads to determine if the test is stopping
91TEST_STOP = False
92LOADERS_READY = 0
93
94# global monotonically increasing request id counter
95REQUEST_ID = 1
96REQUEST_ID_LOCK = threading.Lock()
97
98def get_next_request_id():
99  global REQUEST_ID
100  with REQUEST_ID_LOCK:
101    REQUEST_ID += 1
102    return REQUEST_ID
103
104# given a percentage value, rolls a 100-sided die and return whether the
105# given value is above or equal to the die roll
106#
107# passing 0 should always return false and 100 should always return true
108def roll_d100(p):
109  assert p >= 0 and p <= 100
110  return p >= random.randint(1, 100)
111
112def sha1(x):
113  return hashlib.sha1(str(x).encode('utf-8')).hexdigest()
114
115def is_connection_error(exc):
116  error_code = exc.args[0]
117  return (error_code == MySQLdb.constants.CR.CONNECTION_ERROR or
118          error_code == MySQLdb.constants.CR.CONN_HOST_ERROR or
119          error_code == MySQLdb.constants.CR.SERVER_LOST or
120          error_code == MySQLdb.constants.CR.SERVER_GONE_ERROR or
121          error_code == MySQLdb.constants.ER.QUERY_INTERRUPTED or
122          error_code == MySQLdb.constants.ER.SERVER_SHUTDOWN)
123
124def is_deadlock_error(exc):
125  error_code = exc.args[0]
126  return (error_code == MySQLdb.constants.ER.LOCK_DEADLOCK or
127          error_code == MySQLdb.constants.ER.LOCK_WAIT_TIMEOUT)
128
129# should be deterministic given an idx
130def gen_msg(idx, thread_id, request_id):
131  random.seed(idx);
132  # field length is 1024 bytes, but 32 are reserved for the tid and req tag
133  blob_length = random.randint(1, 1024 - 32)
134
135  if roll_d100(50):
136    # blob that cannot be compressed (well, compresses to 85% of original size)
137    msg = ''.join([random.choice(CHARS) for x in range(blob_length)])
138  else:
139    # blob that can be compressed
140    msg = random.choice(CHARS) * blob_length
141
142  # append the thread_id and request_id to the end of the msg
143  return ''.join([msg, ' tid: %d req: %d' % (thread_id, request_id)])
144
145def execute(cur, stmt):
146  ROW_COUNT_ERROR = 18446744073709551615
147  logging.debug("Executing %s" % stmt)
148  cur.execute(stmt)
149  if cur.rowcount < 0 or cur.rowcount == ROW_COUNT_ERROR:
150    raise MySQLdb.OperationalError(MySQLdb.constants.CR.CONNECTION_ERROR,
151                                   "Possible connection error, rowcount is %d"
152                                   % cur.rowcount)
153
154def wait_for_workers(workers, min_active = 0):
155  logging.info("Waiting for %d workers", len(workers))
156  # min_active needs to include the current waiting thread
157  min_active += 1
158
159  # polling here allows this thread to be responsive to keyboard interrupt
160  # exceptions, otherwise a user hitting ctrl-c would see the load_generator as
161  # hanging and unresponsive
162  try:
163    while threading.active_count() > min_active:
164      time.sleep(1)
165  except KeyboardInterrupt as e:
166    os._exit(1)
167
168  num_failures = 0
169  for w in workers:
170    w.join()
171    if w.exception:
172      logging.error(w.exception)
173      num_failures += 1
174
175  return num_failures
176
177# base class for worker threads and contains logic for handling reconnecting to
178# the mysqld server during connection failure
179class WorkerThread(threading.Thread):
180  def __init__(self, name):
181    threading.Thread.__init__(self)
182    self.name = name
183    self.exception = None
184    self.con = None
185    self.cur = None
186    self.isolation_level = None
187    self.start_time = time.time()
188    self.total_time = 0
189
190  def run(self):
191    global TEST_STOP
192
193    try:
194      logging.info("Started")
195      self.runme()
196      logging.info("Completed successfully")
197    except Exception as e:
198      self.exception = traceback.format_exc()
199      logging.error(self.exception)
200      TEST_STOP = True
201    finally:
202      self.total_time = time.time() - self.start_time
203      logging.info("Total run time: %.2f s" % self.total_time)
204      self.finish()
205
206  def reconnect(self, timeout=900):
207    global TEST_STOP
208
209    self.con = None
210    SECONDS_BETWEEN_RETRY = 10
211    attempts = 1
212    logging.info("Attempting to connect to MySQL Server")
213    while not self.con and timeout > 0 and not TEST_STOP:
214      try:
215        self.con = MySQLdb.connect(user=OPTIONS.user, host=OPTIONS.host,
216                                   port=OPTIONS.port, db=OPTIONS.db)
217        if self.con:
218          self.con.autocommit(False)
219          self.cur = self.con.cursor()
220          self.set_isolation_level(self.isolation_level)
221          logging.info("Connection successful after attempt %d" % attempts)
222          break
223      except MySQLdb.Error as e:
224        logging.debug(traceback.format_exc())
225      time.sleep(SECONDS_BETWEEN_RETRY)
226      timeout -= SECONDS_BETWEEN_RETRY
227      attempts += 1
228    return self.con is None
229
230  def get_isolation_level(self):
231    execute(self.cur, "SELECT @@SESSION.transaction_isolation")
232    if self.cur.rowcount != 1:
233      raise TestError("Unable to retrieve transaction_isolation")
234    return self.cur.fetchone()[0]
235
236  def set_isolation_level(self, isolation_level, persist = False):
237    if isolation_level is not None:
238      execute(self.cur, "SET @@SESSION.transaction_isolation = '%s'" % isolation_level)
239      if self.cur.rowcount != 0:
240        raise TestError("Unable to set the isolation level to %s")
241
242    if isolation_level is None or persist:
243      self.isolation_level = isolation_level
244
245# periodically kills the server
246class ReaperWorker(WorkerThread):
247  def __init__(self):
248    WorkerThread.__init__(self, 'reaper')
249    self.start()
250    self.kills = 0
251
252  def finish(self):
253    logging.info('complete with %d kills' % self.kills)
254    if self.con:
255      self.con.close()
256
257  def get_server_pid(self):
258    execute(self.cur, "SELECT @@pid_file")
259    if self.cur.rowcount != 1:
260      raise TestError("Unable to retrieve pid_file")
261    return int(open(self.cur.fetchone()[0]).read())
262
263  def runme(self):
264    global TEST_STOP
265    time_remain = random.randint(10, 30)
266    while not TEST_STOP:
267      if time_remain > 0:
268        time_remain -= 1
269        time.sleep(1)
270        continue
271      if self.reconnect():
272        raise Exception("Unable to connect to MySQL server")
273      logging.info('killing server...')
274      with open(OPTIONS.expect_file, 'w+') as expect_file:
275        expect_file.write('restart')
276      os.kill(self.get_server_pid(), signal.SIGKILL)
277      self.kills += 1
278      time_remain = random.randint(0, 30) + OPTIONS.reap_delay;
279
280# runs initially to populate the table with the given number of rows
281class PopulateWorker(WorkerThread):
282  def __init__(self, thread_id, start_id, num_to_add):
283    WorkerThread.__init__(self, 'populate-%d' % thread_id)
284    self.thread_id = thread_id
285    self.start_id = start_id
286    self.num_to_add = num_to_add
287    self.table = OPTIONS.table
288    self.start()
289
290  def finish(self):
291    if self.con:
292      self.con.commit()
293      self.con.close()
294
295  def runme(self):
296    if self.reconnect():
297      raise Exception("Unable to connect to MySQL server")
298
299    stmt = None
300    for i in range(self.start_id, self.start_id + self.num_to_add):
301      stmt = gen_insert(self.table, i, 0, 0, 0)
302      execute(self.cur, stmt)
303      if i % 101 == 0:
304        self.con.commit()
305    self.con.commit()
306    logging.info("Inserted %d rows starting at id %d" %
307                 (self.num_to_add, self.start_id))
308
309def populate_table(num_records):
310
311  logging.info("Populate_table started for %d records" % num_records)
312  if num_records == 0:
313    return False
314
315  num_workers = min(10, int(num_records / 100))
316  workers = []
317
318  N = int(num_records / num_workers)
319  start_id = 0
320  for i in range(num_workers):
321     workers.append(PopulateWorker(i, start_id, N))
322     start_id += N
323  if num_records > start_id:
324    workers.append(PopulateWorker(num_workers, start_id,
325                   num_records - start_id))
326
327  # Wait for the populate threads to complete
328  return wait_for_workers(workers) > 0
329
330def gen_insert(table, idx, thread_id, request_id, zero_sum):
331  msg = gen_msg(idx, thread_id, request_id)
332  return ("INSERT INTO %s (id, thread_id, request_id, zero_sum, "
333          "msg, msg_length, msg_checksum) VALUES (%d,%d,%d,%d,'%s',%d,'%s')"
334           % (table, idx, thread_id, request_id,
335              zero_sum, msg, len(msg), sha1(msg)))
336
337def gen_update(table, idx, thread_id, request_id, zero_sum):
338  msg = gen_msg(idx, thread_id, request_id)
339  return ("UPDATE %s SET thread_id = %d, request_id = %d, "
340          "update_count = update_count + 1, zero_sum = zero_sum + (%d), "
341          "msg = '%s', msg_length = %d, msg_checksum = '%s' WHERE id = %d "
342          % (table, thread_id, request_id, zero_sum, msg, len(msg),
343             sha1(msg), idx))
344
345def gen_delete(table, idx):
346    return "DELETE FROM %s WHERE id = %d" % (table, idx)
347
348def gen_insert_on_dup(table, idx, thread_id, request_id, zero_sum):
349  msg = gen_msg(idx, thread_id, request_id)
350  msg_checksum = sha1(msg)
351  return ("INSERT INTO %s (id, thread_id, request_id, zero_sum, "
352          "msg, msg_length, msg_checksum) VALUES (%d,%d,%d,%d,'%s',%d,'%s') "
353          "ON DUPLICATE KEY UPDATE "
354          "thread_id=%d, request_id=%d, "
355          "update_count=update_count+1, "
356          "zero_sum=zero_sum + (%d), msg='%s', msg_length=%d, "
357          "msg_checksum='%s'" %
358          (table, idx, thread_id, request_id,
359           zero_sum, msg, len(msg), msg_checksum, thread_id, request_id,
360           zero_sum, msg, len(msg), msg_checksum))
361
362# Each loader thread owns a part of the id space which it maintains inventory
363# for. The loader thread generates inserts, updates and deletes for the table.
364# The latest successful transaction and the latest open transaction are kept to
365# verify after a disconnect that the rows were recovered properly.
366class LoadGenWorker(WorkerThread):
367  TXN_UNCOMMITTED = 0
368  TXN_COMMIT_STARTED = 1
369  TXN_COMMITTED = 2
370
371  def __init__(self, thread_id):
372    WorkerThread.__init__(self, 'loader-%02d' % thread_id)
373    self.thread_id = thread_id
374    self.rand = random.Random()
375    self.rand.seed(thread_id)
376    self.loop_num = 0
377
378    # id_map contains the array of id's owned by this worker thread. It needs
379    # to be offset by start_id for the actual id
380    self.id_map = array.array('l')
381    self.start_id = thread_id * OPTIONS.ids_per_loader
382    self.num_id = OPTIONS.ids_per_loader
383    self.start_share_id = OPTIONS.num_loaders * OPTIONS.ids_per_loader
384    self.max_id = OPTIONS.max_id
385    self.table = OPTIONS.table
386    self.num_requests = OPTIONS.num_requests
387
388    # stores information about the latest series of successful transactions
389    #
390    # each transaction is simply a map of id -> request_id
391    # deleted rows are indicated by -request_id
392    self.prev_txn = deque()
393    self.cur_txn = None
394    self.cur_txn_state = None
395
396    self.start()
397
398  def finish(self):
399    if self.total_time:
400      req_per_sec = self.loop_num / self.total_time
401    else:
402      req_per_sec = -1
403    logging.info("total txns: %d, txn/s: %.2f rps" %
404                 (self.loop_num, req_per_sec))
405
406  # constructs the internal hash map of the ids owned by this thread and
407  # the request id of each id
408  def populate_id_map(self):
409    logging.info("Populating id map")
410
411    REQ_ID_COL = 0
412    stmt = "SELECT request_id FROM %s WHERE id = %d"
413
414    # the start_id is used for tracking active transactions, so the row needs
415    # to exist
416    idx = self.start_id
417    execute(self.cur, stmt % (self.table, idx))
418    if self.cur.rowcount > 0:
419      request_id = self.cur.fetchone()[REQ_ID_COL]
420    else:
421      request_id = get_next_request_id()
422      execute(self.cur, gen_insert(self.table, idx, self.thread_id,
423                                   request_id, 0))
424      self.con.commit()
425
426    self.id_map.append(request_id)
427
428    self.cur_txn = {idx:request_id}
429    self.cur_txn_state = self.TXN_COMMITTED
430    for i in range(OPTIONS.committed_txns):
431      self.prev_txn.append(self.cur_txn)
432
433    # fetch the rest of the row for the id space owned by this thread
434    for idx in range(self.start_id + 1, self.start_id + self.num_id):
435      execute(self.cur, stmt % (self.table, idx))
436      if self.cur.rowcount == 0:
437        # Negative number is used to indicated a missing row
438        self.id_map.append(-1)
439      else:
440        res = self.cur.fetchone()
441        self.id_map.append(res[REQ_ID_COL])
442
443    self.con.commit()
444
445  def apply_cur_txn_changes(self):
446    # apply the changes to the id_map
447    for idx in self.cur_txn:
448      if idx < self.start_id + self.num_id:
449        assert idx >= self.start_id
450        self.id_map[idx - self.start_id] = self.cur_txn[idx]
451    self.cur_txn_state = self.TXN_COMMITTED
452
453    self.prev_txn.append(self.cur_txn)
454    self.prev_txn.popleft()
455
456  def verify_txn(self, txn, committed):
457    request_id = txn[self.start_id]
458    if not committed:
459      # if the transaction was not committed, then there should be no rows
460      # in the table that have this request_id
461      cond = '='
462      # it is possible the start_id used to track this transaction is in
463      # the process of being deleted
464      if request_id < 0:
465        request_id = -request_id
466    else:
467      # if the transaction was committed, then no rows modified by this
468      # transaction should have a request_id less than this transaction's id
469      cond = '<'
470    stmt = ("SELECT COUNT(*) FROM %s WHERE id IN (%s) AND request_id %s %d" %
471            (self.table, ','.join(str(x) for x in txn), cond, request_id))
472    execute(self.cur, stmt)
473    if (self.cur.rowcount != 1):
474      raise TestError("Unable to retrieve results for query '%s'" % stmt)
475    count = self.cur.fetchone()[0]
476    if (count > 0):
477      raise TestError("Expected '%s' to return 0 rows, but %d returned "
478                      "instead" % (stmt, count))
479    self.con.commit()
480
481  def verify_data(self):
482    # if the state of the current transaction is unknown (i.e. a commit was
483    # issued, but the connection failed before, check the start_id row to
484    # determine if it was committed
485    request_id = self.cur_txn[self.start_id]
486    if self.cur_txn_state == self.TXN_COMMIT_STARTED:
487      assert request_id >= 0
488      idx = self.start_id
489      stmt = "SELECT id, request_id FROM %s where id = %d" % (self.table, idx)
490      execute(self.cur, stmt)
491      if (self.cur.rowcount == 0):
492        raise TestError("Fetching start_id %d via '%s' returned no data! "
493                        "This row should never be deleted!" % (idx, stmt))
494      REQUEST_ID_COL = 1
495      res = self.cur.fetchone()
496      if res[REQUEST_ID_COL] == self.cur_txn[idx]:
497        self.apply_cur_txn_changes()
498      else:
499        self.cur_txn_state = self.TXN_UNCOMMITTED
500      self.con.commit()
501
502    # if the transaction was not committed, verify there are no rows at this
503    # request id
504    #
505    # however, if the transaction was committed, then verify none of the rows
506    # have a request_id below the request_id recorded by the start_id row.
507    if self.cur_txn_state == self.TXN_UNCOMMITTED:
508      self.verify_txn(self.cur_txn, False)
509
510    # verify all committed transactions
511    for txn in self.prev_txn:
512      self.verify_txn(txn, True)
513
514    # verify the rows owned by this worker matches the request_id at which
515    # they were set.
516    idx = self.start_id
517    max_map_id = self.start_id + self.num_id
518    row_count = 0
519    ID_COL = 0
520    REQ_ID_COL = ID_COL + 1
521
522    while idx < max_map_id:
523      if (row_count == 0):
524        num_rows_to_check = random.randint(50, 100)
525        execute(self.cur,
526          "SELECT id, request_id FROM %s where id >= %d and id < %d "
527          "ORDER BY id LIMIT %d"
528          % (self.table, idx, max_map_id, num_rows_to_check))
529
530        # prevent future queries from being issued since we've hit the end of
531        # the rows that exist in the table
532        row_count = self.cur.rowcount if self.cur.rowcount != 0 else -1
533
534      # determine the id of the next available row in the table
535      if (row_count > 0):
536        res = self.cur.fetchone()
537        assert idx <= res[ID_COL]
538        next_id = res[ID_COL]
539        row_count -= 1
540      else:
541        next_id = max_map_id
542
543      # rows up to the next id don't exist within the table, verify our
544      # map has them as removed
545      while idx < next_id:
546        # see if the latest transaction may have modified this id. If so, use
547        # that value.
548        if self.id_map[idx - self.start_id] >= 0:
549          raise ValidateError("Row id %d was not found in table, but "
550                              "id_map has it at request_id %d" %
551                              (idx, self.id_map[idx - self.start_id]))
552        idx += 1
553
554      if idx == max_map_id:
555        break
556
557      if (self.id_map[idx - self.start_id] != res[REQ_ID_COL]):
558        raise ValidateError("Row id %d has req id %d, but %d is the "
559                            "expected value!" %
560                            (idx, res[REQ_ID_COL],
561                             self.id_map[idx - self.start_id]))
562      idx += 1
563
564    self.con.commit()
565    logging.debug("Verified data successfully")
566
567  def execute_one(self):
568    # select a number of rows; perform an insert; update or delete operation on
569    # them
570    num_rows = random.randint(1, MAX_ROWS_PER_REQ)
571    ids = array.array('L')
572
573    # allocate at least one row in the id space owned by this worker
574    idx = random.randint(self.start_id, self.start_id + self.num_id - 1)
575    ids.append(idx)
576
577    for i in range(1, num_rows):
578      # The valid ranges for ids is from start_id to start_id + num_id and from
579      # start_share_id to max_id. The randint() uses the range from
580      # start_share_id to max_id + num_id - 1. start_share_id to max_id covers
581      # the shared range. The exclusive range is covered by max_id to max_id +
582      # num_id - 1. If any number lands in this >= max_id section, it is
583      # remapped to start_id and used for selecting a row in the exclusive
584      # section.
585      idx = random.randint(self.start_share_id, self.max_id + self.num_id - 1)
586      if idx >= self.max_id:
587        idx -= self.max_id - self.start_id
588      if ids.count(idx) == 0:
589        ids.append(idx)
590
591    # perform a read of these rows
592    ID_COL = 0
593    ZERO_SUM_COL = ID_COL + 1
594
595    # For repeatable-read isolation levels on MyRocks, during the lock
596    # acquisition part of this transaction, it is possible the selected rows
597    # conflict with another thread's transaction. This results in a deadlock
598    # error that requires the whole transaction to be rolled back because the
599    # transaction's current snapshot will always be reading an older version of
600    # the row. MyRocks will prevent any updates to this row until the
601    # snapshot is released and re-acquired.
602    NUM_RETRIES = 100
603    for i in range(NUM_RETRIES):
604      ids_found = {}
605      try:
606        for idx in ids:
607          stmt = ("SELECT id, zero_sum FROM %s WHERE id = %d "
608                  "FOR UPDATE" % (self.table, idx))
609          execute(self.cur, stmt)
610          if self.cur.rowcount > 0:
611            res = self.cur.fetchone()
612            ids_found[res[ID_COL]] = res[ZERO_SUM_COL]
613        break
614      except MySQLdb.OperationalError as e:
615        if not is_deadlock_error(e):
616          raise e
617
618      # if a deadlock occurred, rollback the transaction and wait a short time
619      # before retrying.
620      logging.debug("%s generated deadlock, retry %d of %d" %
621                    (stmt, i, NUM_RETRIES))
622      self.con.rollback()
623      time.sleep(0.2)
624
625    if i == NUM_RETRIES - 1:
626      raise TestError("Unable to acquire locks after a number of retries "
627                      "for query '%s'" % stmt)
628
629    # ensure that the zero_sum column remains summed up to zero at the
630    # end of this operation
631    current_sum = 0
632
633    # all row locks acquired at this point, so allocate a request_id
634    request_id = get_next_request_id()
635    self.cur_txn = {self.start_id:request_id}
636    self.cur_txn_state = self.TXN_UNCOMMITTED
637
638    for idx in ids:
639      stmt = None
640      zero_sum = self.rand.randint(-1000, 1000)
641      action = self.rand.randint(0, 3)
642      is_delete = False
643
644      if idx in ids_found:
645        # for each row found, determine if it should be updated or deleted
646        if action == 0:
647          stmt = gen_delete(self.table, idx)
648          is_delete = True
649          current_sum -= ids_found[idx]
650        else:
651          stmt = gen_update(self.table, idx, self.thread_id, request_id,
652                            zero_sum)
653          current_sum += zero_sum
654      else:
655        # if it does not exist, then determine if an insert should happen
656        if action <= 1:
657          stmt = gen_insert(self.table, idx, self.thread_id, request_id,
658                            zero_sum)
659          current_sum += zero_sum
660
661      if stmt is not None:
662        # mark in self.cur_txn what these new changes will be
663        if is_delete:
664          self.cur_txn[idx] = -request_id
665        else:
666          self.cur_txn[idx] = request_id
667        execute(self.cur, stmt)
668        if self.cur.rowcount == 0:
669          raise TestError("Executing %s returned row count of 0!" % stmt)
670
671    # the start_id row is used to determine if this transaction has been
672    # committed if the connect fails and it is used to adjust the zero_sum
673    # correctly
674    idx = self.start_id
675    ids.append(idx)
676    self.cur_txn[idx] = request_id
677    stmt = gen_insert_on_dup(self.table, idx, self.thread_id, request_id,
678                             -current_sum)
679    execute(self.cur, stmt)
680    if self.cur.rowcount == 0:
681      raise TestError("Executing '%s' returned row count of 0!" % stmt)
682
683    # 90% commit, 10% rollback
684    if roll_d100(90):
685      self.con.rollback()
686      logging.debug("request %s was rolled back" % request_id)
687    else:
688      self.cur_txn_state = self.TXN_COMMIT_STARTED
689      self.con.commit()
690      if not self.con.get_server_info():
691        raise MySQLdb.OperationalError(MySQLdb.constants.CR.CONNECTION_ERROR,
692                                       "Possible connection error on commit")
693      self.apply_cur_txn_changes()
694
695    self.loop_num += 1
696    if self.loop_num % 1000 == 0:
697      logging.info("Processed %d transactions so far" % self.loop_num)
698
699  def runme(self):
700    global TEST_STOP, LOADERS_READY
701
702    self.start_time = time.time()
703    if self.reconnect():
704      raise Exception("Unable to connect to MySQL server")
705
706    self.populate_id_map()
707    self.verify_data()
708
709    logging.info("Starting load generator")
710    reconnected = False
711    LOADERS_READY += 1
712
713    while self.loop_num < self.num_requests and not TEST_STOP:
714      try:
715        # verify our data on each reconnect and also on ocassion
716        if reconnected or random.randint(1, 500) == 1:
717          self.verify_data()
718          reconnected = False
719
720        self.execute_one()
721        self.loop_num += 1
722      except MySQLdb.OperationalError as e:
723        if not is_connection_error(e):
724          raise e
725        if self.reconnect():
726          raise Exception("Unable to connect to MySQL server")
727        reconnected = True
728    return
729
730# the checker thread is running read only transactions to verify the row
731# checksums match the message.
732class CheckerWorker(WorkerThread):
733  def __init__(self, thread_id):
734    WorkerThread.__init__(self, 'checker-%02d' % thread_id)
735    self.thread_id = thread_id
736    self.rand = random.Random()
737    self.rand.seed(thread_id)
738    self.max_id = OPTIONS.max_id
739    self.table = OPTIONS.table
740    self.loop_num = 0
741    self.start()
742
743  def finish(self):
744    logging.info("total loops: %d" % self.loop_num)
745
746  def check_zerosum(self):
747    # two methods for checking zero sum
748    #   1. request the server to do it (90% of the time for now)
749    #   2. read all rows and calculate directly
750    if roll_d100(90):
751      stmt = "SELECT SUM(zero_sum) FROM %s" % self.table
752      if roll_d100(50):
753        stmt += " FORCE INDEX(msg_i)"
754      execute(self.cur, stmt)
755
756      if self.cur.rowcount != 1:
757        raise ValidateError("Error with query '%s'" % stmt)
758      res = self.cur.fetchone()[0]
759      if res != 0:
760        raise ValidateError("Expected zero_sum to be 0, but %d returned "
761                            "instead" % res)
762    else:
763      cur_isolation_level = self.get_isolation_level()
764      self.set_isolation_level('REPEATABLE-READ')
765      num_rows_to_check = random.randint(500, 1000)
766      idx = 0
767      sum = 0
768
769      stmt = "SELECT id, zero_sum FROM %s where id >= %d ORDER BY id LIMIT %d"
770      ID_COL = 0
771      ZERO_SUM_COL = 1
772
773      while idx < self.max_id:
774        execute(self.cur, stmt % (self.table, idx, num_rows_to_check))
775        if self.cur.rowcount == 0:
776          break
777
778        for i in range(self.cur.rowcount - 1):
779          sum += self.cur.fetchone()[ZERO_SUM_COL]
780
781        last_row = self.cur.fetchone()
782        idx = last_row[ID_COL] + 1
783        sum += last_row[ZERO_SUM_COL]
784
785      if sum != 0:
786        raise TestError("Zero sum column expected to total 0, but sum is %d "
787                        "instead!" % sum)
788      self.set_isolation_level(cur_isolation_level)
789
790  def check_rows(self):
791    class id_range():
792      def __init__(self, min_id, min_inclusive, max_id, max_inclusive):
793        self.min_id = min_id if min_inclusive else min_id + 1
794        self.max_id = max_id if max_inclusive else max_id - 1
795      def count(self, idx):
796        return idx >= self.min_id and idx <= self.max_id
797
798    stmt = ("SELECT id, msg, msg_length, msg_checksum FROM %s WHERE " %
799            self.table)
800
801    # two methods for checking rows
802    #  1. pick a number of rows at random
803    #  2. range scan
804    if roll_d100(90):
805      ids = []
806      for i in range(random.randint(1, MAX_ROWS_PER_REQ)):
807        ids.append(random.randint(0, self.max_id - 1))
808      stmt += "id in (%s)" % ','.join(str(x) for x in ids)
809    else:
810      id1 = random.randint(0, self.max_id - 1)
811      id2 = random.randint(0, self.max_id - 1)
812      min_inclusive = random.randint(0, 1)
813      cond1 = '>=' if min_inclusive else '>'
814      max_inclusive = random.randint(0, 1)
815      cond2 = '<=' if max_inclusive else '<'
816      stmt += ("id %s %d AND id %s %d" %
817               (cond1, min(id1, id2), cond2, max(id1, id2)))
818      ids = id_range(min(id1, id2), min_inclusive, max(id1, id2), max_inclusive)
819
820    execute(self.cur, stmt)
821
822    ID_COL = 0
823    MSG_COL = ID_COL + 1
824    MSG_LENGTH_COL = MSG_COL + 1
825    MSG_CHECKSUM_COL = MSG_LENGTH_COL + 1
826
827    for row in self.cur.fetchall():
828      idx = row[ID_COL]
829      msg = row[MSG_COL]
830      msg_length = row[MSG_LENGTH_COL]
831      msg_checksum = row[MSG_CHECKSUM_COL]
832      if ids.count(idx) < 1:
833        raise ValidateError(
834            "id %d returned from database, but query was '%s'" % (idx, stmt))
835      if (len(msg) != msg_length):
836        raise ValidateError(
837            "id %d contains msg_length %d, but msg '%s' is only %d "
838            "characters long" % (idx, msg_length, msg, len(msg)))
839      if (sha1(msg) != msg_checksum):
840        raise ValidateError("id %d has checksum '%s', but expected checksum "
841                            "is '%s'" % (idx, msg_checksum, sha1(msg)))
842
843  def runme(self):
844    global TEST_STOP
845
846    self.start_time = time.time()
847    if self.reconnect():
848      raise Exception("Unable to connect to MySQL server")
849    logging.info("Starting checker")
850
851    while not TEST_STOP:
852      try:
853        # choose one of three options:
854        #   1. compute zero_sum across all rows is 0
855        #   2. read a number of rows and verify checksums
856        if roll_d100(25):
857          self.check_zerosum()
858        else:
859          self.check_rows()
860
861        self.con.commit()
862        self.loop_num += 1
863        if self.loop_num % 10000 == 0:
864          logging.info("Processed %d transactions so far" % self.loop_num)
865      except MySQLdb.OperationalError as e:
866        if not is_connection_error(e):
867          raise e
868        if self.reconnect():
869          raise Exception("Unable to reconnect to MySQL server")
870
871if  __name__ == '__main__':
872  parser = argparse.ArgumentParser(description='Concurrent load generator.')
873
874  parser.add_argument('-C, --committed-txns', dest='committed_txns',
875                      default=3, type=int,
876                      help="number of committed txns to verify")
877
878  parser.add_argument('-c, --num-checkers', dest='num_checkers', type=int,
879                      default=4,
880                      help="number of reader/checker threads to test with")
881
882  parser.add_argument('-d, --db', dest='db', default='test',
883                      help="mysqld server database to test with")
884
885  parser.add_argument('-H, --host', dest='host', default='127.0.0.1',
886                      help="mysqld server host ip address")
887
888  parser.add_argument('-i, --ids-per-loader', dest='ids_per_loader',
889                      type=int, default=100,
890                      help="number of records which each loader owns "
891                           "exclusively, up to max-id / 2 / num-loaders")
892
893  parser.add_argument('-L, --log-file', dest='log_file', default=None,
894                      help="log file for output")
895
896  parser.add_argument('-l, --num-loaders', dest='num_loaders', type=int,
897                      default=16,
898                      help="number of loader threads to test with")
899
900  parser.add_argument('-m, --max-id', dest='max_id', type=int, default=1000,
901                      help="maximum number of records which the table "
902                           "extends to, must be larger than ids_per_loader * "
903                           "num_loaders")
904
905  parser.add_argument('-n, --num-records', dest='num_records', type=int,
906                      default=0,
907                      help="number of records to populate the table with")
908
909  parser.add_argument('-P, --port', dest='port', default=3307, type=int,
910                      help='mysqld server host port')
911
912  parser.add_argument('-r, --num-requests', dest='num_requests', type=int,
913                      default=100000000,
914                      help="number of requests issued per worker thread")
915
916  parser.add_argument('-T, --truncate', dest='truncate', action='store_true',
917                      help="truncates or creates the table before the test")
918
919  parser.add_argument('-t, --table', dest='table', default='t1',
920                      help="mysqld server table to test with")
921
922  parser.add_argument('-u, --user', dest='user', default='root',
923                      help="user to log into the mysql server")
924
925  parser.add_argument('-v, --verbose', dest='verbose', action='store_true',
926                      help="enable debug logging")
927
928  parser.add_argument('-E, --expect-file', dest='expect_file', default=None,
929                      help="expect file for server restart")
930
931  parser.add_argument('-D, --reap-delay', dest='reap_delay', type=int,
932                      default=0,
933                      help="seconds to sleep after each server reap")
934
935  OPTIONS = parser.parse_args()
936
937  if OPTIONS.verbose:
938    log_level = logging.DEBUG
939  else:
940    log_level = logging.INFO
941
942  logging.basicConfig(level=log_level,
943                      format='%(asctime)s %(threadName)s [%(levelname)s] '
944                             '%(message)s',
945                      datefmt='%Y-%m-%d %H:%M:%S',
946                      filename=OPTIONS.log_file)
947
948  logging.info("Command line given: %s" % ' '.join(sys.argv))
949
950  if (OPTIONS.max_id < 0 or OPTIONS.ids_per_loader <= 0 or
951      OPTIONS.max_id < OPTIONS.ids_per_loader * OPTIONS.num_loaders):
952    logging.error("ids-per-loader must be larger tha 0 and max-id must be "
953                  "larger than ids_per_loader * num_loaders")
954    exit(1)
955
956  logging.info("Using table %s.%s for test" % (OPTIONS.db, OPTIONS.table))
957
958  if OPTIONS.truncate:
959    logging.info("Truncating table")
960    con = MySQLdb.connect(user=OPTIONS.user, host=OPTIONS.host,
961                          port=OPTIONS.port, db=OPTIONS.db)
962    if not con:
963      raise TestError("Unable to connect to mysqld server to create/truncate "
964                      "table")
965    cur = con.cursor()
966    cur.execute("SELECT COUNT(*) FROM INFORMATION_SCHEMA.tables WHERE "
967                         "table_schema = '%s' AND table_name = '%s'" %
968                         (OPTIONS.db, OPTIONS.table))
969    if cur.rowcount != 1:
970      logging.error("Unable to retrieve information about table %s "
971                    "from information_schema!" % OPTIONS.table)
972      exit(1)
973
974    if cur.fetchone()[0] == 0:
975      logging.info("Table %s not found, creating a new one" % OPTIONS.table)
976      cur.execute("CREATE TABLE %s (id INT PRIMARY KEY, "
977                  "thread_id INT NOT NULL, "
978                  "request_id BIGINT UNSIGNED NOT NULL, "
979                  "update_count INT UNSIGNED NOT NULL DEFAULT 0, "
980                  "zero_sum INT DEFAULT 0, "
981                  "msg VARCHAR(1024), "
982                  "msg_length int, "
983                  "msg_checksum varchar(128), "
984                  "KEY msg_i(msg(255), zero_sum)) "
985                  "ENGINE=RocksDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin" %
986                  OPTIONS.table)
987    else:
988      logging.info("Table %s found, truncating" % OPTIONS.table)
989      cur.execute("TRUNCATE TABLE %s" % OPTIONS.table)
990    con.commit()
991
992  if populate_table(OPTIONS.num_records):
993    logging.error("Populate table returned an error")
994    exit(1)
995
996  logging.info("Starting %d loaders" % OPTIONS.num_loaders)
997  loaders = []
998  for i in range(OPTIONS.num_loaders):
999    loaders.append(LoadGenWorker(i))
1000
1001  logging.info("Starting %d checkers" % OPTIONS.num_checkers)
1002  checkers = []
1003  for i in range(OPTIONS.num_checkers):
1004    checkers.append(CheckerWorker(i))
1005
1006  while LOADERS_READY < OPTIONS.num_loaders:
1007    time.sleep(0.5)
1008
1009  if OPTIONS.expect_file and OPTIONS.reap_delay > 0:
1010    logging.info('Starting reaper')
1011    checkers.append(ReaperWorker())
1012
1013  workers_failed = 0
1014  workers_failed += wait_for_workers(loaders, len(checkers))
1015
1016  if TEST_STOP:
1017    logging.error("Detected test failure, aborting")
1018    os._exit(1)
1019
1020  TEST_STOP = True
1021
1022  workers_failed += wait_for_workers(checkers)
1023
1024  if workers_failed > 0:
1025    logging.error("Test detected %d failures, aborting" % workers_failed)
1026    sys.exit(1)
1027
1028  logging.info("Test completed successfully")
1029  sys.exit(0)
1030