1"""Simple key-value based storage for liquidctl drivers.
2
3Copyright (C) 2019–2021  Jonas Malaco and contributors
4SPDX-License-Identifier: GPL-3.0-or-later
5"""
6
7import logging
8import os
9import sys
10import tempfile
11from ast import literal_eval
12from contextlib import contextmanager
13
14if sys.platform == 'win32':
15    import msvcrt
16else:
17    import fcntl
18
19
20_LOGGER = logging.getLogger(__name__)
21XDG_RUNTIME_DIR = os.getenv('XDG_RUNTIME_DIR')
22
23
24def get_runtime_dirs(appname='liquidctl'):
25    """Return base directories for application runtime data.
26
27    Directories are returned in order of preference.
28    """
29    if sys.platform == 'win32':
30        dirs = [os.path.join(os.getenv('TEMP'), appname)]
31    elif sys.platform == 'darwin':
32        dirs = [os.path.expanduser(os.path.join('~/Library/Caches', appname))]
33    elif sys.platform == 'linux':
34        # threat all other platforms as *nix and conform to XDG basedir spec
35        dirs = []
36        if XDG_RUNTIME_DIR:
37            dirs.append(os.path.join(XDG_RUNTIME_DIR, appname))
38        # regardless whether XDG_RUNTIME_DIR is set, fallback to /var/run if it
39        # is available; this allows a user with XDG_RUNTIME_DIR set to still
40        # find data stored by another user as long as it is in the fallback
41        # path (see #37 for a real world use case)
42        if os.path.isdir('/var/run'):
43            dirs.append(os.path.join('/var/run', appname))
44        assert dirs, 'Could not get a suitable place to store runtime data'
45    else:
46        dirs = [os.path.join('/tmp', appname)]
47    return dirs
48
49
50@contextmanager
51def _open_with_lock(path, flags, *, shared=False):
52    if flags | os.O_RDWR:
53        write_mode = 'r+'
54    elif flags | os.O_RDONLY:
55        write_mode = 'r'
56    elif flags | os.O_WRONLY:
57        write_mode = 'w'
58    else:
59        assert False, 'unreachable'
60
61    with os.fdopen(os.open(path, flags), mode=write_mode) as f:
62        if sys.platform == 'win32':
63            msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1)
64        elif shared:
65            fcntl.flock(f, fcntl.LOCK_SH)
66        else:
67            fcntl.flock(f, fcntl.LOCK_EX)
68
69        yield f
70
71
72class _FilesystemBackend:
73    def _sanitize(self, key):
74        if not isinstance(key, str):
75            raise TypeError('key must str')
76        if not key.isidentifier():
77            raise ValueError('key must be valid Python identifier')
78        return key
79
80    def __init__(self, key_prefixes, runtime_dirs=get_runtime_dirs()):
81        key_prefixes = [self._sanitize(p) for p in key_prefixes]
82        # compute read and write dirs from base runtime dirs: the first base
83        # dir is selected for writes and prefered for reads
84        self._read_dirs = [os.path.join(x, *key_prefixes) for x in runtime_dirs]
85        self._write_dir = self._read_dirs[0]
86        os.makedirs(self._write_dir, exist_ok=True)
87        if sys.platform == 'linux':
88            # set the sticky bit to prevent removal during cleanup
89            os.chmod(self._write_dir, 0o1700)
90        _LOGGER.debug('data in %s', self._write_dir)
91
92    def load(self, key):
93        for base in self._read_dirs:
94            path = os.path.join(base, key)
95            if not os.path.isfile(path):
96                continue
97            try:
98                with _open_with_lock(path, os.O_RDONLY, shared=True) as f:
99                    data = f.read().strip()
100
101                if not data:
102                    continue
103
104                value = literal_eval(data)
105                _LOGGER.debug('loaded %s=%r (from %s)', key, value, path)
106            except OSError as err:
107                _LOGGER.warning('%s exists but could not be read: %s', path, err)
108            except ValueError as err:
109                _LOGGER.warning('%s exists but was corrupted: %s', key, err)
110            else:
111                return value
112
113        _LOGGER.debug('no data (file) found for %s', key)
114        return None
115
116    def store(self, key, value):
117        data = repr(value)
118        assert literal_eval(data) == value, 'encode/decode roundtrip fails'
119        path = os.path.join(self._write_dir, key)
120
121        with _open_with_lock(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC) as f:
122            f.write(data)
123            f.flush()  # ensure flushing before automatic unlocking
124
125        _LOGGER.debug('stored %s=%r (in %s)', key, value, path)
126
127    def load_store(self, key, func):
128        value = None
129        new_value = None
130
131        path = os.path.join(self._write_dir, key)
132
133        # lock the destination as soon as possible
134        with _open_with_lock(path, os.O_RDWR | os.O_CREAT) as f:
135
136            # still traverse all possible locations to find the current value
137            for base in self._read_dirs:
138                read_path = os.path.join(base, key)
139                if not os.path.isfile(read_path):
140                    continue
141                try:
142                    if os.path.samefile(read_path, path):
143                        # we already have an exclusive lock to this file
144                        data = f.read().strip()
145                        f.seek(0)
146                    else:
147                        with _open_with_lock(read_path, os.O_RDONLY, shared=True) as aux:
148                            data = aux.read().strip()
149
150                    if not data:
151                        continue
152
153                    value = literal_eval(data)
154                    _LOGGER.debug('loaded %s=%r (from %s)', key, value, read_path)
155                    break
156                except OSError as err:
157                    _LOGGER.warning('%s exists but could not be read: %s', read_path, err)
158                except ValueError as err:
159                    _LOGGER.warning('%s exists but was corrupted: %s', key, err)
160            else:
161                _LOGGER.debug('no data (file) found for %s', key)
162
163            new_value = func(value)
164
165            data = repr(new_value)
166            assert literal_eval(data) == new_value, 'encode/decode roundtrip fails'
167            f.write(data)
168            f.truncate()
169            f.flush()  # ensure flushing before automatic unlocking
170
171            _LOGGER.debug('replaced with %s=%r (stored in %s)', key, new_value, path)
172
173        return (value, new_value)
174
175
176class RuntimeStorage:
177    """Unstable API."""
178
179    def __init__(self, key_prefixes, backend=None):
180        if not backend:
181            backend = _FilesystemBackend(key_prefixes)
182        self._backend = backend
183
184    def load(self, key, of_type=None, default=None):
185        """Unstable API."""
186
187        value = self._backend.load(key)
188
189        if value is None:
190            return default
191        elif of_type and not isinstance(value, of_type):
192            return default
193        else:
194            return value
195
196    def load_store(self, key, func, of_type=None, default=None):
197        """Unstable API."""
198
199        def l(value):
200            if value is None:
201                value = default
202            elif of_type and not isinstance(value, of_type):
203                value = default
204            return func(value)
205
206        return self._backend.load_store(key, l)
207
208    def store(self, key, value):
209        """Unstable API."""
210        self._backend.store(key, value)
211        return value
212