1# Copyright 2021 The Duet Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Internal implementation details for duet."""
16
17import enum
18import functools
19import heapq
20import itertools
21import signal
22import threading
23import time
24from concurrent.futures import Future
25from contextvars import ContextVar
26from typing import (
27    Any,
28    Awaitable,
29    Callable,
30    cast,
31    Coroutine,
32    Generic,
33    Iterator,
34    List,
35    Optional,
36    Set,
37    TypeVar,
38)
39
40import duet.futuretools as futuretools
41
42T = TypeVar("T")
43
44
45class Interrupt(BaseException):
46    def __init__(self, task, error):
47        self.task = task
48        self.error = error
49
50
51class TaskState(enum.Enum):
52    WAITING = 0
53    SUCCEEDED = 1
54    FAILED = 2
55
56
57class TaskStateError(Exception):
58    def __init__(self, state: TaskState, expected_state: TaskState) -> None:
59        self.state = state
60        self.expected_state = expected_state
61        super().__init__(f"state: {state}, expected: {expected_state}")
62
63
64# Sentinel local variable name that we insert into coroutines.
65# This allows us to detect whether a task is running when we get Ctrl-C.
66LOCALS_TASK_SCHEDULER = "__duet_task_scheduler__"
67
68
69class Task(Generic[T]):
70    def __init__(
71        self,
72        awaitable: Awaitable[T],
73        scheduler: "Scheduler",
74        main_task: Optional["Task"],
75    ) -> None:
76        self.scheduler = scheduler
77        self.main_task = main_task
78        self._state = TaskState.WAITING
79        self._future: Optional[Future] = None
80        self._ready_future = futuretools.AwaitableFuture[None]()
81        self._ready_future.set_result(None)  # Ready to advance.
82        self.interruptible = True
83        self._interrupt: Optional[Interrupt] = None
84        self._result: Optional[T] = None
85        self._error: Optional[Exception] = None
86        self._deadlines: List[DeadlineEntry] = []
87        if main_task and main_task.deadline is not None:
88            self.push_deadline(main_task.deadline)
89        self._generator = awaitable.__await__()  # Returns coroutine generator.
90        if isinstance(awaitable, Coroutine):
91            awaitable.cr_frame.f_locals.setdefault(LOCALS_TASK_SCHEDULER, scheduler)
92
93    def _check_state(self, expected_state: TaskState) -> None:
94        if self._state != expected_state:
95            raise TaskStateError(self._state, expected_state)
96
97    @property
98    def future(self) -> Optional[Future]:
99        self._check_state(TaskState.WAITING)
100        return self._future
101
102    @property
103    def result(self) -> T:
104        self._check_state(TaskState.SUCCEEDED)
105        return cast(T, self._result)
106
107    @property
108    def done(self) -> bool:
109        return self._state == TaskState.SUCCEEDED or self._state == TaskState.FAILED
110
111    def add_ready_callback(self, callback: Callable[["Task"], Any]) -> None:
112        self._check_state(TaskState.WAITING)
113        self._ready_future.add_done_callback(lambda _: callback(self))
114
115    def advance(self):
116        if self.done:
117            return
118        if self._state == TaskState.WAITING:
119            self._ready_future.result()
120        token = _current_task.set(self)
121        try:
122            if self._interrupt:
123                interrupt = self._interrupt
124                self._interrupt = None
125                if interrupt.task is self:
126                    error = interrupt.error
127                else:
128                    error = interrupt
129                f = self._generator.throw(error)
130            else:
131                f = next(self._generator)
132        except StopIteration as e:
133            self._result = e.value
134            self._state = TaskState.SUCCEEDED
135            return
136        except (Interrupt, Exception) as error:
137            self._error = error
138            self._state = TaskState.FAILED
139            if self.main_task:
140                self.main_task.interrupt(self, error)
141                return
142            else:
143                raise
144        else:
145            if not isinstance(f, Future):
146                raise TypeError(f"expected Future, got {type(f)}: {f}")
147            ready_future = futuretools.AwaitableFuture()
148            f.add_done_callback(lambda _: ready_future.try_set_result(None))
149            self._future = f
150            self._ready_future = ready_future
151            self._state = TaskState.WAITING
152        finally:
153            _current_task.reset(token)
154
155    def push_deadline(self, deadline: float) -> None:
156        if self._deadlines:
157            deadline = min(self._deadlines[-1].deadline, deadline)
158        entry = self.scheduler.add_deadline(self, deadline)
159        self._deadlines.append(entry)
160
161    def pop_deadline(self) -> None:
162        entry = self._deadlines.pop(-1)
163        entry.valid = False
164
165    @property
166    def deadline(self) -> Optional[float]:
167        return self._deadlines[-1].deadline if self._deadlines else None
168
169    def interrupt(self, task, error):
170        if self.done or not self.interruptible or self._interrupt:
171            return
172        self._interrupt = Interrupt(task, error)
173        self._ready_future.try_set_result(None)
174        if self._future:
175            self._future.cancel()
176
177    def close(self):
178        self._generator.close()
179        self.scheduler = None
180        self.main_task = None
181
182
183_current_task: ContextVar[Task] = ContextVar("current_task")
184
185
186def current_task() -> Task:
187    """Gets the currently-running duet task.
188
189    This must be called from within a running async function, or else it will
190    raise a RuntimeError.
191    """
192    try:
193        return _current_task.get()
194    except LookupError:
195        raise RuntimeError("Can only be called from an async function.")
196
197
198def current_scheduler() -> "Scheduler":
199    """Gets the currently-running duet scheduler.
200
201    This must be called from within a running async function, or else it will
202    raise a RuntimeError.
203    """
204    return current_task().scheduler
205
206
207def any_ready(tasks: Set[Task]) -> futuretools.AwaitableFuture[None]:
208    """Returns a Future that will fire when any of the given tasks is ready."""
209    if not tasks or any(task.done for task in tasks):
210        return futuretools.completed_future(None)
211    f = futuretools.AwaitableFuture[None]()
212    for task in tasks:
213        task.add_ready_callback(lambda _: f.try_set_result(None))
214    return f
215
216
217class ReadySet:
218    """Container for an ordered set of tasks that are ready to advance."""
219
220    def __init__(self):
221        self._cond = threading.Condition()
222        self._buffer = futuretools.BufferGroup()
223        self._tasks: List[Task] = []
224        self._task_set: Set[Task] = set()
225
226    def register(self, task: Task) -> None:
227        """Registers task to be added to this set when it is ready."""
228        self._buffer.add(task.future)
229        task.add_ready_callback(self._add)
230
231    def _add(self, task: Task) -> None:
232        """Adds the given task to the ready set, if it is not already there."""
233        with self._cond:
234            if task not in self._task_set:
235                self._task_set.add(task)
236                self._tasks.append(task)
237                self._cond.notify()
238
239    def get_all(self, timeout: Optional[float] = None) -> List[Task]:
240        """Gets all ready tasks and clears the ready set.
241
242        If no tasks are ready yet, we flush buffered futures to notify them
243        that they should proceed, and then block until one or more tasks become
244        ready.
245
246        Raises:
247            ValueError if timeout is < 0 or > threading.TIMEOUT_MAX
248        """
249        if timeout is not None and (timeout < 0 or timeout > threading.TIMEOUT_MAX):
250            raise ValueError(f"invalid timeout: {timeout}")
251        with self._cond:
252            if self._tasks:
253                return self._pop_tasks()
254        # Flush buffered futures to ensure we make progress. Note that we must
255        # release the condition lock before flushing to avoid a deadlock if
256        # buffered futures complete and trigger a call to self._add.
257        self._buffer.flush()
258        with self._cond:
259            if not self._tasks:
260                if not self._cond.wait(timeout):
261                    raise TimeoutError()
262            return self._pop_tasks()
263
264    def _pop_tasks(self) -> List[Task]:
265        tasks = self._tasks
266        self._tasks = []
267        self._task_set.clear()
268        return tasks
269
270    def interrupt(self) -> None:
271        with self._cond:
272            self._cond.notify()
273
274
275@functools.total_ordering
276class DeadlineEntry:
277    """A entry for one Deadline in the Scheduler's priority queue.
278
279    This follows the implementation notes in the stdlib heapq docs:
280    https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes
281
282    Attributes:
283        task: The task associated with this deadline.
284        deadline: Absolute time when the deadline will elapse.
285        count: Monotonically-increasing counter to preserve creation order when
286            comparing entries with the same deadline.
287        valid: Flag indicating whether the deadline is still valid. If the task
288            exits its scope before the deadline elapses, we mark the deadline as
289            invalid but leave it in the scheduler's priority queue since removal
290            would require an O(n) scan. The scheduler ignores invalid deadlines
291            when they elapse.
292    """
293
294    _counter = itertools.count()
295
296    def __init__(self, task: Task, deadline: float):
297        self.task = task
298        self.deadline = deadline
299        self.count = next(self._counter)
300        self._cmp_val = (deadline, self.count)
301        self.valid = True
302
303    def __eq__(self, other: Any) -> bool:
304        if not isinstance(other, DeadlineEntry):
305            return NotImplemented
306        return self._cmp_val == other._cmp_val
307
308    def __lt__(self, other: Any) -> bool:
309        if not isinstance(other, DeadlineEntry):
310            return NotImplemented
311        return self._cmp_val < other._cmp_val
312
313    def __repr__(self) -> str:
314        return f"DeadlineEntry({self.task}, {self.deadline}, {self.count})"
315
316
317class Scheduler:
318    def __init__(self) -> None:
319        self.active_tasks: Set[Task] = set()
320        self._ready_tasks = ReadySet()
321        self._prev_signal: Optional[Callable] = None
322        self._interrupted = False
323        self._deadlines: List[DeadlineEntry] = []
324
325    def spawn(self, awaitable: Awaitable[Any], main_task: Optional[Task] = None) -> Task:
326        """Spawns a new Task to run an awaitable in this Scheduler.
327
328        Note that the task will not be advanced until the next scheduler tick.
329        Also, note that this function is safe to call from sync code (such as
330        duet.run) or async code (such as within a scope).
331
332        Args:
333            func: The async function to run.
334            *args: Args for func.
335            **kwds: Keyword args for func.
336
337        Returns:
338            A Task to run the given awaitable.
339        """
340        task = Task(awaitable, scheduler=self, main_task=main_task)
341        self.active_tasks.add(task)
342        self._ready_tasks.register(task)
343        return task
344
345    def time(self) -> float:
346        return time.time()
347
348    def add_deadline(self, task: Task, deadline: float) -> DeadlineEntry:
349        entry = DeadlineEntry(task, deadline=deadline)
350        heapq.heappush(self._deadlines, entry)
351        return entry
352
353    def get_next_deadline(self) -> Optional[float]:
354        while self._deadlines:
355            if not self._deadlines[0].valid:
356                heapq.heappop(self._deadlines)
357                continue
358            return self._deadlines[0].deadline
359        return None
360
361    def get_deadline_tasks(self, deadline: float) -> Iterator[Task]:
362        while self._deadlines and self._deadlines[0].deadline <= deadline:
363            entry = heapq.heappop(self._deadlines)
364            if entry.valid:
365                yield entry.task
366
367    def tick(self):
368        """Runs the scheduler ahead by one tick.
369
370        This waits for at least one active task to complete, then advances all
371        ready tasks and sets up a new future to be notified later by tasks that
372        are still active (or yet to be spawned). Raises a RuntimeError if there
373        are no currently active tasks.
374        """
375        if not self.active_tasks:
376            raise RuntimeError("tick called with no active tasks")
377
378        if self._interrupted:
379            task = next(iter(self.active_tasks))
380            task.interrupt(task, KeyboardInterrupt)
381            self._interrupted = False
382
383        deadline = self.get_next_deadline()
384        if deadline is None:
385            ready_tasks = self._ready_tasks.get_all(None)
386        else:
387            ready_tasks: List[Task] = []
388            for i in itertools.count():
389                timeout = deadline - self.time()
390                if i and timeout < 0:
391                    break
392                try:
393                    ready_tasks = self._ready_tasks.get_all(
394                        min(0, max(timeout, threading.TIMEOUT_MAX))
395                    )
396                    break
397                except TimeoutError:
398                    pass
399            if not ready_tasks:
400                for task in self.get_deadline_tasks(deadline):
401                    task.interrupt(task, TimeoutError())
402                ready_tasks = self._ready_tasks.get_all(None)
403        for task in ready_tasks:
404            try:
405                task.advance()
406            finally:
407                if task.done:
408                    task.close()
409                    self.active_tasks.discard(task)
410                else:
411                    self._ready_tasks.register(task)
412
413    def _interrupt(self, signum: int, frame: Optional[Any]) -> None:
414        """Interrupt signal handler used while this scheduler is running.
415
416        This is inspired by trio's interrupt handling, described here:
417        https://vorpus.org/blog/control-c-handling-in-python-and-trio/
418
419        If the interrupted frame is inside a running task, which we detect by
420        looking for a special local variable inserted into the task coroutine,
421        we simply raise a KeyboardInterrupt as usual. Otherwise we set a flag
422        which will get checked on the next tick() and cause a task to be
423        interrupted.
424
425        One important difference from trio is that duet is reentrant, so when
426        detecting whether we are in a task we have to check whether the task's
427        scheduler is self. If the interrupted frame is running in a task of a
428        different scheduler, that should not raise KeyboardInterrupt directly.
429        """
430        if self._in_task(frame):
431            raise KeyboardInterrupt
432        else:
433            self._interrupted = True
434            self._ready_tasks.interrupt()
435
436    def _in_task(self, frame) -> bool:
437        while frame is not None:
438            if frame.f_locals.get(LOCALS_TASK_SCHEDULER, None) is self:
439                return True
440            frame = frame.f_back
441        return False
442
443    def __enter__(self):
444        if (
445            threading.current_thread() == threading.main_thread()
446            and signal.getsignal(signal.SIGINT) == signal.default_int_handler
447        ):
448            self._prev_signal = signal.signal(signal.SIGINT, self._interrupt)
449        return self
450
451    def __exit__(self, exc_type, exc, tb):
452        def finish_tasks(error=None):
453            if error:
454                for task in self.active_tasks:
455                    task.interrupt(None, error)
456            while self.active_tasks:
457                try:
458                    self.tick()
459                except Exception:
460                    if not error:
461                        raise
462
463        try:
464            if exc:
465                finish_tasks(exc)
466            else:
467                try:
468                    finish_tasks()
469                except Exception as exc:
470                    finish_tasks(exc)
471                    raise
472        finally:
473            if self._prev_signal:
474                signal.signal(signal.SIGINT, self._prev_signal)
475