1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3# Originally based on ./sam.py
4from __future__ import print_function
5from unicodedata import normalize
6import locale
7locale.setlocale(locale.LC_ALL, ('en_US', 'UTF-8'))
8
9import optparse
10import sys
11import os
12import re
13
14sys.path.insert(0, "bin/python")
15import samba
16from samba.tests.subunitrun import SubunitOptions, TestProgram
17from samba.compat import cmp_fn
18from samba.compat import cmp_to_key_fn
19from samba.compat import text_type
20import samba.getopt as options
21
22from samba.auth import system_session
23import ldb
24from samba.samdb import SamDB
25
26parser = optparse.OptionParser("sort.py [options] <host>")
27sambaopts = options.SambaOptions(parser)
28parser.add_option_group(sambaopts)
29parser.add_option_group(options.VersionOptions(parser))
30# use command line creds if available
31credopts = options.CredentialsOptions(parser)
32parser.add_option_group(credopts)
33subunitopts = SubunitOptions(parser)
34parser.add_option_group(subunitopts)
35
36parser.add_option('--elements', type='int', default=33,
37                  help="use this many elements in the tests")
38
39opts, args = parser.parse_args()
40
41if len(args) < 1:
42    parser.print_usage()
43    sys.exit(1)
44
45datadir = os.getenv("DATA_DIR", None)
46if not datadir:
47    print("Please specify the location of the sort expected results with env variable DATA_DIR")
48    sys.exit(1)
49
50host = os.getenv("SERVER", None)
51if not host:
52    print("Please specify the host with env variable SERVER")
53    sys.exit(1)
54
55lp = sambaopts.get_loadparm()
56creds = credopts.get_credentials(lp)
57
58
59def norm(x):
60    if not isinstance(x, text_type):
61        x = x.decode('utf8')
62    return normalize('NFKC', x).upper()
63
64
65# Python, Windows, and Samba all sort the following sequence in
66# drastically different ways. The order here is what you get from
67# Windows2012R2.
68FIENDISH_TESTS = [' ', ' e', '\t-\t', '\n\t\t', '!@#!@#!', '¼', '¹', '1',
69                  '1/4', '1⁄4', '1\xe2\x81\x845', '3', 'abc', 'fo\x00od',
70
71                  # Here we also had '\x00food', but that seems to sort
72                  # non-deterministically on Windows vis-a-vis 'fo\x00od'.
73
74                  'kōkako', 'ŋđ¼³ŧ “«đð', 'ŋđ¼³ŧ“«đð',
75                  'sorttest', 'sorttēst11,', 'śorttest2', 'śoRttest2',
76                  'ś-o-r-t-t-e-s-t-2', 'soRTTēst2,', 'ṡorttest4', 'ṡorttesT4',
77                  'sörttest-5', 'sÖrttest-5', 'so-rttest7,', '桑巴']
78
79
80class BaseSortTests(samba.tests.TestCase):
81    avoid_tricky_sort = False
82    maxDiff = 2000
83
84    def create_user(self, i, n, prefix='sorttest', suffix='', attrs=None,
85                    tricky=False):
86        name = "%s%d%s" % (prefix, i, suffix)
87        user = {
88            'cn': name,
89            "objectclass": "user",
90            'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
91            "roomNumber": "%sb\x00c" % (n - i),
92            # with python3 re.sub(r'[^\w,.]', repl, string) doesn't
93            # work as expected with unicode as value for carLicense
94            "carLicense": "XXXXXXXXX" if self.avoid_tricky_sort else "后来经",
95            "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
96            "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
97            "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
98            "flags": str(i * (n - i)),
99            "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
100                                            ' 3z}'[i & 3],
101                                            '"@'[i & 1],),
102            "comment": "Favourite colour is %d" % (n % (i + 1)),
103        }
104
105        if self.avoid_tricky_sort:
106            # We are not even going to try passing tests that assume
107            # some kind of Unicode awareness.
108            for k, v in user.items():
109                user[k] = re.sub(r'[^\w,.]', 'X', v)
110        else:
111            # Add some even trickier ones!
112            fiendish_index = i % len(FIENDISH_TESTS)
113            user.update({
114                # Sort doesn't look past a NUL byte.
115                "photo": "\x00%d" % (n - i),
116                "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
117                                                                 chr(i & 255),
118                                                                 i),
119                "displayNamePrintable": "%d\x00%c" % (i, i & 255),
120                "adminDisplayName": "%d\x00b" % (n - i),
121                "title": "%d%sb" % (n - i, '\x00' * i),
122
123                # Names that vary only in case. Windows returns
124                # equivalent addresses in the order they were put
125                # in ('a st', 'A st',...). We don't check that.
126                "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
127
128                "streetAddress": FIENDISH_TESTS[fiendish_index],
129                "postalAddress": FIENDISH_TESTS[-fiendish_index],
130            })
131
132        if attrs is not None:
133            user.update(attrs)
134
135        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)
136
137        self.users.append(user)
138        self.ldb.add(user)
139        return user
140
141    def setUp(self):
142        super(BaseSortTests, self).setUp()
143        self.ldb = SamDB(host, credentials=creds,
144                         session_info=system_session(lp), lp=lp)
145
146        self.base_dn = self.ldb.domain_dn()
147        self.ou = "ou=sort,%s" % self.base_dn
148        if False:
149            try:
150                self.ldb.delete(self.ou, ['tree_delete:1'])
151            except ldb.LdbError as e:
152                print("tried deleting %s, got error %s" % (self.ou, e))
153
154        self.ldb.add({
155            "dn": self.ou,
156            "objectclass": "organizationalUnit"})
157        self.users = []
158        n = opts.elements
159        for i in range(n):
160            self.create_user(i, n)
161
162        attrs = set(self.users[0].keys()) - set([
163            'objectclass', 'dn'])
164        self.binary_sorted_keys = attrs.intersection(['audio',
165                                                      'photo',
166                                                      "msTSExpireDate4",
167                                                      'serialNumber',
168                                                      "displayNamePrintable"])
169
170        self.numeric_sorted_keys = attrs.intersection(['flags',
171                                                       'accountExpires'])
172
173        self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])
174
175        self.int64_keys = set(['accountExpires'])
176
177        self.locale_sorted_keys = [x for x in attrs if
178                                   x not in (self.binary_sorted_keys |
179                                             self.numeric_sorted_keys)]
180
181        self.expected_results = {}
182        self.expected_results_binary = {}
183
184        for k in self.binary_sorted_keys:
185            forward = sorted((x[k] for x in self.users))
186            reverse = list(reversed(forward))
187            self.expected_results_binary[k] = (forward, reverse)
188
189        # FYI: Expected result data was generated from the old
190        # code that was manually sorting (while executing with
191        # python2)
192        # The resulting data was injected into the data file with
193        # code similar to:
194        #
195        # for k in self.expected_results:
196        #     f.write("%s = %s\n" % (k,  repr(self.expected_results[k][0])))
197
198        f = open(self.results_file, "r")
199        for line in f:
200            if len(line.split('=', 1)) == 2:
201                key = line.split('=', 1)[0].strip()
202                value = line.split('=', 1)[1].strip()
203                if value.startswith('['):
204                    import ast
205                    fwd_list = ast.literal_eval(value)
206                    rev_list = list(reversed(fwd_list))
207                    self.expected_results[key] = (fwd_list, rev_list)
208        f.close()
209    def tearDown(self):
210        super(BaseSortTests, self).tearDown()
211        self.ldb.delete(self.ou, ['tree_delete:1'])
212
213    def _test_server_sort_default(self):
214        attrs = self.locale_sorted_keys
215
216        for attr in attrs:
217            for rev in (0, 1):
218                res = self.ldb.search(self.ou,
219                                      scope=ldb.SCOPE_ONELEVEL, attrs=[attr],
220                                      controls=["server_sort:1:%d:%s" %
221                                                (rev, attr)])
222                self.assertEqual(len(res), len(self.users))
223
224                expected_order = self.expected_results[attr][rev]
225                received_order = [norm(x[attr][0]) for x in res]
226                if expected_order != received_order:
227                    print(attr, ['forward', 'reverse'][rev])
228                    print("expected", expected_order)
229                    print("received", received_order)
230                    print("unnormalised:", [x[attr][0] for x in res])
231                    print("unnormalised: «%s»" % '»  «'.join(str(x[attr][0])
232                                                             for x in res))
233                self.assertEquals(expected_order, received_order)
234
235    def _test_server_sort_binary(self):
236        for attr in self.binary_sorted_keys:
237            for rev in (0, 1):
238                res = self.ldb.search(self.ou,
239                                      scope=ldb.SCOPE_ONELEVEL, attrs=[attr],
240                                      controls=["server_sort:1:%d:%s" %
241                                                (rev, attr)])
242
243                self.assertEqual(len(res), len(self.users))
244                expected_order = self.expected_results_binary[attr][rev]
245                received_order = [str(x[attr][0]) for x in res]
246                if expected_order != received_order:
247                    print(attr)
248                    print(expected_order)
249                    print(received_order)
250                self.assertEquals(expected_order, received_order)
251
252    def _test_server_sort_us_english(self):
253        # Windows doesn't support many matching rules, but does allow
254        # the locale specific sorts -- if it has the locale installed.
255        # The most reliable locale is the default US English, which
256        # won't change the sort order.
257
258        for lang, oid in [('en_US', '1.2.840.113556.1.4.1499'),
259                          ]:
260
261            for attr in self.locale_sorted_keys:
262                for rev in (0, 1):
263                    res = self.ldb.search(self.ou,
264                                          scope=ldb.SCOPE_ONELEVEL,
265                                          attrs=[attr],
266                                          controls=["server_sort:1:%d:%s:%s" %
267                                                    (rev, attr, oid)])
268
269                    self.assertTrue(len(res) == len(self.users))
270                    expected_order = self.expected_results[attr][rev]
271                    received_order = [norm(x[attr][0]) for x in res]
272                    if expected_order != received_order:
273                        print(attr, lang)
274                        print(['forward', 'reverse'][rev])
275                        print("expected: ", expected_order)
276                        print("received: ", received_order)
277                        print("unnormalised:", [x[attr][0] for x in res])
278                        print("unnormalised: «%s»" % '»  «'.join(str(x[attr][0])
279                                                                 for x in res))
280
281                    self.assertEquals(expected_order, received_order)
282
283    def _test_server_sort_different_attr(self):
284
285        def cmp_locale(a, b):
286            return locale.strcoll(a[0], b[0])
287
288        def cmp_binary(a, b):
289            return cmp_fn(a[0], b[0])
290
291        def cmp_numeric(a, b):
292            return cmp_fn(int(a[0]), int(b[0]))
293
294        # For testing simplicity, the attributes in here need to be
295        # unique for each user. Otherwise there are multiple possible
296        # valid answers.
297        sort_functions = {'cn': cmp_binary,
298                          "employeeNumber": cmp_locale,
299                          "accountExpires": cmp_numeric,
300                          "msTSExpireDate4": cmp_binary}
301        attrs = list(sort_functions.keys())
302        attr_pairs = zip(attrs, attrs[1:] + attrs[:1])
303
304        for sort_attr, result_attr in attr_pairs:
305            forward = sorted(((norm(x[sort_attr]), norm(x[result_attr]))
306                             for x in self.users),
307                             key=cmp_to_key_fn(sort_functions[sort_attr]))
308            reverse = list(reversed(forward))
309
310            for rev in (0, 1):
311                res = self.ldb.search(self.ou,
312                                      scope=ldb.SCOPE_ONELEVEL,
313                                      attrs=[result_attr],
314                                      controls=["server_sort:1:%d:%s" %
315                                                (rev, sort_attr)])
316                self.assertEqual(len(res), len(self.users))
317                pairs = (forward, reverse)[rev]
318
319                expected_order = [x[1] for x in pairs]
320                received_order = [norm(x[result_attr][0]) for x in res]
321
322                if expected_order != received_order:
323                    print(sort_attr, result_attr, ['forward', 'reverse'][rev])
324                    print("expected", expected_order)
325                    print("received", received_order)
326                    print("unnormalised:", [x[result_attr][0] for x in res])
327                    print("unnormalised: «%s»" % '»  «'.join(str(x[result_attr][0])
328                                                             for x in res))
329                    print("pairs:", pairs)
330                    # There are bugs in Windows that we don't want (or
331                    # know how) to replicate regarding timestamp sorting.
332                    # Let's remind ourselves.
333                    if result_attr == "msTSExpireDate4":
334                        print('-' * 72)
335                        print("This test fails against Windows with the "
336                              "default number of elements (33).")
337                        print("Try with --elements=27 (or similar).")
338                        print('-' * 72)
339
340                self.assertEquals(expected_order, received_order)
341                for x in res:
342                    if sort_attr in x:
343                        self.fail('the search for %s should not return %s' %
344                                  (result_attr, sort_attr))
345
346
347class SimpleSortTests(BaseSortTests):
348    avoid_tricky_sort = True
349    results_file = os.path.join(datadir, "simplesort.expected")
350    def test_server_sort_different_attr(self):
351        self._test_server_sort_different_attr()
352
353    def test_server_sort_default(self):
354        self._test_server_sort_default()
355
356    def test_server_sort_binary(self):
357        self._test_server_sort_binary()
358
359    def test_server_sort_us_english(self):
360        self._test_server_sort_us_english()
361
362
363class UnicodeSortTests(BaseSortTests):
364    avoid_tricky_sort = False
365    results_file = os.path.join(datadir, "unicodesort.expected")
366
367    def test_server_sort_default(self):
368        self._test_server_sort_default()
369
370    def test_server_sort_us_english(self):
371        self._test_server_sort_us_english()
372
373    def test_server_sort_different_attr(self):
374        self._test_server_sort_different_attr()
375
376
377if "://" not in host:
378    if os.path.isfile(host):
379        host = "tdb://%s" % host
380    else:
381        host = "ldap://%s" % host
382