1"""Classes for instrumenting code to collect various metrics.
2
3Instrumentation consists of creating the metric and then updating it.
4Creation can be performed at the module level or as a class attribute.  Since
5the metric namespace is global, metrics should not be created by instances
6unless that instance is certain to be a singleton.
7
8Sample code:
9
10_my_counter = metrics.Counter("my-counter")
11
12
13def foo():
14  _my_counter.inc()  # calls to foo() count as 1 unit.
15
16
17def bar(n):
18  _my_counter.inc(n)  # calls to bar() count as n units.
19"""
20
21import json
22import math
23import re
24import time
25import tracemalloc
26
27
28# Metric serialization/deserialization code, taking advantage of the fact that
29# all instance vars of metrics are things that json can serialize, so we don't
30# need to write custom JsonEncoder and JsonDecoder classes per Metric subclass.
31
32# Register metric types for deserialization
33_METRIC_TYPES = {}
34
35
36class _RegistryMeta(type):
37  """Metaclass that registers subclasses in _METRIC_TYPES."""
38
39  def __new__(cls, name, bases, class_dict):
40    subcls = super(_RegistryMeta, cls).__new__(cls, name, bases, class_dict)
41    _METRIC_TYPES[subcls.__name__] = subcls
42    return subcls
43
44
45def _deserialize(typ, payload):
46  """Construct a Metric from a typename and payload loaded from json."""
47  if typ not in _METRIC_TYPES:
48    raise TypeError("Could not decode class %s" % typ)
49  cls = _METRIC_TYPES[typ]
50  out = cls(None)
51  out.__dict__.update(payload)
52  return out
53
54
55def _serialize(obj):
56  """Return a json-serializable form of object."""
57  return [obj.__class__.__name__, vars(obj)]
58
59
60def dump_all(objs, fp):
61  """Write a list of metrics to a json file."""
62  json.dump([_serialize(x) for x in objs], fp)
63
64
65def load_all(fp):
66  """Read a list of metrics from a json file."""
67  metrics = json.load(fp)
68  return [_deserialize(*x) for x in metrics]
69
70
71_METRIC_NAME_RE = re.compile(r"^[a-zA-Z_]\w+$")
72
73_registered_metrics = {}  # Map from metric name to Metric object.
74_enabled = False  # True iff metrics should be collected.
75
76
77def _validate_metric_name(name):
78  if _METRIC_NAME_RE.match(name) is None:
79    raise ValueError("Illegal metric name: %s" % name)
80
81
82def _prepare_for_test(enabled=True):
83  """Setup metrics collection for a test."""
84  _registered_metrics.clear()
85  global _enabled
86  _enabled = enabled
87
88
89def get_cpu_clock():
90  """Returns CPU clock to keep compatibilty with various Python versions."""
91  return time.process_time()
92
93
94def get_metric(name, constructor, *args, **kwargs):
95  """Return an existing metric or create a new one for the given name.
96
97  Args:
98    name: The name of the metric.
99    constructor: A class to instantiate if a new metric is required.
100    *args: Additional positional args to pass to the constructor.
101    **kwargs: Keyword args for the constructor.
102
103  Returns:
104    The current metric registered to name, or a new one created by
105    invoking constructor(name, *args, **kwargs).
106  """
107  metric = _registered_metrics.get(name)
108  if metric is not None:
109    return metric
110  else:
111    return constructor(name, *args, **kwargs)
112
113
114def get_report():
115  """Return a string listing all metrics, one per line."""
116  lines = [str(_registered_metrics[n]) + "\n"
117           for n in sorted(_registered_metrics)]
118  return "".join(lines)
119
120
121def merge_from_file(metrics_file):
122  """Merge metrics recorded in another file into the current metrics."""
123  for metric in load_all(metrics_file):
124    existing = _registered_metrics.get(metric.name)
125    if existing is None:
126      _validate_metric_name(metric.name)
127      _registered_metrics[metric.name] = metric
128    else:
129      if type(metric) != type(existing):  # pylint: disable=unidiomatic-typecheck
130        raise TypeError("Cannot merge metrics of different types.")
131      existing._merge(metric)  # pylint: disable=protected-access
132
133
134class Metric(metaclass=_RegistryMeta):
135  """Abstract base class for metrics."""
136
137  def __init__(self, name):
138    """Initialize the metric and register it under the specified name."""
139    if name is None:
140      # We do not want to register this metric (e.g. we are deserializing a
141      # metric from file and need to merge it into the existing metric with the
142      # same name.)
143      return
144    _validate_metric_name(name)
145    if name in _registered_metrics:
146      raise ValueError("Metric %s has already been defined." % name)
147    self._name = name
148    _registered_metrics[name] = self
149
150  @property
151  def name(self):
152    return self._name
153
154  def _summary(self):
155    """Return a string summarizing the value of the metric."""
156    raise NotImplementedError
157
158  def _merge(self, other):
159    """Merge data from another metric of the same type."""
160    raise NotImplementedError
161
162  def __str__(self):
163    return "%s: %s" % (self._name, self._summary())
164
165
166class Counter(Metric):
167  """A monotonically increasing metric."""
168
169  def __init__(self, name):
170    super().__init__(name)
171    self._total = 0
172
173  def inc(self, count=1):
174    """Increment the metric by the specified amount."""
175    if count < 0:
176      raise ValueError("Counter must be monotonically increasing.")
177    if not _enabled:
178      return
179    self._total += count
180
181  def _summary(self):
182    return str(self._total)
183
184  def _merge(self, other):
185    # pylint: disable=protected-access
186    self._total += other._total
187
188
189class StopWatch(Metric):
190  """A counter that measures the time spent in a "with" statement."""
191
192  def __enter__(self):
193    self._start_time = get_cpu_clock()
194
195  def __exit__(self, exc_type, exc_value, traceback):
196    self._total = get_cpu_clock() - self._start_time
197    del self._start_time
198
199  def _summary(self):
200    return "%f seconds" % self._total
201
202  def _merge(self, other):
203    # pylint: disable=protected-access
204    self._total += other._total
205
206
207class ReentrantStopWatch(Metric):
208  """A watch that supports being called multiple times and recursively."""
209
210  def __init__(self, name):
211    super().__init__(name)
212    self._time = 0
213    self._calls = 0
214
215  def __enter__(self):
216    if not self._calls:
217      self._start_time = get_cpu_clock()
218    self._calls += 1
219
220  def __exit__(self, exc_type, exc_value, traceback):
221    self._calls -= 1
222    if not self._calls:
223      self._time += get_cpu_clock() - self._start_time
224      del self._start_time
225
226  def _merge(self, other):
227    self._time += other._time  # pylint: disable=protected-access
228
229  def _summary(self):
230    return "time spend below this StopWatch: %s" % self._time
231
232
233class MapCounter(Metric):
234  """A set of related counters keyed by an arbitrary string."""
235
236  def __init__(self, name):
237    super().__init__(name)
238    self._counts = {}
239    self._total = 0
240
241  def inc(self, key, count=1):
242    """Increment the metric by the specified amount.
243
244    Args:
245      key: A string to be used as the key.
246      count: The amount to increment by (non-negative integer).
247
248    Raises:
249      ValueError: if the count is less than 0.
250    """
251    if count < 0:
252      raise ValueError("Counter must be monotonically increasing.")
253    if not _enabled:
254      return
255    self._counts[key] = self._counts.get(key, 0) + count
256    self._total += count
257
258  def _summary(self):
259    details = ", ".join(["%s=%d" % (k, self._counts[k])
260                         for k in sorted(self._counts)])
261    return "%d {%s}" % (self._total, details)
262
263  def _merge(self, other):
264    # pylint: disable=protected-access
265    for key, count in other._counts.items():
266      self._counts[key] = self._counts.get(key, 0) + count
267      self._total += count
268
269
270class Distribution(Metric):
271  """A metric to track simple statistics from a distribution of values."""
272
273  def __init__(self, name):
274    super().__init__(name)
275    self._count = 0  # Number of values.
276    self._total = 0.0  # Sum of the values.
277    self._squared = 0.0  # Sum of the squares of the values.
278    self._min = None
279    self._max = None
280
281  def add(self, value):
282    """Add a value to the distribution."""
283    if not _enabled:
284      return
285    self._count += 1
286    self._total += value
287    self._squared += value * value
288    if self._min is None:
289      # First add, this value is the min and max
290      self._min = self._max = value
291    else:
292      self._min = min(self._min, value)
293      self._max = max(self._max, value)
294
295  def _mean(self):
296    if self._count:
297      return self._total / float(self._count)
298
299  def _stdev(self):
300    if self._count:
301      variance = ((self._squared * self._count - self._total * self._total) /
302                  (self._count * self._count))
303      if variance < 0.0:
304        # This can only happen as the result of rounding error when the actual
305        # variance is very, very close to 0.  Assume it is 0.
306        return 0.0
307      return  math.sqrt(variance)
308
309  def _summary(self):
310    return "total=%s, count=%d, min=%s, max=%s, mean=%s, stdev=%s" % (
311        self._total, self._count, self._min, self._max, self._mean(),
312        self._stdev())
313
314  def _merge(self, other):
315    # pylint: disable=protected-access
316    if other._count == 0:
317      # Exit early so we don't have to worry about min/max of None.
318      return
319    self._count += other._count
320    self._total += other._total
321    self._squared += other._squared
322    if self._min is None:
323      self._min = other._min
324      self._max = other._max
325    else:
326      self._min = min(self._min, other._min)
327      self._max = max(self._max, other._max)
328
329
330class Snapshot(Metric):
331  """A metric to track memory usage via tracemalloc snapshots."""
332
333  def __init__(self, name, enabled=False, groupby="lineno",
334               nframes=1, count=10):
335    super().__init__(name)
336    self.snapshots = []
337    # The metric to group memory blocks by. Default is "lineno", which groups by
338    # which file and line allocated the block. The other useful value is
339    # "traceback", which groups by the stack frames leading to each allocation.
340    self.groupby = groupby
341    # The number of stack frames to store per memory block. Values greater than
342    # 1 are only useful if groupby = "traceback".
343    self.nframes = nframes
344    # The number of memory block statistics to save.
345    self.count = count
346    self.running = False
347    # Two conditions must be met for memory snapshots to be taken:
348    # 1. Metrics have been enabled (global _enabled)
349    # 2. Explicitly enabled by the arg to the constructor (which should be the
350    # options.memory_snapshot flag set by the --memory-snapshots option)
351    self.enabled = _enabled and enabled
352
353  def _start_tracemalloc(self):
354    tracemalloc.start(self.nframes)
355    self.running = True
356
357  def _stop_tracemalloc(self):
358    tracemalloc.stop()
359    self.running = False
360
361  def take_snapshot(self, where=""):
362    """Stores a tracemalloc snapshot."""
363    if not self.enabled:
364      return
365    if not self.running:
366      self._start_tracemalloc()
367    snap = tracemalloc.take_snapshot()
368    # Store the top self.count memory consumers by self.groupby
369    # We can't just store the list of statistics though! Statistic.__eq__
370    # doesn't take None into account during comparisons, and json will compare
371    # it to None when trying to process it, causing an error. So, store it as a
372    # string instead.
373    self.snapshots.append("%s:\n%s" % (where, "\n".join(
374        map(str, snap.statistics(self.groupby)[:self.count]))))
375
376  def __enter__(self):
377    if not self.enabled:
378      return
379    self._start_tracemalloc()
380    self.take_snapshot("__enter__")
381
382  def __exit__(self, exc_type, exc_value, traceback):
383    if not self.running:
384      return
385    self.take_snapshot("__exit__")
386    self._stop_tracemalloc()
387
388  def _summary(self):
389    return "\n\n".join(self.snapshots)
390
391
392class MetricsContext:
393  """A context manager that configures metrics and writes their output."""
394
395  def __init__(self, output_path, open_function=open):
396    """Initialize.
397
398    Args:
399      output_path: The path for the metrics data.  If empty, no metrics are
400          collected.
401      open_function: A custom file opening function.
402    """
403    self._output_path = output_path
404    self._open_function = open_function
405    self._old_enabled = None  # Set in __enter__.
406
407  def __enter__(self):
408    global _enabled
409    self._old_enabled = _enabled
410    _enabled = bool(self._output_path)
411
412  def __exit__(self, exc_type, exc_value, traceback):
413    global _enabled
414    _enabled = self._old_enabled
415    if self._output_path:
416      with self._open_function(self._output_path, "w") as f:
417        dump_all(_registered_metrics.values(), f)
418