1#!/usr/local/bin/python3.8
2# vim:fileencoding=utf-8
3# License: GPLv3 Copyright: 2019, Kovid Goyal <kovid at kovidgoyal.net>
4
5import random
6import string
7
8from . import BaseTest
9
10
11def run(input_data, query, **kw):
12    kw['threads'] = kw.get('threads', 1)
13    mark = kw.pop('mark', False)
14    from kittens.choose.match import match
15    mark_before = mark_after = ''
16    if mark:
17        if mark is True:
18            mark_before, mark_after = '\033[32m', '\033[39m'
19        else:
20            mark_before = mark_after = mark
21    kw['mark_before'], kw['mark_after'] = mark_before, mark_after
22    return match(input_data, query, **kw)
23
24
25class TestMatcher(BaseTest):
26
27    def run_matcher(self, *args, **kwargs):
28        result = run(*args, **kwargs)
29        return result
30
31    def basic_test(self, inp, query, out, **k):
32        result = self.run_matcher(inp, query, **k)
33        if out is not None:
34            if hasattr(out, 'splitlines'):
35                out = list(filter(None, out.split(k.get('delimiter', '\n'))))
36            self.assertEqual(list(out), result)
37        return out
38
39    def test_filtering(self):
40        ' Non matching entries must be removed '
41        self.basic_test('test\nxyz', 'te', 'test')
42        self.basic_test('abc\nxyz', 'ba', '')
43        self.basic_test('abc\n123', 'abc', 'abc')
44
45    def test_case_insensitive(self):
46        self.basic_test('test\nxyz', 'Te', 'test')
47        self.basic_test('test\nxyz', 'XY', 'xyz')
48        self.basic_test('test\nXYZ', 'xy', 'XYZ')
49        self.basic_test('test\nXYZ', 'mn', '')
50
51    def test_marking(self):
52        ' Marking of matched characters '
53        self.basic_test(
54            'test\nxyz',
55            'ts',
56            '\x1b[32mt\x1b[39me\x1b[32ms\x1b[39mt',
57            mark=True)
58
59    def test_positions(self):
60        ' Output of positions '
61        self.basic_test('abc\nac', 'ac', '0,1:ac\n0,2:abc', positions=True)
62        self.basic_test('abc\nv', 'a', '0:abc', positions=True)
63
64    def test_delimiter(self):
65        ' Test using a custom line delimiter '
66        self.basic_test('abc\n21ac', 'ac', 'ac1abc\n2', delimiter='1')
67
68    def test_scoring(self):
69        ' Scoring algorithm '
70        # Match at start
71        self.basic_test('archer\nelementary', 'e', 'elementary\narcher')
72        # Match at level factor
73        self.basic_test('xxxy\nxx/y', 'y', 'xx/y\nxxxy')
74        # CamelCase
75        self.basic_test('xxxy\nxxxY', 'y', 'xxxY\nxxxy')
76        # Total length
77        self.basic_test('xxxya\nxxxy', 'y', 'xxxy\nxxxya')
78        # Distance
79        self.basic_test('abbc\nabc', 'ac', 'abc\nabbc')
80        # Extreme chars
81        self.basic_test('xxa\naxx', 'a', 'axx\nxxa')
82        # Highest score
83        self.basic_test('xa/a', 'a', 'xa/|a|', mark='|')
84
85    def test_threading(self):
86        ' Test matching on a large data set with different number of threads '
87        alphabet = string.ascii_lowercase + string.ascii_uppercase + string.digits
88
89        def random_word():
90            sz = random.randint(2, 10)
91            return ''.join(random.choice(alphabet) for x in range(sz))
92        words = [random_word() for i in range(400)]
93
94        def random_item():
95            num = random.randint(2, 7)
96            return '/'.join(random.choice(words) for w in range(num))
97
98        data = '\n'.join(random_item() for x in range(25123))
99
100        for threads in range(4):
101            self.basic_test(data, 'foo', None, threads=threads)
102