1"""Provides shared memory for direct access across processes.
2
3The API of this package is currently provisional. Refer to the
4documentation for details.
5"""
6
7
8__all__ = [ 'SharedMemory', 'ShareableList' ]
9
10
11from functools import partial
12import mmap
13import os
14import errno
15import struct
16import secrets
17
18if os.name == "nt":
19    import _winapi
20    _USE_POSIX = False
21else:
22    import _posixshmem
23    _USE_POSIX = True
24
25
26_O_CREX = os.O_CREAT | os.O_EXCL
27
28# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
29_SHM_SAFE_NAME_LENGTH = 14
30
31# Shared memory block name prefix
32if _USE_POSIX:
33    _SHM_NAME_PREFIX = '/psm_'
34else:
35    _SHM_NAME_PREFIX = 'wnsm_'
36
37
38def _make_filename():
39    "Create a random filename for the shared memory object."
40    # number of random bytes to use for name
41    nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
42    assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
43    name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
44    assert len(name) <= _SHM_SAFE_NAME_LENGTH
45    return name
46
47
48class SharedMemory:
49    """Creates a new shared memory block or attaches to an existing
50    shared memory block.
51
52    Every shared memory block is assigned a unique name.  This enables
53    one process to create a shared memory block with a particular name
54    so that a different process can attach to that same shared memory
55    block using that same name.
56
57    As a resource for sharing data across processes, shared memory blocks
58    may outlive the original process that created them.  When one process
59    no longer needs access to a shared memory block that might still be
60    needed by other processes, the close() method should be called.
61    When a shared memory block is no longer needed by any process, the
62    unlink() method should be called to ensure proper cleanup."""
63
64    # Defaults; enables close() and unlink() to run without errors.
65    _name = None
66    _fd = -1
67    _mmap = None
68    _buf = None
69    _flags = os.O_RDWR
70    _mode = 0o600
71    _prepend_leading_slash = True if _USE_POSIX else False
72
73    def __init__(self, name=None, create=False, size=0):
74        if not size >= 0:
75            raise ValueError("'size' must be a positive integer")
76        if create:
77            self._flags = _O_CREX | os.O_RDWR
78            if size == 0:
79                raise ValueError("'size' must be a positive number different from zero")
80        if name is None and not self._flags & os.O_EXCL:
81            raise ValueError("'name' can only be None if create=True")
82
83        if _USE_POSIX:
84
85            # POSIX Shared Memory
86
87            if name is None:
88                while True:
89                    name = _make_filename()
90                    try:
91                        self._fd = _posixshmem.shm_open(
92                            name,
93                            self._flags,
94                            mode=self._mode
95                        )
96                    except FileExistsError:
97                        continue
98                    self._name = name
99                    break
100            else:
101                name = "/" + name if self._prepend_leading_slash else name
102                self._fd = _posixshmem.shm_open(
103                    name,
104                    self._flags,
105                    mode=self._mode
106                )
107                self._name = name
108            try:
109                if create and size:
110                    os.ftruncate(self._fd, size)
111                stats = os.fstat(self._fd)
112                size = stats.st_size
113                self._mmap = mmap.mmap(self._fd, size)
114            except OSError:
115                self.unlink()
116                raise
117
118            from .resource_tracker import register
119            register(self._name, "shared_memory")
120
121        else:
122
123            # Windows Named Shared Memory
124
125            if create:
126                while True:
127                    temp_name = _make_filename() if name is None else name
128                    # Create and reserve shared memory block with this name
129                    # until it can be attached to by mmap.
130                    h_map = _winapi.CreateFileMapping(
131                        _winapi.INVALID_HANDLE_VALUE,
132                        _winapi.NULL,
133                        _winapi.PAGE_READWRITE,
134                        (size >> 32) & 0xFFFFFFFF,
135                        size & 0xFFFFFFFF,
136                        temp_name
137                    )
138                    try:
139                        last_error_code = _winapi.GetLastError()
140                        if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
141                            if name is not None:
142                                raise FileExistsError(
143                                    errno.EEXIST,
144                                    os.strerror(errno.EEXIST),
145                                    name,
146                                    _winapi.ERROR_ALREADY_EXISTS
147                                )
148                            else:
149                                continue
150                        self._mmap = mmap.mmap(-1, size, tagname=temp_name)
151                    finally:
152                        _winapi.CloseHandle(h_map)
153                    self._name = temp_name
154                    break
155
156            else:
157                self._name = name
158                # Dynamically determine the existing named shared memory
159                # block's size which is likely a multiple of mmap.PAGESIZE.
160                h_map = _winapi.OpenFileMapping(
161                    _winapi.FILE_MAP_READ,
162                    False,
163                    name
164                )
165                try:
166                    p_buf = _winapi.MapViewOfFile(
167                        h_map,
168                        _winapi.FILE_MAP_READ,
169                        0,
170                        0,
171                        0
172                    )
173                finally:
174                    _winapi.CloseHandle(h_map)
175                size = _winapi.VirtualQuerySize(p_buf)
176                self._mmap = mmap.mmap(-1, size, tagname=name)
177
178        self._size = size
179        self._buf = memoryview(self._mmap)
180
181    def __del__(self):
182        try:
183            self.close()
184        except OSError:
185            pass
186
187    def __reduce__(self):
188        return (
189            self.__class__,
190            (
191                self.name,
192                False,
193                self.size,
194            ),
195        )
196
197    def __repr__(self):
198        return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
199
200    @property
201    def buf(self):
202        "A memoryview of contents of the shared memory block."
203        return self._buf
204
205    @property
206    def name(self):
207        "Unique name that identifies the shared memory block."
208        reported_name = self._name
209        if _USE_POSIX and self._prepend_leading_slash:
210            if self._name.startswith("/"):
211                reported_name = self._name[1:]
212        return reported_name
213
214    @property
215    def size(self):
216        "Size in bytes."
217        return self._size
218
219    def close(self):
220        """Closes access to the shared memory from this instance but does
221        not destroy the shared memory block."""
222        if self._buf is not None:
223            self._buf.release()
224            self._buf = None
225        if self._mmap is not None:
226            self._mmap.close()
227            self._mmap = None
228        if _USE_POSIX and self._fd >= 0:
229            os.close(self._fd)
230            self._fd = -1
231
232    def unlink(self):
233        """Requests that the underlying shared memory block be destroyed.
234
235        In order to ensure proper cleanup of resources, unlink should be
236        called once (and only once) across all processes which have access
237        to the shared memory block."""
238        if _USE_POSIX and self._name:
239            from .resource_tracker import unregister
240            _posixshmem.shm_unlink(self._name)
241            unregister(self._name, "shared_memory")
242
243
244_encoding = "utf8"
245
246class ShareableList:
247    """Pattern for a mutable list-like object shareable via a shared
248    memory block.  It differs from the built-in list type in that these
249    lists can not change their overall length (i.e. no append, insert,
250    etc.)
251
252    Because values are packed into a memoryview as bytes, the struct
253    packing format for any storable value must require no more than 8
254    characters to describe its format."""
255
256    _types_mapping = {
257        int: "q",
258        float: "d",
259        bool: "xxxxxxx?",
260        str: "%ds",
261        bytes: "%ds",
262        None.__class__: "xxxxxx?x",
263    }
264    _alignment = 8
265    _back_transforms_mapping = {
266        0: lambda value: value,                   # int, float, bool
267        1: lambda value: value.rstrip(b'\x00').decode(_encoding),  # str
268        2: lambda value: value.rstrip(b'\x00'),   # bytes
269        3: lambda _value: None,                   # None
270    }
271
272    @staticmethod
273    def _extract_recreation_code(value):
274        """Used in concert with _back_transforms_mapping to convert values
275        into the appropriate Python objects when retrieving them from
276        the list as well as when storing them."""
277        if not isinstance(value, (str, bytes, None.__class__)):
278            return 0
279        elif isinstance(value, str):
280            return 1
281        elif isinstance(value, bytes):
282            return 2
283        else:
284            return 3  # NoneType
285
286    def __init__(self, sequence=None, *, name=None):
287        if sequence is not None:
288            _formats = [
289                self._types_mapping[type(item)]
290                    if not isinstance(item, (str, bytes))
291                    else self._types_mapping[type(item)] % (
292                        self._alignment * (len(item) // self._alignment + 1),
293                    )
294                for item in sequence
295            ]
296            self._list_len = len(_formats)
297            assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
298            self._allocated_bytes = tuple(
299                    self._alignment if fmt[-1] != "s" else int(fmt[:-1])
300                    for fmt in _formats
301            )
302            _recreation_codes = [
303                self._extract_recreation_code(item) for item in sequence
304            ]
305            requested_size = struct.calcsize(
306                "q" + self._format_size_metainfo +
307                "".join(_formats) +
308                self._format_packing_metainfo +
309                self._format_back_transform_codes
310            )
311
312        else:
313            requested_size = 8  # Some platforms require > 0.
314
315        if name is not None and sequence is None:
316            self.shm = SharedMemory(name)
317        else:
318            self.shm = SharedMemory(name, create=True, size=requested_size)
319
320        if sequence is not None:
321            _enc = _encoding
322            struct.pack_into(
323                "q" + self._format_size_metainfo,
324                self.shm.buf,
325                0,
326                self._list_len,
327                *(self._allocated_bytes)
328            )
329            struct.pack_into(
330                "".join(_formats),
331                self.shm.buf,
332                self._offset_data_start,
333                *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
334            )
335            struct.pack_into(
336                self._format_packing_metainfo,
337                self.shm.buf,
338                self._offset_packing_formats,
339                *(v.encode(_enc) for v in _formats)
340            )
341            struct.pack_into(
342                self._format_back_transform_codes,
343                self.shm.buf,
344                self._offset_back_transform_codes,
345                *(_recreation_codes)
346            )
347
348        else:
349            self._list_len = len(self)  # Obtains size from offset 0 in buffer.
350            self._allocated_bytes = struct.unpack_from(
351                self._format_size_metainfo,
352                self.shm.buf,
353                1 * 8
354            )
355
356    def _get_packing_format(self, position):
357        "Gets the packing format for a single value stored in the list."
358        position = position if position >= 0 else position + self._list_len
359        if (position >= self._list_len) or (self._list_len < 0):
360            raise IndexError("Requested position out of range.")
361
362        v = struct.unpack_from(
363            "8s",
364            self.shm.buf,
365            self._offset_packing_formats + position * 8
366        )[0]
367        fmt = v.rstrip(b'\x00')
368        fmt_as_str = fmt.decode(_encoding)
369
370        return fmt_as_str
371
372    def _get_back_transform(self, position):
373        "Gets the back transformation function for a single value."
374
375        position = position if position >= 0 else position + self._list_len
376        if (position >= self._list_len) or (self._list_len < 0):
377            raise IndexError("Requested position out of range.")
378
379        transform_code = struct.unpack_from(
380            "b",
381            self.shm.buf,
382            self._offset_back_transform_codes + position
383        )[0]
384        transform_function = self._back_transforms_mapping[transform_code]
385
386        return transform_function
387
388    def _set_packing_format_and_transform(self, position, fmt_as_str, value):
389        """Sets the packing format and back transformation code for a
390        single value in the list at the specified position."""
391
392        position = position if position >= 0 else position + self._list_len
393        if (position >= self._list_len) or (self._list_len < 0):
394            raise IndexError("Requested position out of range.")
395
396        struct.pack_into(
397            "8s",
398            self.shm.buf,
399            self._offset_packing_formats + position * 8,
400            fmt_as_str.encode(_encoding)
401        )
402
403        transform_code = self._extract_recreation_code(value)
404        struct.pack_into(
405            "b",
406            self.shm.buf,
407            self._offset_back_transform_codes + position,
408            transform_code
409        )
410
411    def __getitem__(self, position):
412        try:
413            offset = self._offset_data_start \
414                     + sum(self._allocated_bytes[:position])
415            (v,) = struct.unpack_from(
416                self._get_packing_format(position),
417                self.shm.buf,
418                offset
419            )
420        except IndexError:
421            raise IndexError("index out of range")
422
423        back_transform = self._get_back_transform(position)
424        v = back_transform(v)
425
426        return v
427
428    def __setitem__(self, position, value):
429        try:
430            offset = self._offset_data_start \
431                     + sum(self._allocated_bytes[:position])
432            current_format = self._get_packing_format(position)
433        except IndexError:
434            raise IndexError("assignment index out of range")
435
436        if not isinstance(value, (str, bytes)):
437            new_format = self._types_mapping[type(value)]
438            encoded_value = value
439        else:
440            encoded_value = (value.encode(_encoding)
441                             if isinstance(value, str) else value)
442            if len(encoded_value) > self._allocated_bytes[position]:
443                raise ValueError("bytes/str item exceeds available storage")
444            if current_format[-1] == "s":
445                new_format = current_format
446            else:
447                new_format = self._types_mapping[str] % (
448                    self._allocated_bytes[position],
449                )
450
451        self._set_packing_format_and_transform(
452            position,
453            new_format,
454            value
455        )
456        struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
457
458    def __reduce__(self):
459        return partial(self.__class__, name=self.shm.name), ()
460
461    def __len__(self):
462        return struct.unpack_from("q", self.shm.buf, 0)[0]
463
464    def __repr__(self):
465        return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
466
467    @property
468    def format(self):
469        "The struct packing format used by all currently stored values."
470        return "".join(
471            self._get_packing_format(i) for i in range(self._list_len)
472        )
473
474    @property
475    def _format_size_metainfo(self):
476        "The struct packing format used for metainfo on storage sizes."
477        return f"{self._list_len}q"
478
479    @property
480    def _format_packing_metainfo(self):
481        "The struct packing format used for the values' packing formats."
482        return "8s" * self._list_len
483
484    @property
485    def _format_back_transform_codes(self):
486        "The struct packing format used for the values' back transforms."
487        return "b" * self._list_len
488
489    @property
490    def _offset_data_start(self):
491        return (self._list_len + 1) * 8  # 8 bytes per "q"
492
493    @property
494    def _offset_packing_formats(self):
495        return self._offset_data_start + sum(self._allocated_bytes)
496
497    @property
498    def _offset_back_transform_codes(self):
499        return self._offset_packing_formats + self._list_len * 8
500
501    def count(self, value):
502        "L.count(value) -> integer -- return number of occurrences of value."
503
504        return sum(value == entry for entry in self)
505
506    def index(self, value):
507        """L.index(value) -> integer -- return first index of value.
508        Raises ValueError if the value is not present."""
509
510        for position, entry in enumerate(self):
511            if value == entry:
512                return position
513        else:
514            raise ValueError(f"{value!r} not in this container")
515