1import sys
2
3from .common import ZictBase
4
5
6def _encode_key(key):
7    return key.encode("utf-8")
8
9
10def _decode_key(key):
11    return key.decode("utf-8")
12
13
14class LMDB(ZictBase):
15    """ Mutable Mapping interface to a LMDB database.
16
17    Keys must be strings, values must be bytes
18
19    Parameters
20    ----------
21    directory: string
22
23    Examples
24    --------
25    >>> z = LMDB('/tmp/somedir/')  # doctest: +SKIP
26    >>> z['x'] = b'123'  # doctest: +SKIP
27    >>> z['x']  # doctest: +SKIP
28    b'123'
29    """
30
31    def __init__(self, directory):
32        import lmdb
33
34        # map_size is the maximum database size but shouldn't fill up the
35        # virtual address space
36        map_size = 1 << 40 if sys.maxsize >= 2 ** 32 else 1 << 28
37        # writemap requires sparse file support otherwise the whole
38        # `map_size` may be reserved up front on disk
39        writemap = sys.platform.startswith("linux")
40        self.db = lmdb.open(
41            directory, subdir=True, map_size=map_size, sync=False, writemap=writemap,
42        )
43
44    def __getitem__(self, key):
45        with self.db.begin() as txn:
46            value = txn.get(_encode_key(key))
47        if value is None:
48            raise KeyError(key)
49        return value
50
51    def __setitem__(self, key, value):
52        with self.db.begin(write=True) as txn:
53            txn.put(_encode_key(key), value)
54
55    def __contains__(self, key):
56        with self.db.begin() as txn:
57            return txn.cursor().set_key(_encode_key(key))
58
59    def items(self):
60        cursor = self.db.begin().cursor()
61        return ((_decode_key(k), v) for k, v in cursor.iternext(keys=True, values=True))
62
63    def keys(self):
64        cursor = self.db.begin().cursor()
65        return (_decode_key(k) for k in cursor.iternext(keys=True, values=False))
66
67    def values(self):
68        cursor = self.db.begin().cursor()
69        return cursor.iternext(keys=False, values=True)
70
71    def _do_update(self, items):
72        # Optimized version of update() using a single putmulti() call.
73        items = [(_encode_key(k), v) for k, v in items]
74        with self.db.begin(write=True) as txn:
75            consumed, added = txn.cursor().putmulti(items)
76            assert consumed == added == len(items)
77
78    def __iter__(self):
79        return self.keys()
80
81    def __delitem__(self, key):
82        with self.db.begin(write=True) as txn:
83            if not txn.delete(_encode_key(key)):
84                raise KeyError(key)
85
86    def __len__(self):
87        return self.db.stat()["entries"]
88
89    def close(self):
90        self.db.close()
91