1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable= 20"""Parallelization utility optimizer.""" 21 22__all__ = ['split_data', 'split_and_load', 'clip_global_norm', 23 'check_sha1', 'download', 'replace_file'] 24 25import os 26import sys 27import hashlib 28import uuid 29import warnings 30import collections 31import weakref 32import requests 33 34import numpy as np 35 36from .. import ndarray 37from ..util import is_np_shape, is_np_array 38from .. import numpy as _mx_np # pylint: disable=reimported 39 40 41def split_data(data, num_slice, batch_axis=0, even_split=True): 42 """Splits an NDArray into `num_slice` slices along `batch_axis`. 43 Usually used for data parallelism where each slices is sent 44 to one device (i.e. GPU). 45 46 Parameters 47 ---------- 48 data : NDArray 49 A batch of data. 50 num_slice : int 51 Number of desired slices. 52 batch_axis : int, default 0 53 The axis along which to slice. 54 even_split : bool, default True 55 Whether to force all slices to have the same number of elements. 56 If `True`, an error will be raised when `num_slice` does not evenly 57 divide `data.shape[batch_axis]`. 58 59 Returns 60 ------- 61 list of NDArray 62 Return value is a list even if `num_slice` is 1. 63 """ 64 size = data.shape[batch_axis] 65 if even_split and size % num_slice != 0: 66 raise ValueError( 67 "data with shape %s cannot be evenly split into %d slices along axis %d. " \ 68 "Use a batch size that's multiple of %d or set even_split=False to allow " \ 69 "uneven partitioning of data."%( 70 str(data.shape), num_slice, batch_axis, num_slice)) 71 72 n_each_section, extras = divmod(size, num_slice) 73 section_sizes = [0] + (extras * [n_each_section + 1] + 74 (num_slice - extras) * [n_each_section]) 75 div_points = np.array(section_sizes).cumsum() 76 if is_np_array(): 77 slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis) 78 else: 79 slices = [] 80 if batch_axis != 0: 81 for i in range(num_slice): 82 st = div_points[i] 83 end = div_points[i + 1] 84 slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end)) 85 else: 86 # Fixes issue: https://github.com/apache/incubator-mxnet/issues/19268 87 slices = [data[div_points[i]:div_points[i + 1]] if i < num_slice - 1 else data[div_points[i]:size] 88 for i in range(num_slice)] 89 return slices 90 91 92def split_and_load(data, ctx_list, batch_axis=0, even_split=True): 93 """Splits an NDArray into `len(ctx_list)` slices along `batch_axis` and loads 94 each slice to one context in `ctx_list`. 95 96 Parameters 97 ---------- 98 data : NDArray or ndarray 99 A batch of data. 100 ctx_list : list of Context 101 A list of Contexts. 102 batch_axis : int, default 0 103 The axis along which to slice. 104 even_split : bool, default True 105 Whether to force all slices to have the same number of elements. 106 107 Returns 108 ------- 109 list of NDArrays or ndarrays 110 Each corresponds to a context in `ctx_list`. 111 """ 112 array_fn = _mx_np.array if is_np_array() else ndarray.array 113 if not isinstance(data, ndarray.NDArray): 114 data = array_fn(data, ctx=ctx_list[0]) 115 if len(ctx_list) == 1: 116 return [data.as_in_context(ctx_list[0])] 117 118 slices = split_data(data, len(ctx_list), batch_axis, even_split) 119 return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)] 120 121 122def clip_global_norm(arrays, max_norm, check_isfinite=True): 123 """Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`. 124 125 Parameters 126 ---------- 127 arrays : list of NDArray 128 max_norm : float 129 check_isfinite : bool, default True 130 If True, check that the total_norm is finite (not nan or inf). This 131 requires a blocking .asscalar() call. 132 133 Returns 134 ------- 135 NDArray or float 136 Total norm. Return type is NDArray of shape (1,) if check_isfinite is 137 False. Otherwise a float is returned. 138 139 """ 140 def _norm(array): 141 if array.stype == 'default': 142 x = array.reshape((-1,)) 143 return ndarray.dot(x, x) 144 return array.norm().square() 145 assert len(arrays) > 0 146 ctx = arrays[0].context 147 total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays]) 148 total_norm = ndarray.sqrt(total_norm) 149 if check_isfinite: 150 if not np.isfinite(total_norm.asscalar()): 151 warnings.warn( 152 UserWarning('nan or inf is detected. ' 153 'Clipping results will be undefined.'), stacklevel=2) 154 scale = max_norm / (total_norm + 1e-8) 155 scale = ndarray.min(ndarray.concat(scale, ndarray.ones(1, ctx=ctx), dim=0)) 156 for arr in arrays: 157 arr *= scale.as_in_context(arr.context) 158 if check_isfinite: 159 return total_norm.asscalar() 160 else: 161 return total_norm 162 163 164def _indent(s_, numSpaces): 165 """Indent string 166 """ 167 s = s_.split('\n') 168 if len(s) == 1: 169 return s_ 170 first = s.pop(0) 171 s = [first] + [(numSpaces * ' ') + line for line in s] 172 s = '\n'.join(s) 173 return s 174 175 176def check_sha1(filename, sha1_hash): 177 """Check whether the sha1 hash of the file content matches the expected hash. 178 179 Parameters 180 ---------- 181 filename : str 182 Path to the file. 183 sha1_hash : str 184 Expected sha1 hash in hexadecimal digits. 185 186 Returns 187 ------- 188 bool 189 Whether the file content matches the expected hash. 190 """ 191 sha1 = hashlib.sha1() 192 with open(filename, 'rb') as f: 193 while True: 194 data = f.read(1048576) 195 if not data: 196 break 197 sha1.update(data) 198 199 return sha1.hexdigest() == sha1_hash 200 201 202if not sys.platform.startswith('win32'): 203 # refer to https://github.com/untitaker/python-atomicwrites 204 def replace_file(src, dst): 205 """Implement atomic os.replace with linux and OSX. 206 207 Parameters 208 ---------- 209 src : source file path 210 dst : destination file path 211 """ 212 try: 213 os.rename(src, dst) 214 except OSError: 215 try: 216 os.remove(src) 217 except OSError: 218 pass 219 finally: 220 raise OSError( 221 'Moving downloaded temp file - {}, to {} failed. \ 222 Please retry the download.'.format(src, dst)) 223else: 224 import ctypes 225 226 _MOVEFILE_REPLACE_EXISTING = 0x1 227 # Setting this value guarantees that a move performed as a copy 228 # and delete operation is flushed to disk before the function returns. 229 # The flush occurs at the end of the copy operation. 230 _MOVEFILE_WRITE_THROUGH = 0x8 231 _windows_default_flags = _MOVEFILE_WRITE_THROUGH 232 233 def _str_to_unicode(x): 234 """Handle text decoding. Internal use only""" 235 if not isinstance(x, str): 236 return x.decode(sys.getfilesystemencoding()) 237 return x 238 239 def _handle_errors(rv, src): 240 """Handle WinError. Internal use only""" 241 if not rv: 242 msg = ctypes.FormatError(ctypes.GetLastError()) 243 # if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile 244 try: 245 os.remove(src) 246 except OSError: 247 pass 248 finally: 249 raise OSError(msg) 250 251 def replace_file(src, dst): 252 """Implement atomic os.replace with windows. 253 254 refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw 255 The function fails when one of the process(copy, flush, delete) fails. 256 257 Parameters 258 ---------- 259 src : source file path 260 dst : destination file path 261 """ 262 _handle_errors(ctypes.windll.kernel32.MoveFileExW( 263 _str_to_unicode(src), _str_to_unicode(dst), 264 _windows_default_flags | _MOVEFILE_REPLACE_EXISTING 265 ), src) 266 267 268def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): 269 """Download a given URL 270 271 Parameters 272 ---------- 273 url : str 274 URL to download 275 path : str, optional 276 Destination path to store downloaded file. By default stores to the 277 current directory with same name as in url. 278 overwrite : bool, optional 279 Whether to overwrite destination file if already exists. 280 sha1_hash : str, optional 281 Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 282 but doesn't match. 283 retries : integer, default 5 284 The number of times to attempt the download in case of failure or non 200 return codes 285 verify_ssl : bool, default True 286 Verify SSL certificates. 287 288 Returns 289 ------- 290 str 291 The file path of the downloaded file. 292 """ 293 if path is None: 294 fname = url.split('/')[-1] 295 # Empty filenames are invalid 296 assert fname, 'Can\'t construct file-name from this URL. ' \ 297 'Please set the `path` option manually.' 298 else: 299 path = os.path.expanduser(path) 300 if os.path.isdir(path): 301 fname = os.path.join(path, url.split('/')[-1]) 302 else: 303 fname = path 304 assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format( 305 retries) 306 307 if not verify_ssl: 308 warnings.warn( 309 'Unverified HTTPS request is being made (verify_ssl=False). ' 310 'Adding certificate verification is strongly advised.') 311 312 if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 313 dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 314 if not os.path.exists(dirname): 315 os.makedirs(dirname, exist_ok=True) 316 while retries + 1 > 0: 317 # Disable pyling too broad Exception 318 # pylint: disable=W0703 319 try: 320 print('Downloading {} from {}...'.format(fname, url)) 321 r = requests.get(url, stream=True, verify=verify_ssl) 322 if r.status_code != 200: 323 raise RuntimeError('Failed downloading url {}'.format(url)) 324 # create uuid for temporary files 325 random_uuid = str(uuid.uuid4()) 326 with open('{}.{}'.format(fname, random_uuid), 'wb') as f: 327 for chunk in r.iter_content(chunk_size=1024): 328 if chunk: # filter out keep-alive new chunks 329 f.write(chunk) 330 # if the target file exists(created by other processes) 331 # and have the same hash with target file 332 # delete the temporary file 333 if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 334 # atmoic operation in the same file system 335 replace_file('{}.{}'.format(fname, random_uuid), fname) 336 else: 337 try: 338 os.remove('{}.{}'.format(fname, random_uuid)) 339 except OSError: 340 pass 341 finally: 342 warnings.warn( 343 'File {} exists in file system so the downloaded file is deleted'.format(fname)) 344 if sha1_hash and not check_sha1(fname, sha1_hash): 345 raise UserWarning( 346 'File {} is downloaded but the content hash does not match.' 347 ' The repo may be outdated or download may be incomplete. ' 348 'If the "repo_url" is overridden, consider switching to ' 349 'the default repo.'.format(fname)) 350 break 351 except Exception as e: 352 retries -= 1 353 if retries <= 0: 354 raise e 355 356 print('download failed due to {}, retrying, {} attempt{} left' 357 .format(repr(e), retries, 's' if retries > 1 else '')) 358 359 return fname 360 361def _get_repo_url(): 362 """Return the base URL for Gluon dataset and model repository.""" 363 default_repo = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/' 364 repo_url = os.environ.get('MXNET_GLUON_REPO', default_repo) 365 if repo_url[-1] != '/': 366 repo_url = repo_url+'/' 367 return repo_url 368 369def _get_repo_file_url(namespace, filename): 370 """Return the URL for hosted file in Gluon repository. 371 372 Parameters 373 ---------- 374 namespace : str 375 Namespace of the file. 376 filename : str 377 Name of the file 378 """ 379 return '{base_url}{namespace}/{filename}'.format(base_url=_get_repo_url(), 380 namespace=namespace, 381 filename=filename) 382 383def _brief_print_list(lst, limit=7): 384 """Print at most `limit` elements of list.""" 385 lst = list(lst) 386 if len(lst) > limit: 387 return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \ 388 _brief_print_list(lst[-limit//2:], limit) 389 return ', '.join(["'%s'"%str(i) for i in lst]) 390 391 392class HookHandle(object): 393 """A handle that can attach/detach a hook.""" 394 395 def __init__(self): 396 self._hooks_dict_ref = None 397 self._id = None 398 399 def attach(self, hooks_dict, hook): 400 assert not self._hooks_dict_ref, 'The same handle cannot be attached twice.' 401 self._id = id(hook) 402 hooks_dict[self._id] = hook 403 self._hooks_dict_ref = weakref.ref(hooks_dict) 404 405 def detach(self): 406 hooks_dict = self._hooks_dict_ref() 407 if hooks_dict is not None and self._id in hooks_dict: 408 del hooks_dict[self._id] 409 410 def __getstate__(self): 411 return (self._hooks_dict_ref(), self._id) 412 413 def __setstate__(self, state): 414 if state[0] is None: 415 self._hooks_dict_ref = weakref.ref(collections.OrderedDict()) 416 else: 417 self._hooks_dict_ref = weakref.ref(state[0]) 418 self._id = state[1] 419 420 def __enter__(self): 421 return self 422 423 def __exit__(self, ptype, value, trace): 424 self.detach() 425 426 427def shape_is_known(shape): 428 """Check whether a shape is completely known with or without np semantics. 429 430 Please see the doc of is_np_shape for more details. 431 """ 432 if shape is None: 433 return False 434 unknown_dim_size = -1 if is_np_shape() else 0 435 if len(shape) == 0: 436 return unknown_dim_size == -1 437 for dim_size in shape: 438 if dim_size == unknown_dim_size: 439 return False 440 assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \ 441 "received {}".format(unknown_dim_size, dim_size) 442 return True 443 444 445def _check_same_symbol_type(symbols): 446 """Check whether all the symbols in the list are of the same type. 447 Raise type error if the types are different. Return the class of 448 the symbols.""" 449 from ..symbol.numpy import _Symbol as np_symbol 450 from ..symbol import Symbol as nd_symbol 451 is_np_sym = isinstance(symbols[0], np_symbol) 452 for s in symbols[1:]: 453 if is_np_sym != isinstance(s, np_symbol): 454 raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol ' 455 '(mx.sym.np._Symbol) in outputs. This will prevent you from building ' 456 'a computation graph by grouping them since different types of symbols ' 457 'are not allowed to be grouped in Gluon to form a computation graph. ' 458 'You will need to convert them to the same type of symbols, either ' 459 'classic or numpy following this rule: if you want numpy ndarray ' 460 'output(s) from the computation graph, please convert all the classic ' 461 'symbols in the list to numpy symbols by calling `as_np_ndarray()` ' 462 'on each of them; if you want classic ndarray output(s) from the ' 463 'computation graph, please convert all the numpy symbols in the list ' 464 'to classic symbols by calling `as_nd_ndarray()` on each of them.') 465 return np_symbol if is_np_sym else nd_symbol 466 467 468def _check_all_np_ndarrays(out): 469 """Check if ndarrays/symbols in out are all np.ndarray/np._Symbol.""" 470 from ..numpy import ndarray as np_ndarray 471 from ..symbol.numpy import _Symbol as np_symbol 472 from ..symbol import Symbol as nd_symbol 473 from ..ndarray import NDArray as nd_ndarray 474 475 # pylint: disable=no-else-raise 476 if isinstance(out, (nd_ndarray, nd_symbol)) and not isinstance(out, (np_ndarray, np_symbol)): 477 raise TypeError("Block's output ndarrays/symbols must be of type `mxnet.numpy.ndarray`" 478 " or `mxnet.symbol.numpy._Symbol`, while got output type {}" 479 .format(str(type(out)))) 480 elif isinstance(out, (list, tuple)): 481 for i in out: 482 _check_all_np_ndarrays(i) 483 # pylint: enable=no-else-raise 484