1from typing import Optional 2 3import pytest 4 5from anyio import ( 6 CancelScope, Condition, Event, Lock, Semaphore, WouldBlock, create_task_group, to_thread, 7 wait_all_tasks_blocked) 8from anyio.abc import CapacityLimiter, TaskStatus 9 10pytestmark = pytest.mark.anyio 11 12 13class TestLock: 14 async def test_contextmanager(self) -> None: 15 async def task() -> None: 16 assert lock.locked() 17 async with lock: 18 results.append('2') 19 20 results = [] 21 lock = Lock() 22 async with create_task_group() as tg: 23 async with lock: 24 tg.start_soon(task) 25 await wait_all_tasks_blocked() 26 results.append('1') 27 28 assert not lock.locked() 29 assert results == ['1', '2'] 30 31 async def test_manual_acquire(self) -> None: 32 async def task() -> None: 33 assert lock.locked() 34 await lock.acquire() 35 try: 36 results.append('2') 37 finally: 38 lock.release() 39 40 results = [] 41 lock = Lock() 42 async with create_task_group() as tg: 43 await lock.acquire() 44 try: 45 tg.start_soon(task) 46 await wait_all_tasks_blocked() 47 results.append('1') 48 finally: 49 lock.release() 50 51 assert not lock.locked() 52 assert results == ['1', '2'] 53 54 async def test_acquire_nowait(self) -> None: 55 lock = Lock() 56 lock.acquire_nowait() 57 assert lock.locked() 58 59 async def test_acquire_nowait_wouldblock(self) -> None: 60 async def try_lock() -> None: 61 pytest.raises(WouldBlock, lock.acquire_nowait) 62 63 lock = Lock() 64 async with lock, create_task_group() as tg: 65 assert lock.locked() 66 tg.start_soon(try_lock) 67 68 @pytest.mark.parametrize('release_first', [ 69 pytest.param(False, id='releaselast'), 70 pytest.param(True, id='releasefirst') 71 ]) 72 async def test_cancel_during_acquire(self, release_first: bool) -> None: 73 acquired = False 74 75 async def task(*, task_status: TaskStatus) -> None: 76 nonlocal acquired 77 task_status.started() 78 async with lock: 79 acquired = True 80 81 lock = Lock() 82 async with create_task_group() as tg: 83 await lock.acquire() 84 await tg.start(task) 85 tg.cancel_scope.cancel() 86 with CancelScope(shield=True): 87 if release_first: 88 lock.release() 89 await wait_all_tasks_blocked() 90 else: 91 await wait_all_tasks_blocked() 92 lock.release() 93 94 assert not acquired 95 assert not lock.locked() 96 97 async def test_statistics(self) -> None: 98 async def waiter() -> None: 99 async with lock: 100 pass 101 102 lock = Lock() 103 async with create_task_group() as tg: 104 assert not lock.statistics().locked 105 assert lock.statistics().tasks_waiting == 0 106 async with lock: 107 assert lock.statistics().locked 108 assert lock.statistics().tasks_waiting == 0 109 for i in range(1, 3): 110 tg.start_soon(waiter) 111 await wait_all_tasks_blocked() 112 assert lock.statistics().tasks_waiting == i 113 114 assert not lock.statistics().locked 115 assert lock.statistics().tasks_waiting == 0 116 117 118class TestEvent: 119 async def test_event(self) -> None: 120 async def setter() -> None: 121 assert not event.is_set() 122 event.set() 123 124 event = Event() 125 async with create_task_group() as tg: 126 tg.start_soon(setter) 127 await event.wait() 128 129 assert event.is_set() 130 131 async def test_event_cancel(self) -> None: 132 task_started = event_set = False 133 134 async def task() -> None: 135 nonlocal task_started, event_set 136 task_started = True 137 await event.wait() 138 event_set = True 139 140 event = Event() 141 async with create_task_group() as tg: 142 tg.start_soon(task) 143 tg.cancel_scope.cancel() 144 event.set() 145 146 assert task_started 147 assert not event_set 148 149 async def test_statistics(self) -> None: 150 async def waiter() -> None: 151 await event.wait() 152 153 event = Event() 154 async with create_task_group() as tg: 155 assert event.statistics().tasks_waiting == 0 156 for i in range(1, 3): 157 tg.start_soon(waiter) 158 await wait_all_tasks_blocked() 159 assert event.statistics().tasks_waiting == i 160 161 event.set() 162 163 assert event.statistics().tasks_waiting == 0 164 165 166class TestCondition: 167 async def test_contextmanager(self) -> None: 168 async def notifier() -> None: 169 async with condition: 170 condition.notify_all() 171 172 condition = Condition() 173 async with create_task_group() as tg: 174 async with condition: 175 assert condition.locked() 176 tg.start_soon(notifier) 177 await condition.wait() 178 179 async def test_manual_acquire(self) -> None: 180 async def notifier() -> None: 181 await condition.acquire() 182 try: 183 condition.notify_all() 184 finally: 185 condition.release() 186 187 condition = Condition() 188 async with create_task_group() as tg: 189 await condition.acquire() 190 try: 191 assert condition.locked() 192 tg.start_soon(notifier) 193 await condition.wait() 194 finally: 195 condition.release() 196 197 async def test_acquire_nowait(self) -> None: 198 condition = Condition() 199 condition.acquire_nowait() 200 assert condition.locked() 201 202 async def test_acquire_nowait_wouldblock(self) -> None: 203 async def try_lock() -> None: 204 pytest.raises(WouldBlock, condition.acquire_nowait) 205 206 condition = Condition() 207 async with condition, create_task_group() as tg: 208 assert condition.locked() 209 tg.start_soon(try_lock) 210 211 async def test_wait_cancel(self) -> None: 212 task_started = notified = False 213 214 async def task() -> None: 215 nonlocal task_started, notified 216 task_started = True 217 async with condition: 218 event.set() 219 await condition.wait() 220 notified = True 221 222 event = Event() 223 condition = Condition() 224 async with create_task_group() as tg: 225 tg.start_soon(task) 226 await event.wait() 227 await wait_all_tasks_blocked() 228 tg.cancel_scope.cancel() 229 230 assert task_started 231 assert not notified 232 233 async def test_statistics(self) -> None: 234 async def waiter() -> None: 235 async with condition: 236 await condition.wait() 237 238 condition = Condition() 239 async with create_task_group() as tg: 240 assert not condition.statistics().lock_statistics.locked 241 assert condition.statistics().tasks_waiting == 0 242 async with condition: 243 assert condition.statistics().lock_statistics.locked 244 assert condition.statistics().tasks_waiting == 0 245 246 for i in range(1, 3): 247 tg.start_soon(waiter) 248 await wait_all_tasks_blocked() 249 assert condition.statistics().tasks_waiting == i 250 251 for i in range(1, -1, -1): 252 async with condition: 253 condition.notify(1) 254 255 await wait_all_tasks_blocked() 256 assert condition.statistics().tasks_waiting == i 257 258 assert not condition.statistics().lock_statistics.locked 259 assert condition.statistics().tasks_waiting == 0 260 261 262class TestSemaphore: 263 async def test_contextmanager(self) -> None: 264 async def acquire() -> None: 265 async with semaphore: 266 assert semaphore.value in (0, 1) 267 268 semaphore = Semaphore(2) 269 async with create_task_group() as tg: 270 tg.start_soon(acquire, name='task 1') 271 tg.start_soon(acquire, name='task 2') 272 273 assert semaphore.value == 2 274 275 async def test_manual_acquire(self) -> None: 276 async def acquire() -> None: 277 await semaphore.acquire() 278 try: 279 assert semaphore.value in (0, 1) 280 finally: 281 semaphore.release() 282 283 semaphore = Semaphore(2) 284 async with create_task_group() as tg: 285 tg.start_soon(acquire, name='task 1') 286 tg.start_soon(acquire, name='task 2') 287 288 assert semaphore.value == 2 289 290 async def test_acquire_nowait(self) -> None: 291 semaphore = Semaphore(1) 292 semaphore.acquire_nowait() 293 assert semaphore.value == 0 294 pytest.raises(WouldBlock, semaphore.acquire_nowait) 295 296 @pytest.mark.parametrize('release_first', [ 297 pytest.param(False, id='releaselast'), 298 pytest.param(True, id='releasefirst') 299 ]) 300 async def test_cancel_during_acquire(self, release_first: bool) -> None: 301 acquired = False 302 303 async def task(*, task_status: TaskStatus) -> None: 304 nonlocal acquired 305 task_status.started() 306 async with semaphore: 307 acquired = True 308 309 semaphore = Semaphore(1) 310 async with create_task_group() as tg: 311 await semaphore.acquire() 312 await tg.start(task) 313 tg.cancel_scope.cancel() 314 with CancelScope(shield=True): 315 if release_first: 316 semaphore.release() 317 await wait_all_tasks_blocked() 318 else: 319 await wait_all_tasks_blocked() 320 semaphore.release() 321 322 assert not acquired 323 assert semaphore.value == 1 324 325 @pytest.mark.parametrize('max_value', [2, None]) 326 async def test_max_value(self, max_value: Optional[int]) -> None: 327 semaphore = Semaphore(0, max_value=max_value) 328 assert semaphore.max_value == max_value 329 330 async def test_max_value_exceeded(self) -> None: 331 semaphore = Semaphore(1, max_value=2) 332 semaphore.release() 333 pytest.raises(ValueError, semaphore.release) 334 335 async def test_statistics(self) -> None: 336 async def waiter() -> None: 337 async with semaphore: 338 pass 339 340 semaphore = Semaphore(1) 341 async with create_task_group() as tg: 342 assert semaphore.statistics().tasks_waiting == 0 343 async with semaphore: 344 assert semaphore.statistics().tasks_waiting == 0 345 for i in range(1, 3): 346 tg.start_soon(waiter) 347 await wait_all_tasks_blocked() 348 assert semaphore.statistics().tasks_waiting == i 349 350 assert semaphore.statistics().tasks_waiting == 0 351 352 async def test_acquire_race(self) -> None: 353 """ 354 Test against a race condition: when a task waiting on acquire() is rescheduled but another 355 task snatches the last available slot, the task should not raise WouldBlock. 356 357 """ 358 semaphore = Semaphore(1) 359 async with create_task_group() as tg: 360 semaphore.acquire_nowait() 361 tg.start_soon(semaphore.acquire) 362 await wait_all_tasks_blocked() 363 semaphore.release() 364 pytest.raises(WouldBlock, semaphore.acquire_nowait) 365 366 367class TestCapacityLimiter: 368 async def test_bad_init_type(self) -> None: 369 pytest.raises(TypeError, CapacityLimiter, 1.0).\ 370 match('total_tokens must be an int or math.inf') 371 372 async def test_bad_init_value(self) -> None: 373 pytest.raises(ValueError, CapacityLimiter, 0).\ 374 match('total_tokens must be >= 1') 375 376 async def test_borrow(self) -> None: 377 limiter = CapacityLimiter(2) 378 assert limiter.total_tokens == 2 379 assert limiter.available_tokens == 2 380 assert limiter.borrowed_tokens == 0 381 async with limiter: 382 assert limiter.total_tokens == 2 383 assert limiter.available_tokens == 1 384 assert limiter.borrowed_tokens == 1 385 386 async def test_limit(self) -> None: 387 value = 0 388 389 async def taskfunc() -> None: 390 nonlocal value 391 for _ in range(5): 392 async with limiter: 393 assert value == 0 394 value = 1 395 await wait_all_tasks_blocked() 396 value = 0 397 398 limiter = CapacityLimiter(1) 399 async with create_task_group() as tg: 400 for _ in range(3): 401 tg.start_soon(taskfunc) 402 403 async def test_borrow_twice(self) -> None: 404 limiter = CapacityLimiter(1) 405 await limiter.acquire() 406 with pytest.raises(RuntimeError) as exc: 407 await limiter.acquire() 408 409 exc.match("this borrower is already holding one of this CapacityLimiter's tokens") 410 411 async def test_bad_release(self) -> None: 412 limiter = CapacityLimiter(1) 413 with pytest.raises(RuntimeError) as exc: 414 limiter.release() 415 416 exc.match("this borrower isn't holding any of this CapacityLimiter's tokens") 417 418 async def test_increase_tokens(self) -> None: 419 async def setter() -> None: 420 # Wait until waiter() is inside the limiter block 421 await event1.wait() 422 async with limiter: 423 # This can only happen when total_tokens has been increased 424 event2.set() 425 426 async def waiter() -> None: 427 async with limiter: 428 event1.set() 429 await event2.wait() 430 431 limiter = CapacityLimiter(1) 432 event1, event2 = Event(), Event() 433 async with create_task_group() as tg: 434 tg.start_soon(setter) 435 tg.start_soon(waiter) 436 await wait_all_tasks_blocked() 437 assert event1.is_set() 438 assert not event2.is_set() 439 limiter.total_tokens = 2 440 441 assert event2.is_set() 442 443 async def test_current_default_thread_limiter(self) -> None: 444 limiter = to_thread.current_default_thread_limiter() 445 assert isinstance(limiter, CapacityLimiter) 446 assert limiter.total_tokens == 40 447 448 async def test_statistics(self) -> None: 449 async def waiter() -> None: 450 async with limiter: 451 pass 452 453 limiter = CapacityLimiter(1) 454 assert limiter.statistics().total_tokens == 1 455 assert limiter.statistics().borrowed_tokens == 0 456 assert limiter.statistics().tasks_waiting == 0 457 async with create_task_group() as tg: 458 async with limiter: 459 assert limiter.statistics().borrowed_tokens == 1 460 assert limiter.statistics().tasks_waiting == 0 461 for i in range(1, 3): 462 tg.start_soon(waiter) 463 await wait_all_tasks_blocked() 464 assert limiter.statistics().tasks_waiting == i 465 466 assert limiter.statistics().tasks_waiting == 0 467 assert limiter.statistics().borrowed_tokens == 0 468