1# coding=utf-8 2 3import collections 4import os 5import random 6import re 7import shutil 8import subprocess 9import tempfile 10import zipfile 11 12import gevent.queue 13import gevent.server 14import gevent.socket 15import time 16 17from common import * 18from includes import * 19 20CREATE_INDICES_TARGET_DIR = '/tmp/test' 21BASE_RDBS_URL = 'https://s3.amazonaws.com/redismodules/redisearch-enterprise/rdbs/' 22 23SHORT_READ_BYTES_DELTA = int(os.getenv('SHORT_READ_BYTES_DELTA', '1')) 24SHORT_READ_FULL_TEST = int(os.getenv('SHORT_READ_FULL_TEST', '0')) 25 26RDBS_SHORT_READS = [ 27 'short-reads/redisearch_2.2.0.rdb.zip', 28 'short-reads/rejson_2.0.0.rdb.zip', 29 'short-reads/redisearch_2.2.0_rejson_2.0.0.rdb.zip', 30] 31RDBS_COMPATIBILITY = [ 32 'redisearch_2.0.9.rdb', 33] 34 35ExpectedIndex = collections.namedtuple('ExpectedIndex', ['count', 'pattern', 'search_result_count']) 36RDBS_EXPECTED_INDICES = [ 37 ExpectedIndex(2, 'shortread_idxSearch_[1-9]', [20, 55]), 38 ExpectedIndex(2, 'shortread_idxJson_[1-9]', [55, 20]), # TODO: why order of indices is first _2 then _1 39 ExpectedIndex(2, 'shortread_idxSearchJson_[1-9]', [10, 35]) 40 ] 41 42RDBS = [] 43RDBS.extend(RDBS_SHORT_READS) 44if not CODE_COVERAGE and SANITIZER == '' and SHORT_READ_FULL_TEST: 45 RDBS.extend(RDBS_COMPATIBILITY) 46 RDBS_EXPECTED_INDICES.append(ExpectedIndex(1, 'idx', [1000])) 47 48 49def unzip(zip_path, to_dir): 50 if not zipfile.is_zipfile(zip_path): 51 return False 52 with zipfile.ZipFile(zip_path, 'r') as db_zip: 53 for info in db_zip.infolist(): 54 if not os.path.exists(db_zip.extract(info, to_dir)): 55 return False 56 return True 57 58 59def downloadFiles(target_dir): 60 for f in RDBS: 61 path = os.path.join(target_dir, f) 62 path_dir = os.path.dirname(path) 63 if not os.path.exists(path_dir): 64 os.makedirs(path_dir) 65 if not os.path.exists(path): 66 subprocess.call(['wget', '-q', BASE_RDBS_URL + f, '-O', path]) 67 _, ext = os.path.splitext(f) 68 if ext == '.zip': 69 if not unzip(path, path_dir): 70 return False 71 else: 72 if not os.path.exists(path) or os.path.getsize(path) == 0: 73 return False 74 return True 75 76 77def rand_name(k): 78 # rand alphabetic string with between 2 to k chars 79 return ''.join([chr(random.randint(ord('a'), ord('z'))) for _ in range(0, random.randint(2, max(2, k)))]) 80 # return random.choices(''.join(string.ascii_letters), k=k) 81 82 83def rand_num(k): 84 # rand positive number with between 2 to k digits 85 return chr(random.randint(ord('1'), ord('9'))) + ''.join([chr(random.randint(ord('0'), ord('9'))) for _ in range(0, random.randint(1, max(1, k)))]) 86 # return random.choices(''.join(string.digits, k=k)) 87 88 89def create_indices(env, rdbFileName, idxNameStem, isHash, isJson): 90 env.flush() 91 idxNameStem = 'shortread_' + idxNameStem + '_' 92 if isHash and isJson: 93 # 1 Hash index and 1 Json index 94 add_index(env, True, idxNameStem + '1', 'c', 10, 20) 95 add_index(env, False, idxNameStem + '2', 'd', 35, 40) 96 elif isHash: 97 # 2 Hash indices 98 add_index(env, True, idxNameStem + '1', 'e', 10, 20) 99 add_index(env, True, idxNameStem + '2', 'f', 35, 40) 100 elif isJson: 101 # 2 Json indices 102 add_index(env, False, idxNameStem + '1', 'g', 10, 20) 103 add_index(env, False, idxNameStem + '2', 'h', 35, 40) 104 else: 105 env.assertTrue(False, "should not reach here") 106 107 # Save the rdb 108 env.assertOk(env.cmd('config', 'set', 'dbfilename', rdbFileName)) 109 dbFileName = env.cmd('config', 'get', 'dbfilename')[1] 110 dbDir = env.cmd('config', 'get', 'dir')[1] 111 dbFilePath = os.path.join(dbDir, dbFileName) 112 113 env.assertTrue(env.cmd('save')) 114 # Copy to avoid truncation of rdb due to RLTest flush and save 115 dbCopyFilePath = os.path.join(CREATE_INDICES_TARGET_DIR, dbFileName) 116 dbCopyFileDir = os.path.dirname(dbCopyFilePath) 117 if not os.path.exists(dbCopyFileDir): 118 os.makedirs(dbCopyFileDir) 119 with zipfile.ZipFile(dbCopyFilePath + '.zip', 'w') as db_zip: 120 db_zip.write(dbFilePath, os.path.join(os.path.curdir, os.path.basename(dbCopyFilePath))) 121 122 123def get_identifier(name, isHash): 124 return '$.' + name if not isHash else name 125 126 127def add_index(env, isHash, index_name, key_suffix, num_prefs, num_keys): 128 ''' Cover most of the possible options of an index 129 130 FT.CREATE {index} 131 [ON {structure}] 132 [PREFIX {count} {prefix} [{prefix} ..] 133 [FILTER {filter}] 134 [LANGUAGE {default_lang}] 135 [LANGUAGE_FIELD {lang_field}] 136 [SCORE {default_score}] 137 [SCORE_FIELD {score_field}] 138 [PAYLOAD_FIELD {payload_field}] 139 [MAXTEXTFIELDS] [TEMPORARY {seconds}] [NOOFFSETS] [NOHL] [NOFIELDS] [NOFREQS] [SKIPINITIALSCAN] 140 [STOPWORDS {num} {stopword} ...] 141 SCHEMA {field} [TEXT [NOSTEM] [WEIGHT {weight}] [PHONETIC {matcher}] | NUMERIC | GEO | TAG [SEPARATOR {sep}] ] [SORTABLE][NOINDEX] ... 142 ''' 143 144 # Create the index 145 cmd_create = ['ft.create', index_name, 'ON', 'HASH' if isHash else 'JSON'] 146 # With prefixes 147 if num_prefs > 0: 148 cmd_create.extend(['prefix', num_prefs]) 149 for i in range(1, num_prefs + 1): 150 cmd_create.append('pref' + str(i) + ":") 151 # With filter 152 cmd_create.extend(['filter', 'startswith(@__key' + ', "p")']) 153 # With language 154 cmd_create.extend(['language', 'spanish']) 155 # With language field 156 cmd_create.extend(['language_field', 'myLang']) 157 # With score field 158 cmd_create.extend(['score_field', 'myScore']) 159 # With payload field 160 cmd_create.extend(['payload_field', 'myPayload']) 161 # With maxtextfields 162 cmd_create.append('maxtextfields') 163 # With stopwords 164 cmd_create.extend(['stopwords', 3, 'stop', 'being', 'silly']) 165 # With schema 166 cmd_create.extend(['schema', 167 get_identifier('field1', isHash), 'as', 'f1', 'text', 'nostem', 'weight', '0.2', 'phonetic', 168 'dm:es', 'sortable', 169 get_identifier('field2', isHash), 'as', 'f2', 'numeric', 'sortable', 170 get_identifier('field3', isHash), 'as', 'f3', 'geo', 171 get_identifier('field4', isHash), 'as', 'f4', 'tag', 'separator', ';', 172 get_identifier('field11', isHash), 'text', 'nostem', 173 get_identifier('field12', isHash), 'numeric', 'noindex', 174 get_identifier('field13', isHash), 'geo', 175 get_identifier('field14', isHash), 'tag', 'noindex', 176 get_identifier('myLang', isHash), 'text', 177 get_identifier('myScore', isHash), 'numeric', 178 ]) 179 conn = getConnectionByEnv(env) 180 env.assertOk(conn.execute_command(*cmd_create)) 181 waitForIndex(conn, index_name) 182 env.assertOk(conn.execute_command('ft.synupdate', index_name, 'syngrp1', 'pelota', 'bola', 'balón')) 183 env.assertOk(conn.execute_command('ft.synupdate', index_name, 'syngrp2', 'jugar', 'tocar')) 184 185 # Add keys 186 for i in range(1, num_keys + 1): 187 if isHash: 188 cmd = ['hset', 'pref' + str(i) + ":k" + str(i) + '_' + rand_num(5) + key_suffix, 'a' + rand_name(5), rand_num(2), 'b' + rand_name(5), rand_num(3)] 189 env.assertEqual(conn.execute_command(*cmd), 2L) 190 else: 191 cmd = ['json.set', 'pref' + str(i) + ":k" + str(i) + '_' + rand_num(5) + key_suffix, '$', r'{"field1":"' + rand_name(5) + r'", "field2":' + rand_num(3) + r'}'] 192 env.assertOk(conn.execute_command(*cmd)) 193 194 195def testCreateIndexRdbFiles(env): 196 create_indices(env, 'redisearch_2.2.0.rdb', 'idxSearch', True, False) 197 198@no_msan 199def testCreateIndexRdbFilesWithJSON(env): 200 create_indices(env, 'rejson_2.0.0.rdb', 'idxJson', False, True) 201 create_indices(env, 'redisearch_2.2.0_rejson_2.0.0.rdb', 'idxSearchJson', True, True) 202 203 204class Connection(object): 205 def __init__(self, sock, bufsize=4096, underlying_sock=None): 206 self.sock = sock 207 self.sockf = sock.makefile('rwb', bufsize) 208 self.closed = False 209 self.peer_closed = False 210 self.underlying_sock = underlying_sock 211 212 def close(self): 213 if not self.closed: 214 self.closed = True 215 self.sockf.close() 216 self.sock.close() 217 self.sockf = None 218 219 def is_close(self, timeout=2): 220 if self.closed: 221 return True 222 try: 223 with TimeLimit(timeout): 224 return self.read(1) == '' 225 except Exception: 226 return False 227 228 def flush(self): 229 self.sockf.flush() 230 231 def get_address(self): 232 return self.sock.getsockname()[0] 233 234 def get_port(self): 235 return self.sock.getsockname()[1] 236 237 def read(self, bytes): 238 return self.sockf.read(bytes) 239 240 def read_at_most(self, bytes, timeout=0.01): 241 self.sock.settimeout(timeout) 242 return self.sock.recv(bytes) 243 244 def send(self, data): 245 self.sockf.write(data) 246 self.sockf.flush() 247 248 def readline(self): 249 return self.sockf.readline() 250 251 def send_bulk_header(self, data_len): 252 self.sockf.write('$%d\r\n' % data_len) 253 self.sockf.flush() 254 255 def send_bulk(self, data): 256 self.sockf.write('$%d\r\n%s\r\n' % (len(data), data)) 257 self.sockf.flush() 258 259 def send_status(self, data): 260 self.sockf.write('+%s\r\n' % data) 261 self.sockf.flush() 262 263 def send_error(self, data): 264 self.sockf.write('-%s\r\n' % data) 265 self.sockf.flush() 266 267 def send_integer(self, data): 268 self.sockf.write(':%u\r\n' % data) 269 self.sockf.flush() 270 271 def send_mbulk(self, data): 272 self.sockf.write('*%d\r\n' % len(data)) 273 for elem in data: 274 self.sockf.write('$%d\r\n%s\r\n' % (len(elem), elem)) 275 self.sockf.flush() 276 277 def read_mbulk(self, args_count=None): 278 if args_count is None: 279 line = self.readline() 280 if not line: 281 self.peer_closed = True 282 if not line or line[0] != '*': 283 self.close() 284 return None 285 try: 286 args_count = int(line[1:]) 287 except ValueError: 288 raise Exception('Invalid mbulk header: %s' % line) 289 data = [] 290 for arg in range(args_count): 291 data.append(self.read_response()) 292 return data 293 294 def read_request(self): 295 line = self.readline() 296 if not line: 297 self.peer_closed = True 298 self.close() 299 return None 300 if line[0] != '*': 301 return line.rstrip().split() 302 try: 303 args_count = int(line[1:]) 304 except ValueError: 305 raise Exception('Invalid mbulk request: %s' % line) 306 return self.read_mbulk(args_count) 307 308 def read_request_and_reply_status(self, status): 309 req = self.read_request() 310 if not req: 311 return 312 self.current_request = req 313 self.send_status(status) 314 315 def wait_until_writable(self, timeout=None): 316 try: 317 gevent.socket.wait_write(self.sockf.fileno(), timeout) 318 except gevent.socket.error: 319 return False 320 return True 321 322 def wait_until_readable(self, timeout=None): 323 if self.closed: 324 return False 325 try: 326 gevent.socket.wait_read(self.sockf.fileno(), timeout) 327 except gevent.socket.error: 328 return False 329 return True 330 331 def read_response(self): 332 line = self.readline() 333 if not line: 334 self.peer_closed = True 335 self.close() 336 return None 337 if line[0] == '+': 338 return line.rstrip() 339 elif line[0] == ':': 340 try: 341 return int(line[1:]) 342 except ValueError: 343 raise Exception('Invalid numeric value: %s' % line) 344 elif line[0] == '-': 345 return line.rstrip() 346 elif line[0] == '$': 347 try: 348 bulk_len = int(line[1:]) 349 except ValueError: 350 raise Exception('Invalid bulk response: %s' % line) 351 if bulk_len == -1: 352 return None 353 data = self.sockf.read(bulk_len + 2) 354 if len(data) < bulk_len: 355 self.peer_closed = True 356 self.close() 357 return data[:bulk_len] 358 elif line[0] == '*': 359 try: 360 args_count = int(line[1:]) 361 except ValueError: 362 raise Exception('Invalid mbulk response: %s' % line) 363 return self.read_mbulk(args_count) 364 else: 365 raise Exception('Invalid response: %s' % line) 366 367 368class ShardMock: 369 def __init__(self, env): 370 self.env = env 371 self.new_conns = gevent.queue.Queue() 372 373 def _handle_conn(self, sock, client_addr): 374 conn = Connection(sock) 375 self.new_conns.put(conn) 376 377 def __enter__(self): 378 self.server_port = None 379 if not self.StartListening(port=0, attempts=10) and not self.StartListening(port=random.randint(55000, 57000)): 380 raise Exception("%s StartListening failed" % self.__class__.__name__) 381 return self 382 383 def __exit__(self, type, value, traceback): 384 self.stream_server.stop() 385 386 def GetConnection(self, timeout=None): 387 conn = self.new_conns.get(block=True, timeout=timeout) 388 return conn 389 390 def GetCleanConnection(self): 391 return self.new_conns.get(block=True, timeout=None) 392 393 def StopListening(self): 394 self.stream_server.stop() 395 self.server_port = None 396 397 def StartListening(self, port, attempts=1): 398 for i in range(1, attempts + 1): 399 self.stream_server = gevent.server.StreamServer(('127.0.0.1', port), self._handle_conn) 400 try: 401 self.stream_server.start() 402 except Exception as e: 403 self.env.assertEqual(self.server_port, None, message='%s: StartListening(%d/%d) %d -> %s' % ( 404 self.__class__.__name__, i, attempts, port, e.strerror)) 405 continue 406 self.server_port = self.stream_server.address[1] 407 self.env.assertNotEqual(self.server_port, None, message='%s: StartListening(%d/%d) %d -> %d' % ( 408 self.__class__.__name__, i, attempts, port, self.server_port)) 409 return True 410 else: 411 return False 412 413 414class Debug: 415 def __init__(self, enabled=False): 416 self.enabled = enabled 417 self.clear() 418 419 def clear(self): 420 self.dbg_str = '' 421 self.dbg_ndx = -1 422 423 def __call__(self, f): 424 if self.enabled: 425 def f_with_debug(*args, **kwds): 426 self.print_bytes_incremental(args[0], args[1], args[2], f.__name__) 427 if len(args[1]) == args[2]: 428 self.clear() 429 f(*args, **kwds) 430 431 return f_with_debug 432 else: 433 return f 434 435 def print_bytes_incremental(self, env, data, total_len, name): 436 # For debugging: print the binary content before it is sent 437 byte_count_width = len(str(total_len)) 438 if len(data): 439 ch = data[self.dbg_ndx] 440 printable_ch = ch 441 if ord(ch) < 32 or ord(ch) == 127: 442 printable_ch = '\?' 443 else: 444 ch = '\0' 445 printable_ch = '\!' # no data (zero length) 446 self.dbg_str = '{} {:=0{}n}:{:<2}({:<3})'.format(self.dbg_str, self.dbg_ndx, byte_count_width, printable_ch, ord(ch)) 447 if not (self.dbg_ndx + 1) % 10: 448 self.dbg_str = self.dbg_str + "\n" 449 self.dbg_ndx = self.dbg_ndx + 1 450 451 env.debugPrint(name + ': %d out of %d \n%s' % (self.dbg_ndx, total_len, self.dbg_str)) 452 453@no_msan 454def testShortReadSearch(env): 455 if CODE_COVERAGE or SANITIZER: 456 env.skip() # FIXME: enable coverage test 457 458 env.skipOnCluster() 459 if env.env.endswith('existing-env') and os.environ.get('CI'): 460 env.skip() 461 462 if OSNICK == 'macos': 463 env.skip() 464 465 seed = str(time.time()) 466 env.assertNotEqual(seed, None, message='random seed ' + seed) 467 random.seed(seed) 468 469 try: 470 temp_dir = tempfile.mkdtemp(prefix="short-read_") 471 # TODO: In python3 use "with tempfile.TemporaryDirectory()" 472 if not downloadFiles(temp_dir): 473 env.assertTrue(False, "downloadFiles failed") 474 475 for f, expected_index in zip(RDBS, RDBS_EXPECTED_INDICES): 476 name, ext = os.path.splitext(f) 477 if ext == '.zip': 478 f = name 479 fullfilePath = os.path.join(temp_dir, f) 480 env.assertNotEqual(fullfilePath, None, message='testShortReadSearch') 481 sendShortReads(env, fullfilePath, expected_index) 482 finally: 483 shutil.rmtree(temp_dir) 484 485 486def sendShortReads(env, rdb_file, expected_index): 487 # Add some initial content (index+keys) to test backup/restore/discard when short read fails 488 # When entire rdb is successfully sent and loaded (from swapdb) - backup should be discarded 489 env.assertCmdOk('replicaof', 'no', 'one') 490 env.flush() 491 add_index(env, True, 'idxBackup1', 'a', 5, 10) 492 add_index(env, False, 'idxBackup2', 'b', 5, 10) 493 494 res = env.cmd('ft.search ', 'idxBackup1', '*', 'limit', '0', '0') 495 env.assertEqual(res[0], 5L) 496 res = env.cmd('ft.search ', 'idxBackup2', '*', 'limit', '0', '0') 497 env.assertEqual(res[0], 5L) 498 499 with open(rdb_file, mode='rb') as f: 500 full_rdb = f.read() 501 total_len = len(full_rdb) 502 503 env.assertGreater(total_len, SHORT_READ_BYTES_DELTA) 504 r = range(0, total_len + 1, SHORT_READ_BYTES_DELTA) 505 if (total_len % SHORT_READ_BYTES_DELTA) != 0: 506 r = r + range(total_len, total_len + 1) 507 for b in r: 508 rdb = full_rdb[0:b] 509 runShortRead(env, rdb, total_len, expected_index) 510 511 512@Debug(False) 513def runShortRead(env, data, total_len, expected_index): 514 with ShardMock(env) as shardMock: 515 # For debugging: if adding breakpoints in redis, 516 # In order to avoid closing the connection, uncomment the following line 517 # res = env.cmd('CONFIG', 'SET', 'timeout', '0') 518 519 # Notice: Do not use env.expect in this test 520 # (since it is sending commands to redis and in this test we need to follow strict hand-shaking) 521 res = env.cmd('CONFIG', 'SET', 'repl-diskless-load', 'swapdb') 522 res = env.cmd('replicaof', '127.0.0.1', shardMock.server_port) 523 env.assertTrue(res) 524 conn = shardMock.GetConnection() 525 # Perform hand-shake with replica 526 res = conn.read_request() 527 env.assertEqual(res, ['PING']) 528 conn.send_status('PONG') 529 530 max_attempt = 100 531 res = conn.read_request() 532 while max_attempt > 0: 533 if res[0] == 'REPLCONF': 534 conn.send_status('OK') 535 else: 536 break 537 max_attempt = max_attempt - 1 538 res = conn.read_request() 539 540 # Send RDB to replica 541 some_guid = 'af4e30b5d14dce9f96fbb7769d0ec794cdc0bbcc' 542 conn.send_status('FULLRESYNC ' + some_guid + ' 0') 543 is_shortread = total_len != len(data) 544 if is_shortread: 545 conn.send('$%d\r\n%s' % (total_len, data)) 546 else: 547 # Allow to succeed with a full read (protocol expects a trailing '\r\n') 548 conn.send('$%d\r\n%s\r\n' % (total_len, data)) 549 conn.flush() 550 551 # Close during replica is waiting for more RDB data (so replica will re-connect to master) 552 conn.close() 553 554 # Make sure replica did not crash 555 res = env.cmd('PING') 556 env.assertEqual(res, True) 557 conn = shardMock.GetConnection(timeout=3) 558 env.assertNotEqual(conn, None) 559 560 res = env.cmd('ft._list') 561 if is_shortread: 562 # Verify original data, that existed before the failed attempt to short-read, is restored 563 env.assertEqual(res, ['idxBackup2', 'idxBackup1']) 564 res = env.cmd('ft.search ', 'idxBackup1', '*', 'limit', '0', '0') 565 env.assertEqual(res[0], 5L) 566 res = env.cmd('ft.search ', 'idxBackup2', '*', 'limit', '0', '0') 567 env.assertEqual(res[0], 5L) 568 else: 569 # Verify new data was loaded and the backup was discarded 570 # TODO: How to verify internal backup was indeed discarded 571 env.assertEqual(len(res), expected_index.count) 572 r = re.compile(expected_index.pattern) 573 expected_indices = list(filter(lambda x: r.match(x), res)) 574 env.assertEqual(len(expected_indices), expected_index.count) 575 for ind, expected_result_count in zip(expected_indices, expected_index.search_result_count): 576 res = env.cmd('ft.search ', ind, '*', 'limit', '0', '0') 577 env.assertEqual(res[0], expected_result_count) 578 579 # Exit (avoid read-only exception with flush on replica) 580 env.assertCmdOk('replicaof', 'no', 'one') 581