1from typing import Optional
2
3import pytest
4
5from anyio import (
6    CancelScope, Condition, Event, Lock, Semaphore, WouldBlock, create_task_group, to_thread,
7    wait_all_tasks_blocked)
8from anyio.abc import CapacityLimiter, TaskStatus
9
10pytestmark = pytest.mark.anyio
11
12
13class TestLock:
14    async def test_contextmanager(self) -> None:
15        async def task() -> None:
16            assert lock.locked()
17            async with lock:
18                results.append('2')
19
20        results = []
21        lock = Lock()
22        async with create_task_group() as tg:
23            async with lock:
24                tg.start_soon(task)
25                await wait_all_tasks_blocked()
26                results.append('1')
27
28        assert not lock.locked()
29        assert results == ['1', '2']
30
31    async def test_manual_acquire(self) -> None:
32        async def task() -> None:
33            assert lock.locked()
34            await lock.acquire()
35            try:
36                results.append('2')
37            finally:
38                lock.release()
39
40        results = []
41        lock = Lock()
42        async with create_task_group() as tg:
43            await lock.acquire()
44            try:
45                tg.start_soon(task)
46                await wait_all_tasks_blocked()
47                results.append('1')
48            finally:
49                lock.release()
50
51        assert not lock.locked()
52        assert results == ['1', '2']
53
54    async def test_acquire_nowait(self) -> None:
55        lock = Lock()
56        lock.acquire_nowait()
57        assert lock.locked()
58
59    async def test_acquire_nowait_wouldblock(self) -> None:
60        async def try_lock() -> None:
61            pytest.raises(WouldBlock, lock.acquire_nowait)
62
63        lock = Lock()
64        async with lock, create_task_group() as tg:
65            assert lock.locked()
66            tg.start_soon(try_lock)
67
68    @pytest.mark.parametrize('release_first', [
69        pytest.param(False, id='releaselast'),
70        pytest.param(True, id='releasefirst')
71    ])
72    async def test_cancel_during_acquire(self, release_first: bool) -> None:
73        acquired = False
74
75        async def task(*, task_status: TaskStatus) -> None:
76            nonlocal acquired
77            task_status.started()
78            async with lock:
79                acquired = True
80
81        lock = Lock()
82        async with create_task_group() as tg:
83            await lock.acquire()
84            await tg.start(task)
85            tg.cancel_scope.cancel()
86            with CancelScope(shield=True):
87                if release_first:
88                    lock.release()
89                    await wait_all_tasks_blocked()
90                else:
91                    await wait_all_tasks_blocked()
92                    lock.release()
93
94        assert not acquired
95        assert not lock.locked()
96
97    async def test_statistics(self) -> None:
98        async def waiter() -> None:
99            async with lock:
100                pass
101
102        lock = Lock()
103        async with create_task_group() as tg:
104            assert not lock.statistics().locked
105            assert lock.statistics().tasks_waiting == 0
106            async with lock:
107                assert lock.statistics().locked
108                assert lock.statistics().tasks_waiting == 0
109                for i in range(1, 3):
110                    tg.start_soon(waiter)
111                    await wait_all_tasks_blocked()
112                    assert lock.statistics().tasks_waiting == i
113
114        assert not lock.statistics().locked
115        assert lock.statistics().tasks_waiting == 0
116
117
118class TestEvent:
119    async def test_event(self) -> None:
120        async def setter() -> None:
121            assert not event.is_set()
122            event.set()
123
124        event = Event()
125        async with create_task_group() as tg:
126            tg.start_soon(setter)
127            await event.wait()
128
129        assert event.is_set()
130
131    async def test_event_cancel(self) -> None:
132        task_started = event_set = False
133
134        async def task() -> None:
135            nonlocal task_started, event_set
136            task_started = True
137            await event.wait()
138            event_set = True
139
140        event = Event()
141        async with create_task_group() as tg:
142            tg.start_soon(task)
143            tg.cancel_scope.cancel()
144            event.set()
145
146        assert task_started
147        assert not event_set
148
149    async def test_statistics(self) -> None:
150        async def waiter() -> None:
151            await event.wait()
152
153        event = Event()
154        async with create_task_group() as tg:
155            assert event.statistics().tasks_waiting == 0
156            for i in range(1, 3):
157                tg.start_soon(waiter)
158                await wait_all_tasks_blocked()
159                assert event.statistics().tasks_waiting == i
160
161            event.set()
162
163        assert event.statistics().tasks_waiting == 0
164
165
166class TestCondition:
167    async def test_contextmanager(self) -> None:
168        async def notifier() -> None:
169            async with condition:
170                condition.notify_all()
171
172        condition = Condition()
173        async with create_task_group() as tg:
174            async with condition:
175                assert condition.locked()
176                tg.start_soon(notifier)
177                await condition.wait()
178
179    async def test_manual_acquire(self) -> None:
180        async def notifier() -> None:
181            await condition.acquire()
182            try:
183                condition.notify_all()
184            finally:
185                condition.release()
186
187        condition = Condition()
188        async with create_task_group() as tg:
189            await condition.acquire()
190            try:
191                assert condition.locked()
192                tg.start_soon(notifier)
193                await condition.wait()
194            finally:
195                condition.release()
196
197    async def test_acquire_nowait(self) -> None:
198        condition = Condition()
199        condition.acquire_nowait()
200        assert condition.locked()
201
202    async def test_acquire_nowait_wouldblock(self) -> None:
203        async def try_lock() -> None:
204            pytest.raises(WouldBlock, condition.acquire_nowait)
205
206        condition = Condition()
207        async with condition, create_task_group() as tg:
208            assert condition.locked()
209            tg.start_soon(try_lock)
210
211    async def test_wait_cancel(self) -> None:
212        task_started = notified = False
213
214        async def task() -> None:
215            nonlocal task_started, notified
216            task_started = True
217            async with condition:
218                event.set()
219                await condition.wait()
220                notified = True
221
222        event = Event()
223        condition = Condition()
224        async with create_task_group() as tg:
225            tg.start_soon(task)
226            await event.wait()
227            await wait_all_tasks_blocked()
228            tg.cancel_scope.cancel()
229
230        assert task_started
231        assert not notified
232
233    async def test_statistics(self) -> None:
234        async def waiter() -> None:
235            async with condition:
236                await condition.wait()
237
238        condition = Condition()
239        async with create_task_group() as tg:
240            assert not condition.statistics().lock_statistics.locked
241            assert condition.statistics().tasks_waiting == 0
242            async with condition:
243                assert condition.statistics().lock_statistics.locked
244                assert condition.statistics().tasks_waiting == 0
245
246            for i in range(1, 3):
247                tg.start_soon(waiter)
248                await wait_all_tasks_blocked()
249                assert condition.statistics().tasks_waiting == i
250
251            for i in range(1, -1, -1):
252                async with condition:
253                    condition.notify(1)
254
255                await wait_all_tasks_blocked()
256                assert condition.statistics().tasks_waiting == i
257
258        assert not condition.statistics().lock_statistics.locked
259        assert condition.statistics().tasks_waiting == 0
260
261
262class TestSemaphore:
263    async def test_contextmanager(self) -> None:
264        async def acquire() -> None:
265            async with semaphore:
266                assert semaphore.value in (0, 1)
267
268        semaphore = Semaphore(2)
269        async with create_task_group() as tg:
270            tg.start_soon(acquire, name='task 1')
271            tg.start_soon(acquire, name='task 2')
272
273        assert semaphore.value == 2
274
275    async def test_manual_acquire(self) -> None:
276        async def acquire() -> None:
277            await semaphore.acquire()
278            try:
279                assert semaphore.value in (0, 1)
280            finally:
281                semaphore.release()
282
283        semaphore = Semaphore(2)
284        async with create_task_group() as tg:
285            tg.start_soon(acquire, name='task 1')
286            tg.start_soon(acquire, name='task 2')
287
288        assert semaphore.value == 2
289
290    async def test_acquire_nowait(self) -> None:
291        semaphore = Semaphore(1)
292        semaphore.acquire_nowait()
293        assert semaphore.value == 0
294        pytest.raises(WouldBlock, semaphore.acquire_nowait)
295
296    @pytest.mark.parametrize('release_first', [
297        pytest.param(False, id='releaselast'),
298        pytest.param(True, id='releasefirst')
299    ])
300    async def test_cancel_during_acquire(self, release_first: bool) -> None:
301        acquired = False
302
303        async def task(*, task_status: TaskStatus) -> None:
304            nonlocal acquired
305            task_status.started()
306            async with semaphore:
307                acquired = True
308
309        semaphore = Semaphore(1)
310        async with create_task_group() as tg:
311            await semaphore.acquire()
312            await tg.start(task)
313            tg.cancel_scope.cancel()
314            with CancelScope(shield=True):
315                if release_first:
316                    semaphore.release()
317                    await wait_all_tasks_blocked()
318                else:
319                    await wait_all_tasks_blocked()
320                    semaphore.release()
321
322        assert not acquired
323        assert semaphore.value == 1
324
325    @pytest.mark.parametrize('max_value', [2, None])
326    async def test_max_value(self, max_value: Optional[int]) -> None:
327        semaphore = Semaphore(0, max_value=max_value)
328        assert semaphore.max_value == max_value
329
330    async def test_max_value_exceeded(self) -> None:
331        semaphore = Semaphore(1, max_value=2)
332        semaphore.release()
333        pytest.raises(ValueError, semaphore.release)
334
335    async def test_statistics(self) -> None:
336        async def waiter() -> None:
337            async with semaphore:
338                pass
339
340        semaphore = Semaphore(1)
341        async with create_task_group() as tg:
342            assert semaphore.statistics().tasks_waiting == 0
343            async with semaphore:
344                assert semaphore.statistics().tasks_waiting == 0
345                for i in range(1, 3):
346                    tg.start_soon(waiter)
347                    await wait_all_tasks_blocked()
348                    assert semaphore.statistics().tasks_waiting == i
349
350        assert semaphore.statistics().tasks_waiting == 0
351
352    async def test_acquire_race(self) -> None:
353        """
354        Test against a race condition: when a task waiting on acquire() is rescheduled but another
355        task snatches the last available slot, the task should not raise WouldBlock.
356
357        """
358        semaphore = Semaphore(1)
359        async with create_task_group() as tg:
360            semaphore.acquire_nowait()
361            tg.start_soon(semaphore.acquire)
362            await wait_all_tasks_blocked()
363            semaphore.release()
364            pytest.raises(WouldBlock, semaphore.acquire_nowait)
365
366
367class TestCapacityLimiter:
368    async def test_bad_init_type(self) -> None:
369        pytest.raises(TypeError, CapacityLimiter, 1.0).\
370            match('total_tokens must be an int or math.inf')
371
372    async def test_bad_init_value(self) -> None:
373        pytest.raises(ValueError, CapacityLimiter, 0).\
374            match('total_tokens must be >= 1')
375
376    async def test_borrow(self) -> None:
377        limiter = CapacityLimiter(2)
378        assert limiter.total_tokens == 2
379        assert limiter.available_tokens == 2
380        assert limiter.borrowed_tokens == 0
381        async with limiter:
382            assert limiter.total_tokens == 2
383            assert limiter.available_tokens == 1
384            assert limiter.borrowed_tokens == 1
385
386    async def test_limit(self) -> None:
387        value = 0
388
389        async def taskfunc() -> None:
390            nonlocal value
391            for _ in range(5):
392                async with limiter:
393                    assert value == 0
394                    value = 1
395                    await wait_all_tasks_blocked()
396                    value = 0
397
398        limiter = CapacityLimiter(1)
399        async with create_task_group() as tg:
400            for _ in range(3):
401                tg.start_soon(taskfunc)
402
403    async def test_borrow_twice(self) -> None:
404        limiter = CapacityLimiter(1)
405        await limiter.acquire()
406        with pytest.raises(RuntimeError) as exc:
407            await limiter.acquire()
408
409        exc.match("this borrower is already holding one of this CapacityLimiter's tokens")
410
411    async def test_bad_release(self) -> None:
412        limiter = CapacityLimiter(1)
413        with pytest.raises(RuntimeError) as exc:
414            limiter.release()
415
416        exc.match("this borrower isn't holding any of this CapacityLimiter's tokens")
417
418    async def test_increase_tokens(self) -> None:
419        async def setter() -> None:
420            # Wait until waiter() is inside the limiter block
421            await event1.wait()
422            async with limiter:
423                # This can only happen when total_tokens has been increased
424                event2.set()
425
426        async def waiter() -> None:
427            async with limiter:
428                event1.set()
429                await event2.wait()
430
431        limiter = CapacityLimiter(1)
432        event1, event2 = Event(), Event()
433        async with create_task_group() as tg:
434            tg.start_soon(setter)
435            tg.start_soon(waiter)
436            await wait_all_tasks_blocked()
437            assert event1.is_set()
438            assert not event2.is_set()
439            limiter.total_tokens = 2
440
441        assert event2.is_set()
442
443    async def test_current_default_thread_limiter(self) -> None:
444        limiter = to_thread.current_default_thread_limiter()
445        assert isinstance(limiter, CapacityLimiter)
446        assert limiter.total_tokens == 40
447
448    async def test_statistics(self) -> None:
449        async def waiter() -> None:
450            async with limiter:
451                pass
452
453        limiter = CapacityLimiter(1)
454        assert limiter.statistics().total_tokens == 1
455        assert limiter.statistics().borrowed_tokens == 0
456        assert limiter.statistics().tasks_waiting == 0
457        async with create_task_group() as tg:
458            async with limiter:
459                assert limiter.statistics().borrowed_tokens == 1
460                assert limiter.statistics().tasks_waiting == 0
461                for i in range(1, 3):
462                    tg.start_soon(waiter)
463                    await wait_all_tasks_blocked()
464                    assert limiter.statistics().tasks_waiting == i
465
466        assert limiter.statistics().tasks_waiting == 0
467        assert limiter.statistics().borrowed_tokens == 0
468