1from __future__ import annotations 2 3import uuid 4from contextlib import asynccontextmanager, contextmanager 5from datetime import datetime 6from typing import TYPE_CHECKING 7 8from tornado.ioloop import PeriodicCallback 9 10if TYPE_CHECKING: 11 # circular dependencies 12 from ..client import Client 13 from ..scheduler import Scheduler 14 15 16class MemorySampler: 17 """Sample cluster-wide memory usage every <interval> seconds. 18 19 **Usage** 20 21 .. code-block:: python 22 23 client = Client() 24 ms = MemorySampler() 25 26 with ms.sample("run 1"): 27 <run first workflow> 28 with ms.sample("run 2"): 29 <run second workflow> 30 ... 31 ms.plot() 32 33 or with an asynchronous client: 34 35 .. code-block:: python 36 37 client = await Client(asynchronous=True) 38 ms = MemorySampler() 39 40 async with ms.sample("run 1"): 41 <run first workflow> 42 async with ms.sample("run 2"): 43 <run second workflow> 44 ... 45 ms.plot() 46 """ 47 48 samples: dict[str, list[tuple[float, int]]] 49 50 def __init__(self): 51 self.samples = {} 52 53 def sample( 54 self, 55 label: str | None = None, 56 *, 57 client: Client | None = None, 58 measure: str = "process", 59 interval: float = 0.5, 60 ): 61 """Context manager that records memory usage in the cluster. 62 This is synchronous if the client is synchronous and 63 asynchronous if the client is asynchronous. 64 65 The samples are recorded in ``self.samples[<label>]``. 66 67 Parameters 68 ========== 69 label: str, optional 70 Tag to record the samples under in the self.samples dict. 71 Default: automatically generate a random label 72 client: Client, optional 73 client used to connect to the scheduler. 74 Default: use the global client 75 measure: str, optional 76 One of the measures from :class:`distributed.scheduler.MemoryState`. 77 Default: sample process memory 78 interval: float, optional 79 sampling interval, in seconds. 80 Default: 0.5 81 """ 82 if not client: 83 from ..client import get_client 84 85 client = get_client() 86 87 if client.asynchronous: 88 return self._sample_async(label, client, measure, interval) 89 else: 90 return self._sample_sync(label, client, measure, interval) 91 92 @contextmanager 93 def _sample_sync( 94 self, label: str | None, client: Client, measure: str, interval: float 95 ): 96 key = client.sync( 97 client.scheduler.memory_sampler_start, 98 client=client.id, 99 measure=measure, 100 interval=interval, 101 ) 102 try: 103 yield 104 finally: 105 samples = client.sync(client.scheduler.memory_sampler_stop, key=key) 106 self.samples[label or key] = samples 107 108 @asynccontextmanager 109 async def _sample_async( 110 self, label: str | None, client: Client, measure: str, interval: float 111 ): 112 key = await client.scheduler.memory_sampler_start( 113 client=client.id, measure=measure, interval=interval 114 ) 115 try: 116 yield 117 finally: 118 samples = await client.scheduler.memory_sampler_stop(key=key) 119 self.samples[label or key] = samples 120 121 def to_pandas(self, *, align: bool = False): 122 """Return the data series as a pandas.Dataframe. 123 124 Parameters 125 ========== 126 align : bool, optional 127 If True, change the absolute timestamps into time deltas from the first 128 sample of each series, so that different series can be visualized side by 129 side. If False (the default), use absolute timestamps. 130 """ 131 import pandas as pd 132 133 ss = {} 134 for (label, s_list) in self.samples.items(): 135 assert s_list # There's always at least one sample 136 s = pd.DataFrame(s_list).set_index(0)[1] 137 s.index = pd.to_datetime(s.index, unit="s") 138 s.name = label 139 if align: 140 # convert datetime to timedelta from the first sample 141 s.index -= s.index[0] 142 ss[label] = s 143 144 df = pd.DataFrame(ss) 145 146 if len(ss) > 1: 147 # Forward-fill NaNs in the middle of a series created either by overlapping 148 # sampling time range or by align=True. Do not ffill series beyond their 149 # last sample. 150 df = df.ffill().where(~pd.isna(df.bfill())) 151 152 return df 153 154 def plot(self, *, align: bool = False, **kwargs): 155 """Plot data series collected so far 156 157 Parameters 158 ========== 159 align : bool (optional) 160 See :meth:`~distributed.diagnostics.MemorySampler.to_pandas` 161 kwargs 162 Passed verbatim to :meth:`pandas.DataFrame.plot` 163 """ 164 df = self.to_pandas(align=align) / 2 ** 30 165 return df.plot( 166 xlabel="time", 167 ylabel="Cluster memory (GiB)", 168 **kwargs, 169 ) 170 171 172class MemorySamplerExtension: 173 """Scheduler extension - server side of MemorySampler""" 174 175 scheduler: Scheduler 176 samples: dict[str, list[tuple[float, int]]] 177 178 def __init__(self, scheduler: Scheduler): 179 self.scheduler = scheduler 180 self.scheduler.extensions["memory_sampler"] = self 181 self.scheduler.handlers["memory_sampler_start"] = self.start 182 self.scheduler.handlers["memory_sampler_stop"] = self.stop 183 self.samples = {} 184 185 def start(self, comm, client: str, measure: str, interval: float) -> str: 186 """Start periodically sampling memory""" 187 assert not measure.startswith("_") 188 assert isinstance(getattr(self.scheduler.memory, measure), int) 189 190 key = str(uuid.uuid4()) 191 self.samples[key] = [] 192 193 def sample(): 194 if client in self.scheduler.clients: 195 ts = datetime.now().timestamp() 196 nbytes = getattr(self.scheduler.memory, measure) 197 self.samples[key].append((ts, nbytes)) 198 else: 199 self.stop(comm, key) 200 201 pc = PeriodicCallback(sample, interval * 1000) 202 self.scheduler.periodic_callbacks["MemorySampler-" + key] = pc 203 pc.start() 204 205 # Immediately collect the first sample; this also ensures there's always at 206 # least one sample 207 sample() 208 209 return key 210 211 def stop(self, comm, key: str): 212 """Stop sampling and return the samples""" 213 pc = self.scheduler.periodic_callbacks.pop("MemorySampler-" + key) 214 pc.stop() 215 return self.samples.pop(key) 216