1# coding=utf-8
2
3import collections
4import os
5import random
6import re
7import shutil
8import subprocess
9import tempfile
10import zipfile
11
12import gevent.queue
13import gevent.server
14import gevent.socket
15import time
16
17from common import *
18from includes import *
19
20CREATE_INDICES_TARGET_DIR = '/tmp/test'
21BASE_RDBS_URL = 'https://s3.amazonaws.com/redismodules/redisearch-enterprise/rdbs/'
22
23SHORT_READ_BYTES_DELTA = int(os.getenv('SHORT_READ_BYTES_DELTA', '1'))
24SHORT_READ_FULL_TEST = int(os.getenv('SHORT_READ_FULL_TEST', '0'))
25
26RDBS_SHORT_READS = [
27    'short-reads/redisearch_2.2.0.rdb.zip',
28    'short-reads/rejson_2.0.0.rdb.zip',
29    'short-reads/redisearch_2.2.0_rejson_2.0.0.rdb.zip',
30]
31RDBS_COMPATIBILITY = [
32    'redisearch_2.0.9.rdb',
33]
34
35ExpectedIndex = collections.namedtuple('ExpectedIndex', ['count', 'pattern', 'search_result_count'])
36RDBS_EXPECTED_INDICES = [
37                         ExpectedIndex(2, 'shortread_idxSearch_[1-9]', [20, 55]),
38                         ExpectedIndex(2, 'shortread_idxJson_[1-9]', [55, 20]),  # TODO: why order of indices is first _2 then _1
39                         ExpectedIndex(2, 'shortread_idxSearchJson_[1-9]', [10, 35])
40                         ]
41
42RDBS = []
43RDBS.extend(RDBS_SHORT_READS)
44if not CODE_COVERAGE and SANITIZER == '' and SHORT_READ_FULL_TEST:
45    RDBS.extend(RDBS_COMPATIBILITY)
46    RDBS_EXPECTED_INDICES.append(ExpectedIndex(1, 'idx', [1000]))
47
48
49def unzip(zip_path, to_dir):
50    if not zipfile.is_zipfile(zip_path):
51        return False
52    with zipfile.ZipFile(zip_path, 'r') as db_zip:
53        for info in db_zip.infolist():
54            if not os.path.exists(db_zip.extract(info, to_dir)):
55                return False
56    return True
57
58
59def downloadFiles(target_dir):
60    for f in RDBS:
61        path = os.path.join(target_dir, f)
62        path_dir = os.path.dirname(path)
63        if not os.path.exists(path_dir):
64            os.makedirs(path_dir)
65        if not os.path.exists(path):
66            subprocess.call(['wget', '-q', BASE_RDBS_URL + f, '-O', path])
67            _, ext = os.path.splitext(f)
68            if ext == '.zip':
69                if not unzip(path, path_dir):
70                    return False
71            else:
72                if not os.path.exists(path) or os.path.getsize(path) == 0:
73                    return False
74    return True
75
76
77def rand_name(k):
78    # rand alphabetic string with between 2 to k chars
79    return ''.join([chr(random.randint(ord('a'), ord('z'))) for _ in range(0, random.randint(2, max(2, k)))])
80    # return random.choices(''.join(string.ascii_letters), k=k)
81
82
83def rand_num(k):
84    # rand positive number with between 2 to k digits
85    return chr(random.randint(ord('1'), ord('9'))) + ''.join([chr(random.randint(ord('0'), ord('9'))) for _ in range(0, random.randint(1, max(1, k)))])
86    # return random.choices(''.join(string.digits, k=k))
87
88
89def create_indices(env, rdbFileName, idxNameStem, isHash, isJson):
90    env.flush()
91    idxNameStem = 'shortread_' + idxNameStem + '_'
92    if isHash and isJson:
93        # 1 Hash index and 1 Json index
94        add_index(env, True, idxNameStem + '1', 'c', 10, 20)
95        add_index(env, False, idxNameStem + '2', 'd', 35, 40)
96    elif isHash:
97        # 2 Hash indices
98        add_index(env, True, idxNameStem + '1', 'e', 10, 20)
99        add_index(env, True, idxNameStem + '2', 'f', 35, 40)
100    elif isJson:
101        # 2 Json indices
102        add_index(env, False, idxNameStem + '1', 'g', 10, 20)
103        add_index(env, False, idxNameStem + '2', 'h', 35, 40)
104    else:
105        env.assertTrue(False, "should not reach here")
106
107    # Save the rdb
108    env.assertOk(env.cmd('config', 'set', 'dbfilename', rdbFileName))
109    dbFileName = env.cmd('config', 'get', 'dbfilename')[1]
110    dbDir = env.cmd('config', 'get', 'dir')[1]
111    dbFilePath = os.path.join(dbDir, dbFileName)
112
113    env.assertTrue(env.cmd('save'))
114    # Copy to avoid truncation of rdb due to RLTest flush and save
115    dbCopyFilePath = os.path.join(CREATE_INDICES_TARGET_DIR, dbFileName)
116    dbCopyFileDir = os.path.dirname(dbCopyFilePath)
117    if not os.path.exists(dbCopyFileDir):
118        os.makedirs(dbCopyFileDir)
119    with zipfile.ZipFile(dbCopyFilePath + '.zip', 'w') as db_zip:
120        db_zip.write(dbFilePath, os.path.join(os.path.curdir, os.path.basename(dbCopyFilePath)))
121
122
123def get_identifier(name, isHash):
124    return '$.' + name if not isHash else name
125
126
127def add_index(env, isHash, index_name, key_suffix, num_prefs, num_keys):
128    ''' Cover most of the possible options of an index
129
130    FT.CREATE {index}
131    [ON {structure}]
132    [PREFIX {count} {prefix} [{prefix} ..]
133    [FILTER {filter}]
134    [LANGUAGE {default_lang}]
135    [LANGUAGE_FIELD {lang_field}]
136    [SCORE {default_score}]
137    [SCORE_FIELD {score_field}]
138    [PAYLOAD_FIELD {payload_field}]
139    [MAXTEXTFIELDS] [TEMPORARY {seconds}] [NOOFFSETS] [NOHL] [NOFIELDS] [NOFREQS] [SKIPINITIALSCAN]
140    [STOPWORDS {num} {stopword} ...]
141    SCHEMA {field} [TEXT [NOSTEM] [WEIGHT {weight}] [PHONETIC {matcher}] | NUMERIC | GEO | TAG [SEPARATOR {sep}] ] [SORTABLE][NOINDEX] ...
142    '''
143
144    # Create the index
145    cmd_create = ['ft.create', index_name, 'ON', 'HASH' if isHash else 'JSON']
146    # With prefixes
147    if num_prefs > 0:
148        cmd_create.extend(['prefix', num_prefs])
149        for i in range(1, num_prefs + 1):
150            cmd_create.append('pref' + str(i) + ":")
151    # With filter
152    cmd_create.extend(['filter', 'startswith(@__key' + ', "p")'])
153    # With language
154    cmd_create.extend(['language', 'spanish'])
155    # With language field
156    cmd_create.extend(['language_field', 'myLang'])
157    # With score field
158    cmd_create.extend(['score_field', 'myScore'])
159    # With payload field
160    cmd_create.extend(['payload_field', 'myPayload'])
161    # With maxtextfields
162    cmd_create.append('maxtextfields')
163    # With stopwords
164    cmd_create.extend(['stopwords', 3, 'stop', 'being', 'silly'])
165    # With schema
166    cmd_create.extend(['schema',
167                       get_identifier('field1', isHash), 'as', 'f1', 'text', 'nostem', 'weight', '0.2', 'phonetic',
168                       'dm:es', 'sortable',
169                       get_identifier('field2', isHash), 'as', 'f2', 'numeric', 'sortable',
170                       get_identifier('field3', isHash), 'as', 'f3', 'geo',
171                       get_identifier('field4', isHash), 'as', 'f4', 'tag', 'separator', ';',
172                       get_identifier('field11', isHash), 'text', 'nostem',
173                       get_identifier('field12', isHash), 'numeric', 'noindex',
174                       get_identifier('field13', isHash), 'geo',
175                       get_identifier('field14', isHash), 'tag', 'noindex',
176                       get_identifier('myLang', isHash), 'text',
177                       get_identifier('myScore', isHash), 'numeric',
178                       ])
179    conn = getConnectionByEnv(env)
180    env.assertOk(conn.execute_command(*cmd_create))
181    waitForIndex(conn, index_name)
182    env.assertOk(conn.execute_command('ft.synupdate', index_name, 'syngrp1', 'pelota', 'bola', 'balón'))
183    env.assertOk(conn.execute_command('ft.synupdate', index_name, 'syngrp2', 'jugar', 'tocar'))
184
185    # Add keys
186    for i in range(1, num_keys + 1):
187        if isHash:
188            cmd = ['hset', 'pref' + str(i) + ":k" + str(i) + '_' + rand_num(5) + key_suffix, 'a' + rand_name(5), rand_num(2), 'b' + rand_name(5), rand_num(3)]
189            env.assertEqual(conn.execute_command(*cmd), 2L)
190        else:
191            cmd = ['json.set', 'pref' + str(i) + ":k" + str(i) + '_' + rand_num(5) + key_suffix, '$', r'{"field1":"' + rand_name(5) + r'", "field2":' + rand_num(3) + r'}']
192            env.assertOk(conn.execute_command(*cmd))
193
194
195def testCreateIndexRdbFiles(env):
196    create_indices(env, 'redisearch_2.2.0.rdb', 'idxSearch', True, False)
197
198@no_msan
199def testCreateIndexRdbFilesWithJSON(env):
200    create_indices(env, 'rejson_2.0.0.rdb', 'idxJson', False, True)
201    create_indices(env, 'redisearch_2.2.0_rejson_2.0.0.rdb', 'idxSearchJson', True, True)
202
203
204class Connection(object):
205    def __init__(self, sock, bufsize=4096, underlying_sock=None):
206        self.sock = sock
207        self.sockf = sock.makefile('rwb', bufsize)
208        self.closed = False
209        self.peer_closed = False
210        self.underlying_sock = underlying_sock
211
212    def close(self):
213        if not self.closed:
214            self.closed = True
215            self.sockf.close()
216            self.sock.close()
217            self.sockf = None
218
219    def is_close(self, timeout=2):
220        if self.closed:
221            return True
222        try:
223            with TimeLimit(timeout):
224                return self.read(1) == ''
225        except Exception:
226            return False
227
228    def flush(self):
229        self.sockf.flush()
230
231    def get_address(self):
232        return self.sock.getsockname()[0]
233
234    def get_port(self):
235        return self.sock.getsockname()[1]
236
237    def read(self, bytes):
238        return self.sockf.read(bytes)
239
240    def read_at_most(self, bytes, timeout=0.01):
241        self.sock.settimeout(timeout)
242        return self.sock.recv(bytes)
243
244    def send(self, data):
245        self.sockf.write(data)
246        self.sockf.flush()
247
248    def readline(self):
249        return self.sockf.readline()
250
251    def send_bulk_header(self, data_len):
252        self.sockf.write('$%d\r\n' % data_len)
253        self.sockf.flush()
254
255    def send_bulk(self, data):
256        self.sockf.write('$%d\r\n%s\r\n' % (len(data), data))
257        self.sockf.flush()
258
259    def send_status(self, data):
260        self.sockf.write('+%s\r\n' % data)
261        self.sockf.flush()
262
263    def send_error(self, data):
264        self.sockf.write('-%s\r\n' % data)
265        self.sockf.flush()
266
267    def send_integer(self, data):
268        self.sockf.write(':%u\r\n' % data)
269        self.sockf.flush()
270
271    def send_mbulk(self, data):
272        self.sockf.write('*%d\r\n' % len(data))
273        for elem in data:
274            self.sockf.write('$%d\r\n%s\r\n' % (len(elem), elem))
275        self.sockf.flush()
276
277    def read_mbulk(self, args_count=None):
278        if args_count is None:
279            line = self.readline()
280            if not line:
281                self.peer_closed = True
282            if not line or line[0] != '*':
283                self.close()
284                return None
285            try:
286                args_count = int(line[1:])
287            except ValueError:
288                raise Exception('Invalid mbulk header: %s' % line)
289        data = []
290        for arg in range(args_count):
291            data.append(self.read_response())
292        return data
293
294    def read_request(self):
295        line = self.readline()
296        if not line:
297            self.peer_closed = True
298            self.close()
299            return None
300        if line[0] != '*':
301            return line.rstrip().split()
302        try:
303            args_count = int(line[1:])
304        except ValueError:
305            raise Exception('Invalid mbulk request: %s' % line)
306        return self.read_mbulk(args_count)
307
308    def read_request_and_reply_status(self, status):
309        req = self.read_request()
310        if not req:
311            return
312        self.current_request = req
313        self.send_status(status)
314
315    def wait_until_writable(self, timeout=None):
316        try:
317            gevent.socket.wait_write(self.sockf.fileno(), timeout)
318        except gevent.socket.error:
319            return False
320        return True
321
322    def wait_until_readable(self, timeout=None):
323        if self.closed:
324            return False
325        try:
326            gevent.socket.wait_read(self.sockf.fileno(), timeout)
327        except gevent.socket.error:
328            return False
329        return True
330
331    def read_response(self):
332        line = self.readline()
333        if not line:
334            self.peer_closed = True
335            self.close()
336            return None
337        if line[0] == '+':
338            return line.rstrip()
339        elif line[0] == ':':
340            try:
341                return int(line[1:])
342            except ValueError:
343                raise Exception('Invalid numeric value: %s' % line)
344        elif line[0] == '-':
345            return line.rstrip()
346        elif line[0] == '$':
347            try:
348                bulk_len = int(line[1:])
349            except ValueError:
350                raise Exception('Invalid bulk response: %s' % line)
351            if bulk_len == -1:
352                return None
353            data = self.sockf.read(bulk_len + 2)
354            if len(data) < bulk_len:
355                self.peer_closed = True
356                self.close()
357            return data[:bulk_len]
358        elif line[0] == '*':
359            try:
360                args_count = int(line[1:])
361            except ValueError:
362                raise Exception('Invalid mbulk response: %s' % line)
363            return self.read_mbulk(args_count)
364        else:
365            raise Exception('Invalid response: %s' % line)
366
367
368class ShardMock:
369    def __init__(self, env):
370        self.env = env
371        self.new_conns = gevent.queue.Queue()
372
373    def _handle_conn(self, sock, client_addr):
374        conn = Connection(sock)
375        self.new_conns.put(conn)
376
377    def __enter__(self):
378        self.server_port = None
379        if not self.StartListening(port=0, attempts=10) and not self.StartListening(port=random.randint(55000, 57000)):
380            raise Exception("%s StartListening failed" % self.__class__.__name__)
381        return self
382
383    def __exit__(self, type, value, traceback):
384        self.stream_server.stop()
385
386    def GetConnection(self, timeout=None):
387        conn = self.new_conns.get(block=True, timeout=timeout)
388        return conn
389
390    def GetCleanConnection(self):
391        return self.new_conns.get(block=True, timeout=None)
392
393    def StopListening(self):
394        self.stream_server.stop()
395        self.server_port = None
396
397    def StartListening(self, port, attempts=1):
398        for i in range(1, attempts + 1):
399            self.stream_server = gevent.server.StreamServer(('127.0.0.1', port), self._handle_conn)
400            try:
401                self.stream_server.start()
402            except Exception as e:
403                self.env.assertEqual(self.server_port, None, message='%s: StartListening(%d/%d) %d -> %s' % (
404                    self.__class__.__name__, i, attempts, port, e.strerror))
405                continue
406            self.server_port = self.stream_server.address[1]
407            self.env.assertNotEqual(self.server_port, None, message='%s: StartListening(%d/%d) %d -> %d' % (
408                self.__class__.__name__, i, attempts, port, self.server_port))
409            return True
410        else:
411            return False
412
413
414class Debug:
415    def __init__(self, enabled=False):
416        self.enabled = enabled
417        self.clear()
418
419    def clear(self):
420        self.dbg_str = ''
421        self.dbg_ndx = -1
422
423    def __call__(self, f):
424        if self.enabled:
425            def f_with_debug(*args, **kwds):
426                self.print_bytes_incremental(args[0], args[1], args[2], f.__name__)
427                if len(args[1]) == args[2]:
428                    self.clear()
429                f(*args, **kwds)
430
431            return f_with_debug
432        else:
433            return f
434
435    def print_bytes_incremental(self, env, data, total_len, name):
436        # For debugging: print the binary content before it is sent
437        byte_count_width = len(str(total_len))
438        if len(data):
439            ch = data[self.dbg_ndx]
440            printable_ch = ch
441            if ord(ch) < 32 or ord(ch) == 127:
442                printable_ch = '\?'
443        else:
444            ch = '\0'
445            printable_ch = '\!'  # no data (zero length)
446        self.dbg_str = '{} {:=0{}n}:{:<2}({:<3})'.format(self.dbg_str, self.dbg_ndx, byte_count_width, printable_ch, ord(ch))
447        if not (self.dbg_ndx + 1) % 10:
448            self.dbg_str = self.dbg_str + "\n"
449        self.dbg_ndx = self.dbg_ndx + 1
450
451        env.debugPrint(name + ': %d out of %d \n%s' % (self.dbg_ndx, total_len, self.dbg_str))
452
453@no_msan
454def testShortReadSearch(env):
455    if CODE_COVERAGE or SANITIZER:
456        env.skip()  # FIXME: enable coverage test
457
458    env.skipOnCluster()
459    if env.env.endswith('existing-env') and os.environ.get('CI'):
460        env.skip()
461
462    if OSNICK == 'macos':
463        env.skip()
464
465    seed = str(time.time())
466    env.assertNotEqual(seed, None, message='random seed ' + seed)
467    random.seed(seed)
468
469    try:
470        temp_dir = tempfile.mkdtemp(prefix="short-read_")
471        # TODO: In python3 use "with tempfile.TemporaryDirectory()"
472        if not downloadFiles(temp_dir):
473            env.assertTrue(False, "downloadFiles failed")
474
475        for f, expected_index in zip(RDBS, RDBS_EXPECTED_INDICES):
476            name, ext = os.path.splitext(f)
477            if ext == '.zip':
478                f = name
479            fullfilePath = os.path.join(temp_dir, f)
480            env.assertNotEqual(fullfilePath, None, message='testShortReadSearch')
481            sendShortReads(env, fullfilePath, expected_index)
482    finally:
483        shutil.rmtree(temp_dir)
484
485
486def sendShortReads(env, rdb_file, expected_index):
487    # Add some initial content (index+keys) to test backup/restore/discard when short read fails
488    # When entire rdb is successfully sent and loaded (from swapdb) - backup should be discarded
489    env.assertCmdOk('replicaof', 'no', 'one')
490    env.flush()
491    add_index(env, True,  'idxBackup1', 'a', 5, 10)
492    add_index(env, False, 'idxBackup2', 'b', 5, 10)
493
494    res = env.cmd('ft.search ', 'idxBackup1', '*', 'limit', '0', '0')
495    env.assertEqual(res[0], 5L)
496    res = env.cmd('ft.search ', 'idxBackup2', '*', 'limit', '0', '0')
497    env.assertEqual(res[0], 5L)
498
499    with open(rdb_file, mode='rb') as f:
500        full_rdb = f.read()
501    total_len = len(full_rdb)
502
503    env.assertGreater(total_len, SHORT_READ_BYTES_DELTA)
504    r = range(0, total_len + 1, SHORT_READ_BYTES_DELTA)
505    if (total_len % SHORT_READ_BYTES_DELTA) != 0:
506        r = r + range(total_len, total_len + 1)
507    for b in r:
508        rdb = full_rdb[0:b]
509        runShortRead(env, rdb, total_len, expected_index)
510
511
512@Debug(False)
513def runShortRead(env, data, total_len, expected_index):
514    with ShardMock(env) as shardMock:
515        # For debugging: if adding breakpoints in redis,
516        # In order to avoid closing the connection, uncomment the following line
517        # res = env.cmd('CONFIG', 'SET', 'timeout', '0')
518
519        # Notice: Do not use env.expect in this test
520        # (since it is sending commands to redis and in this test we need to follow strict hand-shaking)
521        res = env.cmd('CONFIG', 'SET', 'repl-diskless-load', 'swapdb')
522        res = env.cmd('replicaof', '127.0.0.1', shardMock.server_port)
523        env.assertTrue(res)
524        conn = shardMock.GetConnection()
525        # Perform hand-shake with replica
526        res = conn.read_request()
527        env.assertEqual(res, ['PING'])
528        conn.send_status('PONG')
529
530        max_attempt = 100
531        res = conn.read_request()
532        while max_attempt > 0:
533            if res[0] == 'REPLCONF':
534                conn.send_status('OK')
535            else:
536                break
537            max_attempt = max_attempt - 1
538            res = conn.read_request()
539
540        # Send RDB to replica
541        some_guid = 'af4e30b5d14dce9f96fbb7769d0ec794cdc0bbcc'
542        conn.send_status('FULLRESYNC ' + some_guid + ' 0')
543        is_shortread = total_len != len(data)
544        if is_shortread:
545            conn.send('$%d\r\n%s' % (total_len, data))
546        else:
547            # Allow to succeed with a full read (protocol expects a trailing '\r\n')
548            conn.send('$%d\r\n%s\r\n' % (total_len, data))
549        conn.flush()
550
551        # Close during replica is waiting for more RDB data (so replica will re-connect to master)
552        conn.close()
553
554        # Make sure replica did not crash
555        res = env.cmd('PING')
556        env.assertEqual(res, True)
557        conn = shardMock.GetConnection(timeout=3)
558        env.assertNotEqual(conn, None)
559
560        res = env.cmd('ft._list')
561        if is_shortread:
562            # Verify original data, that existed before the failed attempt to short-read, is restored
563            env.assertEqual(res, ['idxBackup2', 'idxBackup1'])
564            res = env.cmd('ft.search ', 'idxBackup1', '*', 'limit', '0', '0')
565            env.assertEqual(res[0], 5L)
566            res = env.cmd('ft.search ', 'idxBackup2', '*', 'limit', '0', '0')
567            env.assertEqual(res[0], 5L)
568        else:
569            # Verify new data was loaded and the backup was discarded
570            # TODO: How to verify internal backup was indeed discarded
571            env.assertEqual(len(res), expected_index.count)
572            r = re.compile(expected_index.pattern)
573            expected_indices = list(filter(lambda x: r.match(x), res))
574            env.assertEqual(len(expected_indices), expected_index.count)
575            for ind, expected_result_count in zip(expected_indices, expected_index.search_result_count):
576                res = env.cmd('ft.search ', ind, '*', 'limit', '0', '0')
577                env.assertEqual(res[0], expected_result_count)
578
579        # Exit (avoid read-only exception with flush on replica)
580        env.assertCmdOk('replicaof', 'no', 'one')
581