1from __future__ import absolute_import, division, unicode_literals
2from six import text_type
3
4from bisect import bisect_left
5
6from ._base import Trie as ABCTrie
7
8
9class Trie(ABCTrie):
10    def __init__(self, data):
11        if not all(isinstance(x, text_type) for x in data.keys()):
12            raise TypeError("All keys must be strings")
13
14        self._data = data
15        self._keys = sorted(data.keys())
16        self._cachestr = ""
17        self._cachepoints = (0, len(data))
18
19    def __contains__(self, key):
20        return key in self._data
21
22    def __len__(self):
23        return len(self._data)
24
25    def __iter__(self):
26        return iter(self._data)
27
28    def __getitem__(self, key):
29        return self._data[key]
30
31    def keys(self, prefix=None):
32        if prefix is None or prefix == "" or not self._keys:
33            return set(self._keys)
34
35        if prefix.startswith(self._cachestr):
36            lo, hi = self._cachepoints
37            start = i = bisect_left(self._keys, prefix, lo, hi)
38        else:
39            start = i = bisect_left(self._keys, prefix)
40
41        keys = set()
42        if start == len(self._keys):
43            return keys
44
45        while self._keys[i].startswith(prefix):
46            keys.add(self._keys[i])
47            i += 1
48
49        self._cachestr = prefix
50        self._cachepoints = (start, i)
51
52        return keys
53
54    def has_keys_with_prefix(self, prefix):
55        if prefix in self._data:
56            return True
57
58        if prefix.startswith(self._cachestr):
59            lo, hi = self._cachepoints
60            i = bisect_left(self._keys, prefix, lo, hi)
61        else:
62            i = bisect_left(self._keys, prefix)
63
64        if i == len(self._keys):
65            return False
66
67        return self._keys[i].startswith(prefix)
68