1from __future__ import annotations
2
3import uuid
4from contextlib import asynccontextmanager, contextmanager
5from datetime import datetime
6from typing import TYPE_CHECKING
7
8from tornado.ioloop import PeriodicCallback
9
10if TYPE_CHECKING:
11    # circular dependencies
12    from ..client import Client
13    from ..scheduler import Scheduler
14
15
16class MemorySampler:
17    """Sample cluster-wide memory usage every <interval> seconds.
18
19    **Usage**
20
21    .. code-block:: python
22
23       client = Client()
24       ms = MemorySampler()
25
26       with ms.sample("run 1"):
27           <run first workflow>
28       with ms.sample("run 2"):
29           <run second workflow>
30       ...
31       ms.plot()
32
33    or with an asynchronous client:
34
35    .. code-block:: python
36
37       client = await Client(asynchronous=True)
38       ms = MemorySampler()
39
40       async with ms.sample("run 1"):
41           <run first workflow>
42       async with ms.sample("run 2"):
43           <run second workflow>
44       ...
45       ms.plot()
46    """
47
48    samples: dict[str, list[tuple[float, int]]]
49
50    def __init__(self):
51        self.samples = {}
52
53    def sample(
54        self,
55        label: str | None = None,
56        *,
57        client: Client | None = None,
58        measure: str = "process",
59        interval: float = 0.5,
60    ):
61        """Context manager that records memory usage in the cluster.
62        This is synchronous if the client is synchronous and
63        asynchronous if the client is asynchronous.
64
65        The samples are recorded in ``self.samples[<label>]``.
66
67        Parameters
68        ==========
69        label: str, optional
70            Tag to record the samples under in the self.samples dict.
71            Default: automatically generate a random label
72        client: Client, optional
73            client used to connect to the scheduler.
74            Default: use the global client
75        measure: str, optional
76            One of the measures from :class:`distributed.scheduler.MemoryState`.
77            Default: sample process memory
78        interval: float, optional
79            sampling interval, in seconds.
80            Default: 0.5
81        """
82        if not client:
83            from ..client import get_client
84
85            client = get_client()
86
87        if client.asynchronous:
88            return self._sample_async(label, client, measure, interval)
89        else:
90            return self._sample_sync(label, client, measure, interval)
91
92    @contextmanager
93    def _sample_sync(
94        self, label: str | None, client: Client, measure: str, interval: float
95    ):
96        key = client.sync(
97            client.scheduler.memory_sampler_start,
98            client=client.id,
99            measure=measure,
100            interval=interval,
101        )
102        try:
103            yield
104        finally:
105            samples = client.sync(client.scheduler.memory_sampler_stop, key=key)
106            self.samples[label or key] = samples
107
108    @asynccontextmanager
109    async def _sample_async(
110        self, label: str | None, client: Client, measure: str, interval: float
111    ):
112        key = await client.scheduler.memory_sampler_start(
113            client=client.id, measure=measure, interval=interval
114        )
115        try:
116            yield
117        finally:
118            samples = await client.scheduler.memory_sampler_stop(key=key)
119            self.samples[label or key] = samples
120
121    def to_pandas(self, *, align: bool = False):
122        """Return the data series as a pandas.Dataframe.
123
124        Parameters
125        ==========
126        align : bool, optional
127            If True, change the absolute timestamps into time deltas from the first
128            sample of each series, so that different series can be visualized side by
129            side. If False (the default), use absolute timestamps.
130        """
131        import pandas as pd
132
133        ss = {}
134        for (label, s_list) in self.samples.items():
135            assert s_list  # There's always at least one sample
136            s = pd.DataFrame(s_list).set_index(0)[1]
137            s.index = pd.to_datetime(s.index, unit="s")
138            s.name = label
139            if align:
140                # convert datetime to timedelta from the first sample
141                s.index -= s.index[0]
142            ss[label] = s
143
144        df = pd.DataFrame(ss)
145
146        if len(ss) > 1:
147            # Forward-fill NaNs in the middle of a series created either by overlapping
148            # sampling time range or by align=True. Do not ffill series beyond their
149            # last sample.
150            df = df.ffill().where(~pd.isna(df.bfill()))
151
152        return df
153
154    def plot(self, *, align: bool = False, **kwargs):
155        """Plot data series collected so far
156
157        Parameters
158        ==========
159        align : bool (optional)
160            See :meth:`~distributed.diagnostics.MemorySampler.to_pandas`
161        kwargs
162            Passed verbatim to :meth:`pandas.DataFrame.plot`
163        """
164        df = self.to_pandas(align=align) / 2 ** 30
165        return df.plot(
166            xlabel="time",
167            ylabel="Cluster memory (GiB)",
168            **kwargs,
169        )
170
171
172class MemorySamplerExtension:
173    """Scheduler extension - server side of MemorySampler"""
174
175    scheduler: Scheduler
176    samples: dict[str, list[tuple[float, int]]]
177
178    def __init__(self, scheduler: Scheduler):
179        self.scheduler = scheduler
180        self.scheduler.extensions["memory_sampler"] = self
181        self.scheduler.handlers["memory_sampler_start"] = self.start
182        self.scheduler.handlers["memory_sampler_stop"] = self.stop
183        self.samples = {}
184
185    def start(self, comm, client: str, measure: str, interval: float) -> str:
186        """Start periodically sampling memory"""
187        assert not measure.startswith("_")
188        assert isinstance(getattr(self.scheduler.memory, measure), int)
189
190        key = str(uuid.uuid4())
191        self.samples[key] = []
192
193        def sample():
194            if client in self.scheduler.clients:
195                ts = datetime.now().timestamp()
196                nbytes = getattr(self.scheduler.memory, measure)
197                self.samples[key].append((ts, nbytes))
198            else:
199                self.stop(comm, key)
200
201        pc = PeriodicCallback(sample, interval * 1000)
202        self.scheduler.periodic_callbacks["MemorySampler-" + key] = pc
203        pc.start()
204
205        # Immediately collect the first sample; this also ensures there's always at
206        # least one sample
207        sample()
208
209        return key
210
211    def stop(self, comm, key: str):
212        """Stop sampling and return the samples"""
213        pc = self.scheduler.periodic_callbacks.pop("MemorySampler-" + key)
214        pc.stop()
215        return self.samples.pop(key)
216