1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import os 7import site 8import glob 9import ctypes 10import time 11import inspect 12import warnings 13import itertools 14import contextlib 15import collections 16import typing as tp 17import numpy as np 18 19 20def pytorch_import_fix() -> None: 21 """Hackfix needed before pytorch import ("dlopen: cannot load any more object with static TLS") 22 See issue #305 23 """ 24 try: 25 for packages in site.getsitepackages(): 26 for lib in glob.glob(f"{packages}/torch/lib/libgomp*.so*"): 27 ctypes.cdll.LoadLibrary(lib) 28 except Exception: # pylint: disable=broad-except 29 pass 30 31 32def pairwise(iterable: tp.Iterable[tp.Any]) -> tp.Iterator[tp.Tuple[tp.Any, tp.Any]]: 33 """Returns an iterator over sliding pairs of the input iterator 34 s -> (s0,s1), (s1,s2), (s2, s3), ... 35 36 Note 37 ---- 38 Nothing will be returned if length of iterator is strictly less 39 than 2. 40 """ # From itertools documentation 41 a, b = itertools.tee(iterable) 42 next(b, None) 43 return zip(a, b) 44 45 46def grouper(iterable: tp.Iterable[tp.Any], n: int, fillvalue: tp.Any = None) -> tp.Iterator[tp.List[tp.Any]]: 47 """Collect data into fixed-length chunks or blocks 48 Copied from itertools recipe documentation 49 Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" 50 """ 51 args = [iter(iterable)] * n 52 return itertools.zip_longest(*args, fillvalue=fillvalue) 53 54 55def roundrobin(*iterables: tp.Iterable[tp.Any]) -> tp.Iterator[tp.Any]: 56 """roundrobin('ABC', 'D', 'EF') --> A D E B F C""" 57 # Recipe credited to George Sakkis 58 num_active = len(iterables) 59 nexts = itertools.cycle(iter(it).__next__ for it in iterables) 60 while num_active: 61 try: 62 for next_ in nexts: 63 yield next_() 64 except StopIteration: 65 # Remove the iterator we just exhausted from the cycle. 66 num_active -= 1 67 nexts = itertools.cycle(itertools.islice(nexts, num_active)) 68 69 70class Sleeper: 71 """Simple object for managing the waiting time of a job 72 73 Parameters 74 ---------- 75 min_sleep: float 76 minimum sleep time 77 max_sleep: float 78 maximum sleep time 79 averaging_size: int 80 size for averaging the registered durations 81 """ 82 83 def __init__(self, min_sleep: float = 1e-7, max_sleep: float = 1.0, averaging_size: int = 10) -> None: 84 self._min = min_sleep 85 self._max = max_sleep 86 self._start: tp.Optional[float] = None 87 self._queue: tp.Deque[float] = collections.deque(maxlen=averaging_size) 88 self._num_waits = 10 # expect to waste around 10% of time 89 90 def start_timer(self) -> None: 91 if self._start is not None: 92 warnings.warn("Ignoring since timer was already started.") 93 return 94 self._start = time.time() 95 96 def stop_timer(self) -> None: 97 if self._start is None: 98 warnings.warn("Ignoring since timer was stopped before starting.") 99 return 100 self._queue.append(time.time() - self._start) 101 self._start = None 102 103 def _get_advised_sleep_duration(self) -> float: 104 if not self._queue: 105 if self._start is None: 106 return self._min 107 value = time.time() - self._start 108 else: 109 value = np.mean(self._queue) 110 return float(np.clip(value / self._num_waits, self._min, self._max)) 111 112 def sleep(self) -> None: 113 time.sleep(self._get_advised_sleep_duration()) 114 115 116X = tp.TypeVar("X", bound=tp.Hashable) 117 118 119class OrderedSet(tp.MutableSet[X]): 120 """Set of elements retaining the insertion order 121 All new elements are appended to the end of the set. 122 """ 123 124 def __init__(self, keys: tp.Optional[tp.Iterable[X]] = None) -> None: 125 self._data: "collections.OrderedDict[X, int]" = collections.OrderedDict() 126 self._global_index = 0 # keep track of insertion global index if need be 127 if keys is not None: 128 for key in keys: 129 self.add(key) 130 131 def add(self, key: X) -> None: 132 self._data[key] = self._data.pop(key, self._global_index) 133 self._global_index += 1 134 135 def popright(self) -> X: 136 key = next(reversed(self._data)) 137 self.discard(key) 138 return key 139 140 def discard(self, key: X) -> None: 141 del self._data[key] 142 143 def __contains__(self, key: tp.Any) -> bool: 144 return key in self._data 145 146 def __iter__(self) -> tp.Iterator[X]: 147 return iter(self._data) 148 149 def __len__(self) -> int: 150 return len(self._data) 151 152 153def different_from_defaults( 154 *, 155 instance: tp.Any, 156 instance_dict: tp.Optional[tp.Dict[str, tp.Any]] = None, 157 check_mismatches: bool = False, 158) -> tp.Dict[str, tp.Any]: 159 """Checks which attributes are different from defaults arguments 160 161 Parameters 162 ---------- 163 instance: object 164 the object to change 165 instance_dict: dict 166 the dict corresponding to the instance, if not provided it's self.__dict__ 167 check_mismatches: bool 168 checks that the attributes match the parameters 169 170 Note 171 ---- 172 This is convenient for short repr of data structures 173 """ 174 defaults = { 175 x: y.default 176 for x, y in inspect.signature(instance.__class__.__init__).parameters.items() 177 if x not in ["self", "__class__"] 178 } 179 if instance_dict is None: 180 instance_dict = instance.__dict__ 181 if check_mismatches: 182 diff = set(defaults.keys()).symmetric_difference(instance_dict.keys()) 183 if diff: # this is to help during development 184 raise RuntimeError(f"Mismatch between attributes and arguments of {instance}: {diff}") 185 else: 186 defaults = {x: y for x, y in defaults.items() if x in instance.__dict__} 187 # only print non defaults 188 return { 189 x: instance_dict[x] for x, y in defaults.items() if y != instance_dict[x] and not x.startswith("_") 190 } 191 192 193@contextlib.contextmanager 194def set_env(**environ: tp.Any) -> tp.Generator[None, None, None]: 195 """Temporarily changes environment variables.""" 196 old_environ = {x: os.environ.get(x, None) for x in environ} 197 for x in environ: 198 if x != x.upper(): 199 raise ValueError(f"Only capitalized environment variable are allowed, but got {x!r}") 200 os.environ.update({x: str(y) for x, y in environ.items()}) 201 try: 202 yield 203 finally: 204 for k, val in old_environ.items(): 205 os.environ.pop(k) 206 if val is not None: 207 os.environ[k] = val 208 209 210def flatten(obj: tp.Any) -> tp.Any: 211 """Flatten a dict/list structure 212 213 Example 214 ------- 215 216 >>> flatten(["a", {"truc": [4, 5]}]) 217 >>> {"0": "a", "1.truc.0": 4, "1.truc.1": 5} 218 """ 219 output: tp.Any = {} 220 if isinstance(obj, (tuple, list)): 221 iterator = enumerate(obj) 222 elif isinstance(obj, dict): 223 iterator = obj.items() # type: ignore 224 else: 225 return obj 226 for k, val in iterator: 227 content = flatten(val) 228 if isinstance(content, dict): 229 output.update({f"{k}.{x}": y for x, y in content.items()}) 230 else: 231 output[str(k)] = val 232 return output 233