1# Copyright 2021 The Duet Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Internal implementation details for duet.""" 16 17import enum 18import functools 19import heapq 20import itertools 21import signal 22import threading 23import time 24from concurrent.futures import Future 25from contextvars import ContextVar 26from typing import ( 27 Any, 28 Awaitable, 29 Callable, 30 cast, 31 Coroutine, 32 Generic, 33 Iterator, 34 List, 35 Optional, 36 Set, 37 TypeVar, 38) 39 40import duet.futuretools as futuretools 41 42T = TypeVar("T") 43 44 45class Interrupt(BaseException): 46 def __init__(self, task, error): 47 self.task = task 48 self.error = error 49 50 51class TaskState(enum.Enum): 52 WAITING = 0 53 SUCCEEDED = 1 54 FAILED = 2 55 56 57class TaskStateError(Exception): 58 def __init__(self, state: TaskState, expected_state: TaskState) -> None: 59 self.state = state 60 self.expected_state = expected_state 61 super().__init__(f"state: {state}, expected: {expected_state}") 62 63 64# Sentinel local variable name that we insert into coroutines. 65# This allows us to detect whether a task is running when we get Ctrl-C. 66LOCALS_TASK_SCHEDULER = "__duet_task_scheduler__" 67 68 69class Task(Generic[T]): 70 def __init__( 71 self, 72 awaitable: Awaitable[T], 73 scheduler: "Scheduler", 74 main_task: Optional["Task"], 75 ) -> None: 76 self.scheduler = scheduler 77 self.main_task = main_task 78 self._state = TaskState.WAITING 79 self._future: Optional[Future] = None 80 self._ready_future = futuretools.AwaitableFuture[None]() 81 self._ready_future.set_result(None) # Ready to advance. 82 self.interruptible = True 83 self._interrupt: Optional[Interrupt] = None 84 self._result: Optional[T] = None 85 self._error: Optional[Exception] = None 86 self._deadlines: List[DeadlineEntry] = [] 87 if main_task and main_task.deadline is not None: 88 self.push_deadline(main_task.deadline) 89 self._generator = awaitable.__await__() # Returns coroutine generator. 90 if isinstance(awaitable, Coroutine): 91 awaitable.cr_frame.f_locals.setdefault(LOCALS_TASK_SCHEDULER, scheduler) 92 93 def _check_state(self, expected_state: TaskState) -> None: 94 if self._state != expected_state: 95 raise TaskStateError(self._state, expected_state) 96 97 @property 98 def future(self) -> Optional[Future]: 99 self._check_state(TaskState.WAITING) 100 return self._future 101 102 @property 103 def result(self) -> T: 104 self._check_state(TaskState.SUCCEEDED) 105 return cast(T, self._result) 106 107 @property 108 def done(self) -> bool: 109 return self._state == TaskState.SUCCEEDED or self._state == TaskState.FAILED 110 111 def add_ready_callback(self, callback: Callable[["Task"], Any]) -> None: 112 self._check_state(TaskState.WAITING) 113 self._ready_future.add_done_callback(lambda _: callback(self)) 114 115 def advance(self): 116 if self.done: 117 return 118 if self._state == TaskState.WAITING: 119 self._ready_future.result() 120 token = _current_task.set(self) 121 try: 122 if self._interrupt: 123 interrupt = self._interrupt 124 self._interrupt = None 125 if interrupt.task is self: 126 error = interrupt.error 127 else: 128 error = interrupt 129 f = self._generator.throw(error) 130 else: 131 f = next(self._generator) 132 except StopIteration as e: 133 self._result = e.value 134 self._state = TaskState.SUCCEEDED 135 return 136 except (Interrupt, Exception) as error: 137 self._error = error 138 self._state = TaskState.FAILED 139 if self.main_task: 140 self.main_task.interrupt(self, error) 141 return 142 else: 143 raise 144 else: 145 if not isinstance(f, Future): 146 raise TypeError(f"expected Future, got {type(f)}: {f}") 147 ready_future = futuretools.AwaitableFuture() 148 f.add_done_callback(lambda _: ready_future.try_set_result(None)) 149 self._future = f 150 self._ready_future = ready_future 151 self._state = TaskState.WAITING 152 finally: 153 _current_task.reset(token) 154 155 def push_deadline(self, deadline: float) -> None: 156 if self._deadlines: 157 deadline = min(self._deadlines[-1].deadline, deadline) 158 entry = self.scheduler.add_deadline(self, deadline) 159 self._deadlines.append(entry) 160 161 def pop_deadline(self) -> None: 162 entry = self._deadlines.pop(-1) 163 entry.valid = False 164 165 @property 166 def deadline(self) -> Optional[float]: 167 return self._deadlines[-1].deadline if self._deadlines else None 168 169 def interrupt(self, task, error): 170 if self.done or not self.interruptible or self._interrupt: 171 return 172 self._interrupt = Interrupt(task, error) 173 self._ready_future.try_set_result(None) 174 if self._future: 175 self._future.cancel() 176 177 def close(self): 178 self._generator.close() 179 self.scheduler = None 180 self.main_task = None 181 182 183_current_task: ContextVar[Task] = ContextVar("current_task") 184 185 186def current_task() -> Task: 187 """Gets the currently-running duet task. 188 189 This must be called from within a running async function, or else it will 190 raise a RuntimeError. 191 """ 192 try: 193 return _current_task.get() 194 except LookupError: 195 raise RuntimeError("Can only be called from an async function.") 196 197 198def current_scheduler() -> "Scheduler": 199 """Gets the currently-running duet scheduler. 200 201 This must be called from within a running async function, or else it will 202 raise a RuntimeError. 203 """ 204 return current_task().scheduler 205 206 207def any_ready(tasks: Set[Task]) -> futuretools.AwaitableFuture[None]: 208 """Returns a Future that will fire when any of the given tasks is ready.""" 209 if not tasks or any(task.done for task in tasks): 210 return futuretools.completed_future(None) 211 f = futuretools.AwaitableFuture[None]() 212 for task in tasks: 213 task.add_ready_callback(lambda _: f.try_set_result(None)) 214 return f 215 216 217class ReadySet: 218 """Container for an ordered set of tasks that are ready to advance.""" 219 220 def __init__(self): 221 self._cond = threading.Condition() 222 self._buffer = futuretools.BufferGroup() 223 self._tasks: List[Task] = [] 224 self._task_set: Set[Task] = set() 225 226 def register(self, task: Task) -> None: 227 """Registers task to be added to this set when it is ready.""" 228 self._buffer.add(task.future) 229 task.add_ready_callback(self._add) 230 231 def _add(self, task: Task) -> None: 232 """Adds the given task to the ready set, if it is not already there.""" 233 with self._cond: 234 if task not in self._task_set: 235 self._task_set.add(task) 236 self._tasks.append(task) 237 self._cond.notify() 238 239 def get_all(self, timeout: Optional[float] = None) -> List[Task]: 240 """Gets all ready tasks and clears the ready set. 241 242 If no tasks are ready yet, we flush buffered futures to notify them 243 that they should proceed, and then block until one or more tasks become 244 ready. 245 246 Raises: 247 ValueError if timeout is < 0 or > threading.TIMEOUT_MAX 248 """ 249 if timeout is not None and (timeout < 0 or timeout > threading.TIMEOUT_MAX): 250 raise ValueError(f"invalid timeout: {timeout}") 251 with self._cond: 252 if self._tasks: 253 return self._pop_tasks() 254 # Flush buffered futures to ensure we make progress. Note that we must 255 # release the condition lock before flushing to avoid a deadlock if 256 # buffered futures complete and trigger a call to self._add. 257 self._buffer.flush() 258 with self._cond: 259 if not self._tasks: 260 if not self._cond.wait(timeout): 261 raise TimeoutError() 262 return self._pop_tasks() 263 264 def _pop_tasks(self) -> List[Task]: 265 tasks = self._tasks 266 self._tasks = [] 267 self._task_set.clear() 268 return tasks 269 270 def interrupt(self) -> None: 271 with self._cond: 272 self._cond.notify() 273 274 275@functools.total_ordering 276class DeadlineEntry: 277 """A entry for one Deadline in the Scheduler's priority queue. 278 279 This follows the implementation notes in the stdlib heapq docs: 280 https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes 281 282 Attributes: 283 task: The task associated with this deadline. 284 deadline: Absolute time when the deadline will elapse. 285 count: Monotonically-increasing counter to preserve creation order when 286 comparing entries with the same deadline. 287 valid: Flag indicating whether the deadline is still valid. If the task 288 exits its scope before the deadline elapses, we mark the deadline as 289 invalid but leave it in the scheduler's priority queue since removal 290 would require an O(n) scan. The scheduler ignores invalid deadlines 291 when they elapse. 292 """ 293 294 _counter = itertools.count() 295 296 def __init__(self, task: Task, deadline: float): 297 self.task = task 298 self.deadline = deadline 299 self.count = next(self._counter) 300 self._cmp_val = (deadline, self.count) 301 self.valid = True 302 303 def __eq__(self, other: Any) -> bool: 304 if not isinstance(other, DeadlineEntry): 305 return NotImplemented 306 return self._cmp_val == other._cmp_val 307 308 def __lt__(self, other: Any) -> bool: 309 if not isinstance(other, DeadlineEntry): 310 return NotImplemented 311 return self._cmp_val < other._cmp_val 312 313 def __repr__(self) -> str: 314 return f"DeadlineEntry({self.task}, {self.deadline}, {self.count})" 315 316 317class Scheduler: 318 def __init__(self) -> None: 319 self.active_tasks: Set[Task] = set() 320 self._ready_tasks = ReadySet() 321 self._prev_signal: Optional[Callable] = None 322 self._interrupted = False 323 self._deadlines: List[DeadlineEntry] = [] 324 325 def spawn(self, awaitable: Awaitable[Any], main_task: Optional[Task] = None) -> Task: 326 """Spawns a new Task to run an awaitable in this Scheduler. 327 328 Note that the task will not be advanced until the next scheduler tick. 329 Also, note that this function is safe to call from sync code (such as 330 duet.run) or async code (such as within a scope). 331 332 Args: 333 func: The async function to run. 334 *args: Args for func. 335 **kwds: Keyword args for func. 336 337 Returns: 338 A Task to run the given awaitable. 339 """ 340 task = Task(awaitable, scheduler=self, main_task=main_task) 341 self.active_tasks.add(task) 342 self._ready_tasks.register(task) 343 return task 344 345 def time(self) -> float: 346 return time.time() 347 348 def add_deadline(self, task: Task, deadline: float) -> DeadlineEntry: 349 entry = DeadlineEntry(task, deadline=deadline) 350 heapq.heappush(self._deadlines, entry) 351 return entry 352 353 def get_next_deadline(self) -> Optional[float]: 354 while self._deadlines: 355 if not self._deadlines[0].valid: 356 heapq.heappop(self._deadlines) 357 continue 358 return self._deadlines[0].deadline 359 return None 360 361 def get_deadline_tasks(self, deadline: float) -> Iterator[Task]: 362 while self._deadlines and self._deadlines[0].deadline <= deadline: 363 entry = heapq.heappop(self._deadlines) 364 if entry.valid: 365 yield entry.task 366 367 def tick(self): 368 """Runs the scheduler ahead by one tick. 369 370 This waits for at least one active task to complete, then advances all 371 ready tasks and sets up a new future to be notified later by tasks that 372 are still active (or yet to be spawned). Raises a RuntimeError if there 373 are no currently active tasks. 374 """ 375 if not self.active_tasks: 376 raise RuntimeError("tick called with no active tasks") 377 378 if self._interrupted: 379 task = next(iter(self.active_tasks)) 380 task.interrupt(task, KeyboardInterrupt) 381 self._interrupted = False 382 383 deadline = self.get_next_deadline() 384 if deadline is None: 385 ready_tasks = self._ready_tasks.get_all(None) 386 else: 387 ready_tasks: List[Task] = [] 388 for i in itertools.count(): 389 timeout = deadline - self.time() 390 if i and timeout < 0: 391 break 392 try: 393 ready_tasks = self._ready_tasks.get_all( 394 min(0, max(timeout, threading.TIMEOUT_MAX)) 395 ) 396 break 397 except TimeoutError: 398 pass 399 if not ready_tasks: 400 for task in self.get_deadline_tasks(deadline): 401 task.interrupt(task, TimeoutError()) 402 ready_tasks = self._ready_tasks.get_all(None) 403 for task in ready_tasks: 404 try: 405 task.advance() 406 finally: 407 if task.done: 408 task.close() 409 self.active_tasks.discard(task) 410 else: 411 self._ready_tasks.register(task) 412 413 def _interrupt(self, signum: int, frame: Optional[Any]) -> None: 414 """Interrupt signal handler used while this scheduler is running. 415 416 This is inspired by trio's interrupt handling, described here: 417 https://vorpus.org/blog/control-c-handling-in-python-and-trio/ 418 419 If the interrupted frame is inside a running task, which we detect by 420 looking for a special local variable inserted into the task coroutine, 421 we simply raise a KeyboardInterrupt as usual. Otherwise we set a flag 422 which will get checked on the next tick() and cause a task to be 423 interrupted. 424 425 One important difference from trio is that duet is reentrant, so when 426 detecting whether we are in a task we have to check whether the task's 427 scheduler is self. If the interrupted frame is running in a task of a 428 different scheduler, that should not raise KeyboardInterrupt directly. 429 """ 430 if self._in_task(frame): 431 raise KeyboardInterrupt 432 else: 433 self._interrupted = True 434 self._ready_tasks.interrupt() 435 436 def _in_task(self, frame) -> bool: 437 while frame is not None: 438 if frame.f_locals.get(LOCALS_TASK_SCHEDULER, None) is self: 439 return True 440 frame = frame.f_back 441 return False 442 443 def __enter__(self): 444 if ( 445 threading.current_thread() == threading.main_thread() 446 and signal.getsignal(signal.SIGINT) == signal.default_int_handler 447 ): 448 self._prev_signal = signal.signal(signal.SIGINT, self._interrupt) 449 return self 450 451 def __exit__(self, exc_type, exc, tb): 452 def finish_tasks(error=None): 453 if error: 454 for task in self.active_tasks: 455 task.interrupt(None, error) 456 while self.active_tasks: 457 try: 458 self.tick() 459 except Exception: 460 if not error: 461 raise 462 463 try: 464 if exc: 465 finish_tasks(exc) 466 else: 467 try: 468 finish_tasks() 469 except Exception as exc: 470 finish_tasks(exc) 471 raise 472 finally: 473 if self._prev_signal: 474 signal.signal(signal.SIGINT, self._prev_signal) 475