1""" Tests for btrie.c
2"""
3
4import os
5import re
6from subprocess import Popen, PIPE
7from tempfile import TemporaryFile, NamedTemporaryFile
8import unittest
9
10from rbldnsd import Rbldnsd, ZoneFile
11
12__all__ = [
13    'Test_coalesce_lc_node',
14    'Test_shorten_lc_node',
15    'Test_convert_lc_node_1',
16    'Test_convert_lc_node',
17    'Test_insert_lc_node',
18    'Test_init_tbm_node',
19    'Test_add_to_trie',
20    'Test_search_trie',
21    ]
22
23def deduce_pointer_size(makefile='./Makefile'):
24    """ Deduce the pointer size (in the current compilation environment)
25    """
26    with file(makefile) as f:
27        make_vars = dict(
28            m.groups()
29            for m in (re.match(r'\s*(\w+)\s*=\s*(.*?)\s*\Z', line)
30                      for line in f)
31            if m is not None)
32    cc = make_vars['CC']
33    cflags = make_vars['CFLAGS']
34
35    test_c = NamedTemporaryFile(suffix=".c")
36    test_c.write(r'''
37#include <stdio.h>
38#ifndef __SIZEOF_POINTER__
39# define __SIZEOF_POINTER__ sizeof(void *)
40#endif
41int main () {
42  printf("%u\n", (unsigned)__SIZEOF_POINTER__);
43  return 0;
44}
45''')
46    test_c.flush()
47    src = test_c.name
48
49    try:
50        proc = Popen("%(cc)s %(cflags)s -o %(src)s.bin %(src)s && %(src)s.bin"
51                     % locals(),
52                     shell=True, stdout=PIPE)
53        output = proc.stdout.read()
54        if proc.wait() != 0:
55            raise RuntimeError("test prog exited with code %d"
56                               % proc.returncode)
57        return int(output)
58    finally:
59        try:
60            os.unlink(src + '.bin')
61        except:
62            pass
63
64try:
65    sizeof_pointer = deduce_pointer_size()
66except Exception:
67    print "Can not deduce size of pointer. Assuming pointer size of 8."
68    sizeof_pointer = 8
69
70if sizeof_pointer == 8:
71    STRIDE = 5
72    LC_BYTES_PER_NODE = 7
73elif sizeof_pointer == 4:
74    STRIDE = 4
75    LC_BYTES_PER_NODE = 3
76else:
77    raise RuntimeError("Unsupported pointer size (%d)" % sizeof_pointer)
78
79def pad_prefix(prefix, plen):
80    """Pad prefix on the right with zeros to a full 128 bits
81    """
82    if not isinstance(prefix, (int, long)):
83        raise TypeError("prefix must be an integer")
84    if not 0 <= int(plen) <= 128:
85        raise ValueError("plen out of range")
86    if not 0 <= prefix < (1 << plen):
87        raise ValueError("prefix out of range")
88    return prefix << (128 - plen)
89
90class BTrie(object):
91    """ A class to construct and perform lookups on a btrie.
92
93    Since we do not have python bindings for btrie, we do this in a
94    roundabout way by running rbldnsd with a single ip6trie dataset,
95    and then querying the rbldnsd to perform the lookup.
96
97    """
98    def __init__(self, prefixes, **kwargs):
99        self.rbldnsd = Rbldnsd(**kwargs)
100        zonedata = (self._zone_entry(*prefix) for prefix in prefixes)
101        self.rbldnsd.add_dataset('ip6trie', ZoneFile(zonedata))
102
103    def __enter__(self):
104        self.rbldnsd.__enter__()
105        return self
106
107    def __exit__(self, exc_type, exc_value, exc_tb):
108        return self.rbldnsd.__exit__(exc_type, exc_value, exc_tb)
109
110    def lookup(self, prefix, plen):
111        prefix = pad_prefix(prefix, plen)
112        nibbles = '.'.join("%x" % ((prefix >> n) & 0x0f)
113                           for n in range(0, 128, 4))
114        return self.rbldnsd.query(nibbles + '.example.com')
115
116    def _zone_entry(self, prefix, plen, data):
117        prefix = pad_prefix(prefix, plen)
118        ip6addr = ':'.join("%x" % ((prefix >> n) & 0xffff)
119                           for n in range(112, -16, -16))
120        return "%s/%u :1:%s" % (ip6addr, plen, data)
121
122class CaptureOutput(object):
123    def __init__(self):
124        self._file = TemporaryFile()
125
126    def fileno(self):
127        return self._file.fileno()
128
129    def __contains__(self, substr):
130        return substr in str(self)
131
132    def __str__(self):
133        self._file.seek(0, 0)
134        return self._file.read()
135
136class Test_coalesce_lc_node(unittest.TestCase):
137    def test_merge(self):
138        # test coverage of coalesce_lc_node
139        prefixes = [
140            # this prefix is too long for a single LC node
141            # but after we stick the TBM node at 0/0 it should
142            # just fit into a single LC extension path
143            (0, 8 * (LC_BYTES_PER_NODE + 1), "term"),
144            # Add a TBM node to shorten the above LC node
145            (0, (8 - STRIDE), "root"),
146            ]
147        with BTrie(prefixes) as btrie:
148            self.assertEqual(btrie.lookup(0, 8), "term")
149            self.assertEqual(btrie.lookup(1, 8), "root")
150
151    def test_steal_bits(self):
152        # test coverage of coalesce_lc_node
153        prefixes = [
154            # This prefix is too long for a single LC node.  After we
155            # stick the TBM node at 0/0 it should still be too long
156            # for a single LC node, but the upper LC node should steal
157            # bits from the terminal LC node.
158            (0, 8 * (LC_BYTES_PER_NODE + 1) + 1, "term"),
159            (0, (8 - STRIDE), "root"),
160            ]
161        with BTrie(prefixes) as btrie:
162            self.assertEqual(btrie.lookup(0, 0), "term")
163            self.assertEqual(btrie.lookup(1, 8), "root")
164
165class Test_shorten_lc_node(unittest.TestCase):
166    def test_steal_child(self):
167        # test coverage of coalesce_lc_node
168        prefixes = [
169            # this prefix is too long for a single LC node
170            # but after we stick the TBM node at 0/0 it should
171            # just fit into a single LC extension path
172            (0, 9, "tbm root"),
173            (0, 10, "term"),
174            # Add a TBM node to shorten the above LC node
175            (0, 9 - STRIDE, "root"),
176            ]
177        with BTrie(prefixes) as btrie:
178            self.assertEqual(btrie.lookup(0, 8), "term")
179            self.assertEqual(btrie.lookup(1, 8), "root")
180
181class Test_convert_lc_node_1(unittest.TestCase):
182    def test_left_child(self):
183        # test coverage of coalesce_lc_node
184        prefixes = [
185            # create TBM node at depth 1
186            (0, 2, "term"),
187            (0, 1, "tbm node"),
188            # promote to depth 0
189            (0, 0, "root"),
190            ]
191        with BTrie(prefixes) as btrie:
192            self.assertEqual(btrie.lookup(0, 0), "term")
193            self.assertEqual(btrie.lookup(1, 2), "tbm node")
194            self.assertEqual(btrie.lookup(1, 1), "root")
195
196    def test_right_child(self):
197        # test coverage of coalesce_lc_node
198        prefixes = [
199            (3, 2, "term"),
200            (1, 1, "tbm node"),
201            (0, 0, "root"),
202            ]
203        with BTrie(prefixes) as btrie:
204            self.assertEqual(btrie.lookup(3, 2), "term")
205            self.assertEqual(btrie.lookup(2, 2), "tbm node")
206            self.assertEqual(btrie.lookup(0, 1), "root")
207
208class Test_convert_lc_node(unittest.TestCase):
209    def test_left_child(self):
210        # test coverage of coalesce_lc_node
211        prefixes = [
212            # create TBM node at depth STRIDE - 1
213            (0, STRIDE, "term"),
214            (0, STRIDE - 1, "tbm node"),
215            # promote to depth 0
216            (0, 0, "root"),
217            ]
218        with BTrie(prefixes) as btrie:
219            self.assertEqual(btrie.lookup(0, STRIDE), "term")
220            self.assertEqual(btrie.lookup(1, STRIDE), "tbm node")
221            self.assertEqual(btrie.lookup(1, 1), "root")
222
223class Test_insert_lc_node(unittest.TestCase):
224    def test_insert_lc_len_1(self):
225        prefixes = [
226            # create TBM node at depth 1 with TBM extending path
227            (0, STRIDE + 2, "term"),
228            (0, STRIDE + 1, "tbm ext path"),
229            (0, 1, "tbm node"),
230            # promote to depth 0
231            (0, 0, "root"),
232            ]
233        with BTrie(prefixes) as btrie:
234            self.assertEqual(btrie.lookup(0, 0), "term")
235            self.assertEqual(btrie.lookup(1, STRIDE + 2), "tbm ext path")
236            self.assertEqual(btrie.lookup(1, 2), "tbm node")
237            self.assertEqual(btrie.lookup(1, 1), "root")
238
239    def test_extend_lc_tail_optimization(self):
240        prefixes = [
241            # create TBM node at depth 1 with LC extending path
242            (1, STRIDE + 2, "term"),
243            (0, 1, "tbm node"),
244            # promote to depth 0
245            (0, 0, "root"),
246            ]
247        with BTrie(prefixes) as btrie:
248            self.assertEqual(btrie.lookup(1, STRIDE + 2), "term")
249            self.assertEqual(btrie.lookup(0, 0), "tbm node")
250            self.assertEqual(btrie.lookup(1, 1), "root")
251
252    def test_coalesce_lc_tail(self):
253        prefixes = [
254            # create TBM node with LC extending path which starts
255            # at a byte boundary.
256            (0, 10, "term"),
257            (0, 8 - STRIDE, "tbm node"),
258            # promote one level
259            (0, 7 - STRIDE, "promoted"),
260            ]
261        with BTrie(prefixes) as btrie:
262            self.assertEqual(btrie.lookup(0, 0), "term")
263            self.assertEqual(btrie.lookup(1, 9 - STRIDE), "tbm node")
264            self.assertEqual(btrie.lookup(1, 8 - STRIDE), "promoted")
265
266class Test_init_tbm_node(unittest.TestCase):
267    def test_short_lc_children(self):
268        # this exercises the convert_lc_node calls in init_tbm_node()
269        prefixes = [
270            # create TBM node at depth 1, with two LC extending paths
271            # from a deep internal node
272            (0, 1, "tbm"),
273            (0, STRIDE + 2, "term0"),
274            (2, STRIDE + 2, "term1"),
275            # promote one level
276            (0, 0, "root"),
277            ]
278        with BTrie(prefixes) as btrie:
279            self.assertEqual(btrie.lookup(0, 0), "term0")
280            self.assertEqual(btrie.lookup(1, STRIDE + 1), "term1")
281            self.assertEqual(btrie.lookup(1, 2), "tbm")
282            self.assertEqual(btrie.lookup(1, 1), "root")
283
284    def test_long_lc_children(self):
285        # this exercises the shorten_lc_node calls in init_tbm_node()
286        prefixes = [
287            # create TBM node at depth 1, with two LC extending paths
288            # from a deep internal node
289            (0, 1, "tbm"),
290            (0, STRIDE + 9, "term0"),
291            (0x100, STRIDE + 9, "term1"),
292            # promote one level
293            (0, 0, "root"),
294            ]
295        with BTrie(prefixes) as btrie:
296            self.assertEqual(btrie.lookup(0, 0), "term0")
297            self.assertEqual(btrie.lookup(1, STRIDE + 1), "term1")
298            self.assertEqual(btrie.lookup(1, 2), "tbm")
299            self.assertEqual(btrie.lookup(1, 1), "root")
300
301    def test_set_internal_data_for_root_prefix(self):
302        # this exercises the "set internal data for root prefix" code
303        prefixes = [
304            # create TBM node at depth 1, with internal prefix data
305            # and an extending path on a deep internal node
306            (0, 1, "tbm"),
307            (0, STRIDE, "int data"),
308            (0, STRIDE + 1, "ext path"),
309            # promote one level
310            (0, 0, "root"),
311            ]
312        with BTrie(prefixes) as btrie:
313            self.assertEqual(btrie.lookup(0, 0), "ext path")
314            self.assertEqual(btrie.lookup(1, STRIDE + 1), "int data")
315            self.assertEqual(btrie.lookup(1, 2), "tbm")
316            self.assertEqual(btrie.lookup(1, 1), "root")
317
318    def test_set_right_ext_path(self):
319        # this exercises the insert_lc_node(right_ext) call in init_tbm_node()
320        # this also exercises next_pbyte with (pos + TBM_STRIDE) % 8 == 0
321        prefixes = [
322            # create TBM node at depth (9 - STRIDE) with a right TBM
323            # extending path on a deep internal node
324            (0, 9 - STRIDE, "tbm"),
325            (1, 9, "ext path"),
326            (2, 10, "term"),
327            # promote one level to depth (8 - STRIDE)
328            (0, 8 - STRIDE, "top"),
329            ]
330        with BTrie(prefixes) as btrie:
331            self.assertEqual(btrie.lookup(0, 0), "tbm")
332            self.assertEqual(btrie.lookup(2, 10), "term")
333            self.assertEqual(btrie.lookup(3, 10), "ext path")
334            self.assertEqual(btrie.lookup(1, 9 - STRIDE), "top")
335
336class Test_add_to_trie(unittest.TestCase):
337    def test_duplicate_terminal_lc(self):
338        prefixes = [
339            (0, 1, "term"),
340            (0, 1, "term"),
341            ]
342        stderr = CaptureOutput()
343        with BTrie(prefixes, stderr=stderr) as btrie:
344            self.assertEqual(btrie.lookup(0, 0), "term")
345        self.assertTrue("duplicated entry for" in stderr,
346                        "No duplicated entry error message in stderr: %r"
347                        % str(stderr))
348
349    def test_duplicate_internal_data(self):
350        prefixes = [
351            (0, 0, "root"),
352            (2, 3, "term"),
353            (2, 3, "term"),
354            ]
355        stderr = CaptureOutput()
356        with BTrie(prefixes, stderr=stderr) as btrie:
357            self.assertEqual(btrie.lookup(4, 4), "term")
358            self.assertEqual(btrie.lookup(0, 0), "root")
359        self.assertTrue("duplicated entry for" in stderr,
360                        "No duplicated entry error message in stderr: %r"
361                        % str(stderr))
362
363    def test_split_first_byte_of_lc_prefix(self):
364        # this is for coverage of common_prefix()
365        prefixes = [
366            (0x1234, 16, "long"),
367            (0x1000, 16, "splitter"),
368            ]
369        with BTrie(prefixes) as btrie:
370            self.assertEqual(btrie.lookup(0x1234, 16), "long")
371            self.assertEqual(btrie.lookup(0x1000, 16), "splitter")
372
373    def test_split_last_byte_of_lc_prefix(self):
374        # this is for coverage of common_prefix()
375        prefixes = [
376            (0x1234, 15, "long"),
377            (0x1238, 15, "splitter"),
378            ]
379        with BTrie(prefixes) as btrie:
380            self.assertEqual(btrie.lookup(0x1234, 15), "long")
381            self.assertEqual(btrie.lookup(0x1238, 15), "splitter")
382
383class Test_search_trie(unittest.TestCase):
384    def test_tbm_root_data(self):
385        # test access to root internal node in a TBM node
386        prefixes = [(0, 127, "tbm root"),
387                    (1, 128, "int data")]
388        with BTrie(prefixes) as btrie:
389            self.assertEqual(btrie.lookup(0, 0), "tbm root")
390
391    def test_tbm_internal_data(self):
392        # test access to each (non-root) internal node in a TBM node
393        for plen in range(1, STRIDE):
394            # TBM node
395            prefixes = [(0, 128 - plen, "tbm root")]
396            prefixes.extend((pfx, 128, "%u/%u" % (pfx, plen))
397                            for pfx in range(1 << plen))
398            with BTrie(prefixes) as btrie:
399                for pfx in range(1 << plen):
400                    self.assertEqual(btrie.lookup(pfx, 128),
401                                     "%u/%u" % (pfx, plen))
402
403    def test_tbm_extending_paths(self):
404        # test access to each extended path of a TBM node
405        prefixes = [(0,0,"root")] # make sure to create top-level TBM node
406        prefixes.extend((pfx, STRIDE, str(pfx)) for pfx in range(1 << STRIDE))
407        with BTrie(prefixes) as btrie:
408            for pfx in range(1 << STRIDE):
409                self.assertEqual(btrie.lookup(pfx, STRIDE), str(pfx))
410
411    def test_no_match(self):
412        prefixes = [
413            (1, 2, "term"),
414            ]
415        with BTrie(prefixes) as btrie:
416            self.assertEqual(btrie.lookup(0,0), None)
417
418    def test_follow_lc(self):
419        prefixes = [
420            (0, 2 * STRIDE, "term"),
421            ]
422        with BTrie(prefixes) as btrie:
423            self.assertEqual(btrie.lookup(0,0), "term")
424
425    def test_parents_internal_data(self):
426        prefixes = [
427            (0, 0, "root"),
428            (2, 2, "int data"),
429            (0x200, 10, "term"),
430            ]
431        with BTrie(prefixes) as btrie:
432            self.assertEqual(btrie.lookup(0x201, 10), "int data")
433
434if __name__ == '__main__':
435    unittest.main()
436