1import multiprocessing
2import threading
3import weakref
4from typing import Any, MutableMapping, Optional
5
6try:
7    from dask.utils import SerializableLock
8except ImportError:
9    # no need to worry about serializing the lock
10    SerializableLock = threading.Lock
11
12try:
13    from dask.distributed import Lock as DistributedLock
14except ImportError:
15    DistributedLock = None
16
17
18# Locks used by multiple backends.
19# Neither HDF5 nor the netCDF-C library are thread-safe.
20HDF5_LOCK = SerializableLock()
21NETCDFC_LOCK = SerializableLock()
22
23
24_FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary()
25
26
27def _get_threaded_lock(key):
28    try:
29        lock = _FILE_LOCKS[key]
30    except KeyError:
31        lock = _FILE_LOCKS[key] = threading.Lock()
32    return lock
33
34
35def _get_multiprocessing_lock(key):
36    # TODO: make use of the key -- maybe use locket.py?
37    # https://github.com/mwilliamson/locket.py
38    del key  # unused
39    return multiprocessing.Lock()
40
41
42_LOCK_MAKERS = {
43    None: _get_threaded_lock,
44    "threaded": _get_threaded_lock,
45    "multiprocessing": _get_multiprocessing_lock,
46    "distributed": DistributedLock,
47}
48
49
50def _get_lock_maker(scheduler=None):
51    """Returns an appropriate function for creating resource locks.
52
53    Parameters
54    ----------
55    scheduler : str or None
56        Dask scheduler being used.
57
58    See Also
59    --------
60    dask.utils.get_scheduler_lock
61    """
62    return _LOCK_MAKERS[scheduler]
63
64
65def _get_scheduler(get=None, collection=None) -> Optional[str]:
66    """Determine the dask scheduler that is being used.
67
68    None is returned if no dask scheduler is active.
69
70    See Also
71    --------
72    dask.base.get_scheduler
73    """
74    try:
75        # Fix for bug caused by dask installation that doesn't involve the toolz library
76        # Issue: 4164
77        import dask
78        from dask.base import get_scheduler  # noqa: F401
79
80        actual_get = get_scheduler(get, collection)
81    except ImportError:
82        return None
83
84    try:
85        from dask.distributed import Client
86
87        if isinstance(actual_get.__self__, Client):
88            return "distributed"
89    except (ImportError, AttributeError):
90        pass
91
92    try:
93        # As of dask=2.6, dask.multiprocessing requires cloudpickle to be installed
94        # Dependency removed in https://github.com/dask/dask/pull/5511
95        if actual_get is dask.multiprocessing.get:
96            return "multiprocessing"
97    except AttributeError:
98        pass
99
100    return "threaded"
101
102
103def get_write_lock(key):
104    """Get a scheduler appropriate lock for writing to the given resource.
105
106    Parameters
107    ----------
108    key : str
109        Name of the resource for which to acquire a lock. Typically a filename.
110
111    Returns
112    -------
113    Lock object that can be used like a threading.Lock object.
114    """
115    scheduler = _get_scheduler()
116    lock_maker = _get_lock_maker(scheduler)
117    return lock_maker(key)
118
119
120def acquire(lock, blocking=True):
121    """Acquire a lock, possibly in a non-blocking fashion.
122
123    Includes backwards compatibility hacks for old versions of Python, dask
124    and dask-distributed.
125    """
126    if blocking:
127        # no arguments needed
128        return lock.acquire()
129    elif DistributedLock is not None and isinstance(lock, DistributedLock):
130        # distributed.Lock doesn't support the blocking argument yet:
131        # https://github.com/dask/distributed/pull/2412
132        return lock.acquire(timeout=0)
133    else:
134        # "blocking" keyword argument not supported for:
135        # - threading.Lock on Python 2.
136        # - dask.SerializableLock with dask v1.0.0 or earlier.
137        # - multiprocessing.Lock calls the argument "block" instead.
138        return lock.acquire(blocking)
139
140
141class CombinedLock:
142    """A combination of multiple locks.
143
144    Like a locked door, a CombinedLock is locked if any of its constituent
145    locks are locked.
146    """
147
148    def __init__(self, locks):
149        self.locks = tuple(set(locks))  # remove duplicates
150
151    def acquire(self, blocking=True):
152        return all(acquire(lock, blocking=blocking) for lock in self.locks)
153
154    def release(self):
155        for lock in self.locks:
156            lock.release()
157
158    def __enter__(self):
159        for lock in self.locks:
160            lock.__enter__()
161
162    def __exit__(self, *args):
163        for lock in self.locks:
164            lock.__exit__(*args)
165
166    def locked(self):
167        return any(lock.locked for lock in self.locks)
168
169    def __repr__(self):
170        return f"CombinedLock({list(self.locks)!r})"
171
172
173class DummyLock:
174    """DummyLock provides the lock API without any actual locking."""
175
176    def acquire(self, blocking=True):
177        pass
178
179    def release(self):
180        pass
181
182    def __enter__(self):
183        pass
184
185    def __exit__(self, *args):
186        pass
187
188    def locked(self):
189        return False
190
191
192def combine_locks(locks):
193    """Combine a sequence of locks into a single lock."""
194    all_locks = []
195    for lock in locks:
196        if isinstance(lock, CombinedLock):
197            all_locks.extend(lock.locks)
198        elif lock is not None:
199            all_locks.append(lock)
200
201    num_locks = len(all_locks)
202    if num_locks > 1:
203        return CombinedLock(all_locks)
204    elif num_locks == 1:
205        return all_locks[0]
206    else:
207        return DummyLock()
208
209
210def ensure_lock(lock):
211    """Ensure that the given object is a lock."""
212    if lock is None or lock is False:
213        return DummyLock()
214    return lock
215