1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import os
7import site
8import glob
9import ctypes
10import time
11import inspect
12import warnings
13import itertools
14import contextlib
15import collections
16import typing as tp
17import numpy as np
18
19
20def pytorch_import_fix() -> None:
21    """Hackfix needed before pytorch import ("dlopen: cannot load any more object with static TLS")
22    See issue #305
23    """
24    try:
25        for packages in site.getsitepackages():
26            for lib in glob.glob(f"{packages}/torch/lib/libgomp*.so*"):
27                ctypes.cdll.LoadLibrary(lib)
28    except Exception:  # pylint: disable=broad-except
29        pass
30
31
32def pairwise(iterable: tp.Iterable[tp.Any]) -> tp.Iterator[tp.Tuple[tp.Any, tp.Any]]:
33    """Returns an iterator over sliding pairs of the input iterator
34    s -> (s0,s1), (s1,s2), (s2, s3), ...
35
36    Note
37    ----
38    Nothing will be returned if length of iterator is strictly less
39    than 2.
40    """  # From itertools documentation
41    a, b = itertools.tee(iterable)
42    next(b, None)
43    return zip(a, b)
44
45
46def grouper(iterable: tp.Iterable[tp.Any], n: int, fillvalue: tp.Any = None) -> tp.Iterator[tp.List[tp.Any]]:
47    """Collect data into fixed-length chunks or blocks
48    Copied from itertools recipe documentation
49    Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
50    """
51    args = [iter(iterable)] * n
52    return itertools.zip_longest(*args, fillvalue=fillvalue)
53
54
55def roundrobin(*iterables: tp.Iterable[tp.Any]) -> tp.Iterator[tp.Any]:
56    """roundrobin('ABC', 'D', 'EF') --> A D E B F C"""
57    # Recipe credited to George Sakkis
58    num_active = len(iterables)
59    nexts = itertools.cycle(iter(it).__next__ for it in iterables)
60    while num_active:
61        try:
62            for next_ in nexts:
63                yield next_()
64        except StopIteration:
65            # Remove the iterator we just exhausted from the cycle.
66            num_active -= 1
67            nexts = itertools.cycle(itertools.islice(nexts, num_active))
68
69
70class Sleeper:
71    """Simple object for managing the waiting time of a job
72
73    Parameters
74    ----------
75    min_sleep: float
76        minimum sleep time
77    max_sleep: float
78        maximum sleep time
79    averaging_size: int
80        size for averaging the registered durations
81    """
82
83    def __init__(self, min_sleep: float = 1e-7, max_sleep: float = 1.0, averaging_size: int = 10) -> None:
84        self._min = min_sleep
85        self._max = max_sleep
86        self._start: tp.Optional[float] = None
87        self._queue: tp.Deque[float] = collections.deque(maxlen=averaging_size)
88        self._num_waits = 10  # expect to waste around 10% of time
89
90    def start_timer(self) -> None:
91        if self._start is not None:
92            warnings.warn("Ignoring since timer was already started.")
93            return
94        self._start = time.time()
95
96    def stop_timer(self) -> None:
97        if self._start is None:
98            warnings.warn("Ignoring since timer was stopped before starting.")
99            return
100        self._queue.append(time.time() - self._start)
101        self._start = None
102
103    def _get_advised_sleep_duration(self) -> float:
104        if not self._queue:
105            if self._start is None:
106                return self._min
107            value = time.time() - self._start
108        else:
109            value = np.mean(self._queue)
110        return float(np.clip(value / self._num_waits, self._min, self._max))
111
112    def sleep(self) -> None:
113        time.sleep(self._get_advised_sleep_duration())
114
115
116X = tp.TypeVar("X", bound=tp.Hashable)
117
118
119class OrderedSet(tp.MutableSet[X]):
120    """Set of elements retaining the insertion order
121    All new elements are appended to the end of the set.
122    """
123
124    def __init__(self, keys: tp.Optional[tp.Iterable[X]] = None) -> None:
125        self._data: "collections.OrderedDict[X, int]" = collections.OrderedDict()
126        self._global_index = 0  # keep track of insertion global index if need be
127        if keys is not None:
128            for key in keys:
129                self.add(key)
130
131    def add(self, key: X) -> None:
132        self._data[key] = self._data.pop(key, self._global_index)
133        self._global_index += 1
134
135    def popright(self) -> X:
136        key = next(reversed(self._data))
137        self.discard(key)
138        return key
139
140    def discard(self, key: X) -> None:
141        del self._data[key]
142
143    def __contains__(self, key: tp.Any) -> bool:
144        return key in self._data
145
146    def __iter__(self) -> tp.Iterator[X]:
147        return iter(self._data)
148
149    def __len__(self) -> int:
150        return len(self._data)
151
152
153def different_from_defaults(
154    *,
155    instance: tp.Any,
156    instance_dict: tp.Optional[tp.Dict[str, tp.Any]] = None,
157    check_mismatches: bool = False,
158) -> tp.Dict[str, tp.Any]:
159    """Checks which attributes are different from defaults arguments
160
161    Parameters
162    ----------
163    instance: object
164        the object to change
165    instance_dict: dict
166        the dict corresponding to the instance, if not provided it's self.__dict__
167    check_mismatches: bool
168        checks that the attributes match the parameters
169
170    Note
171    ----
172    This is convenient for short repr of data structures
173    """
174    defaults = {
175        x: y.default
176        for x, y in inspect.signature(instance.__class__.__init__).parameters.items()
177        if x not in ["self", "__class__"]
178    }
179    if instance_dict is None:
180        instance_dict = instance.__dict__
181    if check_mismatches:
182        diff = set(defaults.keys()).symmetric_difference(instance_dict.keys())
183        if diff:  # this is to help during development
184            raise RuntimeError(f"Mismatch between attributes and arguments of {instance}: {diff}")
185    else:
186        defaults = {x: y for x, y in defaults.items() if x in instance.__dict__}
187    # only print non defaults
188    return {
189        x: instance_dict[x] for x, y in defaults.items() if y != instance_dict[x] and not x.startswith("_")
190    }
191
192
193@contextlib.contextmanager
194def set_env(**environ: tp.Any) -> tp.Generator[None, None, None]:
195    """Temporarily changes environment variables."""
196    old_environ = {x: os.environ.get(x, None) for x in environ}
197    for x in environ:
198        if x != x.upper():
199            raise ValueError(f"Only capitalized environment variable are allowed, but got {x!r}")
200    os.environ.update({x: str(y) for x, y in environ.items()})
201    try:
202        yield
203    finally:
204        for k, val in old_environ.items():
205            os.environ.pop(k)
206            if val is not None:
207                os.environ[k] = val
208
209
210def flatten(obj: tp.Any) -> tp.Any:
211    """Flatten a dict/list structure
212
213    Example
214    -------
215
216    >>> flatten(["a", {"truc": [4, 5]}])
217    >>> {"0": "a", "1.truc.0": 4, "1.truc.1": 5}
218    """
219    output: tp.Any = {}
220    if isinstance(obj, (tuple, list)):
221        iterator = enumerate(obj)
222    elif isinstance(obj, dict):
223        iterator = obj.items()  # type: ignore
224    else:
225        return obj
226    for k, val in iterator:
227        content = flatten(val)
228        if isinstance(content, dict):
229            output.update({f"{k}.{x}": y for x, y in content.items()})
230        else:
231            output[str(k)] = val
232    return output
233