1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=
20"""Parallelization utility optimizer."""
21
22__all__ = ['split_data', 'split_and_load', 'clip_global_norm',
23           'check_sha1', 'download', 'replace_file']
24
25import os
26import sys
27import hashlib
28import uuid
29import warnings
30import collections
31import weakref
32import requests
33
34import numpy as np
35
36from .. import ndarray
37from ..util import is_np_shape, is_np_array
38from .. import numpy as _mx_np  # pylint: disable=reimported
39
40
41def split_data(data, num_slice, batch_axis=0, even_split=True):
42    """Splits an NDArray into `num_slice` slices along `batch_axis`.
43    Usually used for data parallelism where each slices is sent
44    to one device (i.e. GPU).
45
46    Parameters
47    ----------
48    data : NDArray
49        A batch of data.
50    num_slice : int
51        Number of desired slices.
52    batch_axis : int, default 0
53        The axis along which to slice.
54    even_split : bool, default True
55        Whether to force all slices to have the same number of elements.
56        If `True`, an error will be raised when `num_slice` does not evenly
57        divide `data.shape[batch_axis]`.
58
59    Returns
60    -------
61    list of NDArray
62        Return value is a list even if `num_slice` is 1.
63    """
64    size = data.shape[batch_axis]
65    if even_split and size % num_slice != 0:
66        raise ValueError(
67            "data with shape %s cannot be evenly split into %d slices along axis %d. " \
68            "Use a batch size that's multiple of %d or set even_split=False to allow " \
69            "uneven partitioning of data."%(
70                str(data.shape), num_slice, batch_axis, num_slice))
71
72    n_each_section, extras = divmod(size, num_slice)
73    section_sizes = [0] + (extras * [n_each_section + 1] +
74                           (num_slice - extras) * [n_each_section])
75    div_points = np.array(section_sizes).cumsum()
76    if is_np_array():
77        slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis)
78    else:
79        slices = []
80        if batch_axis != 0:
81            for i in range(num_slice):
82                st = div_points[i]
83                end = div_points[i + 1]
84                slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end))
85        else:
86            # Fixes issue: https://github.com/apache/incubator-mxnet/issues/19268
87            slices = [data[div_points[i]:div_points[i + 1]] if i < num_slice - 1 else data[div_points[i]:size]
88                      for i in range(num_slice)]
89    return slices
90
91
92def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
93    """Splits an NDArray into `len(ctx_list)` slices along `batch_axis` and loads
94    each slice to one context in `ctx_list`.
95
96    Parameters
97    ----------
98    data : NDArray or ndarray
99        A batch of data.
100    ctx_list : list of Context
101        A list of Contexts.
102    batch_axis : int, default 0
103        The axis along which to slice.
104    even_split : bool, default True
105        Whether to force all slices to have the same number of elements.
106
107    Returns
108    -------
109    list of NDArrays or ndarrays
110        Each corresponds to a context in `ctx_list`.
111    """
112    array_fn = _mx_np.array if is_np_array() else ndarray.array
113    if not isinstance(data, ndarray.NDArray):
114        data = array_fn(data, ctx=ctx_list[0])
115    if len(ctx_list) == 1:
116        return [data.as_in_context(ctx_list[0])]
117
118    slices = split_data(data, len(ctx_list), batch_axis, even_split)
119    return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)]
120
121
122def clip_global_norm(arrays, max_norm, check_isfinite=True):
123    """Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
124
125    Parameters
126    ----------
127    arrays : list of NDArray
128    max_norm : float
129    check_isfinite : bool, default True
130         If True, check that the total_norm is finite (not nan or inf). This
131         requires a blocking .asscalar() call.
132
133    Returns
134    -------
135    NDArray or float
136      Total norm. Return type is NDArray of shape (1,) if check_isfinite is
137      False. Otherwise a float is returned.
138
139    """
140    def _norm(array):
141        if array.stype == 'default':
142            x = array.reshape((-1,))
143            return ndarray.dot(x, x)
144        return array.norm().square()
145    assert len(arrays) > 0
146    ctx = arrays[0].context
147    total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays])
148    total_norm = ndarray.sqrt(total_norm)
149    if check_isfinite:
150        if not np.isfinite(total_norm.asscalar()):
151            warnings.warn(
152                UserWarning('nan or inf is detected. '
153                            'Clipping results will be undefined.'), stacklevel=2)
154    scale = max_norm / (total_norm + 1e-8)
155    scale = ndarray.min(ndarray.concat(scale, ndarray.ones(1, ctx=ctx), dim=0))
156    for arr in arrays:
157        arr *= scale.as_in_context(arr.context)
158    if check_isfinite:
159        return total_norm.asscalar()
160    else:
161        return total_norm
162
163
164def _indent(s_, numSpaces):
165    """Indent string
166    """
167    s = s_.split('\n')
168    if len(s) == 1:
169        return s_
170    first = s.pop(0)
171    s = [first] + [(numSpaces * ' ') + line for line in s]
172    s = '\n'.join(s)
173    return s
174
175
176def check_sha1(filename, sha1_hash):
177    """Check whether the sha1 hash of the file content matches the expected hash.
178
179    Parameters
180    ----------
181    filename : str
182        Path to the file.
183    sha1_hash : str
184        Expected sha1 hash in hexadecimal digits.
185
186    Returns
187    -------
188    bool
189        Whether the file content matches the expected hash.
190    """
191    sha1 = hashlib.sha1()
192    with open(filename, 'rb') as f:
193        while True:
194            data = f.read(1048576)
195            if not data:
196                break
197            sha1.update(data)
198
199    return sha1.hexdigest() == sha1_hash
200
201
202if not sys.platform.startswith('win32'):
203    # refer to https://github.com/untitaker/python-atomicwrites
204    def replace_file(src, dst):
205        """Implement atomic os.replace with linux and OSX.
206
207        Parameters
208        ----------
209        src : source file path
210        dst : destination file path
211        """
212        try:
213            os.rename(src, dst)
214        except OSError:
215            try:
216                os.remove(src)
217            except OSError:
218                pass
219            finally:
220                raise OSError(
221                    'Moving downloaded temp file - {}, to {} failed. \
222                    Please retry the download.'.format(src, dst))
223else:
224    import ctypes
225
226    _MOVEFILE_REPLACE_EXISTING = 0x1
227    # Setting this value guarantees that a move performed as a copy
228    # and delete operation is flushed to disk before the function returns.
229    # The flush occurs at the end of the copy operation.
230    _MOVEFILE_WRITE_THROUGH = 0x8
231    _windows_default_flags = _MOVEFILE_WRITE_THROUGH
232
233    def _str_to_unicode(x):
234        """Handle text decoding. Internal use only"""
235        if not isinstance(x, str):
236            return x.decode(sys.getfilesystemencoding())
237        return x
238
239    def _handle_errors(rv, src):
240        """Handle WinError. Internal use only"""
241        if not rv:
242            msg = ctypes.FormatError(ctypes.GetLastError())
243            # if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile
244            try:
245                os.remove(src)
246            except OSError:
247                pass
248            finally:
249                raise OSError(msg)
250
251    def replace_file(src, dst):
252        """Implement atomic os.replace with windows.
253
254        refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
255        The function fails when one of the process(copy, flush, delete) fails.
256
257        Parameters
258        ----------
259        src : source file path
260        dst : destination file path
261        """
262        _handle_errors(ctypes.windll.kernel32.MoveFileExW(
263            _str_to_unicode(src), _str_to_unicode(dst),
264            _windows_default_flags | _MOVEFILE_REPLACE_EXISTING
265        ), src)
266
267
268def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
269    """Download a given URL
270
271    Parameters
272    ----------
273    url : str
274        URL to download
275    path : str, optional
276        Destination path to store downloaded file. By default stores to the
277        current directory with same name as in url.
278    overwrite : bool, optional
279        Whether to overwrite destination file if already exists.
280    sha1_hash : str, optional
281        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
282        but doesn't match.
283    retries : integer, default 5
284        The number of times to attempt the download in case of failure or non 200 return codes
285    verify_ssl : bool, default True
286        Verify SSL certificates.
287
288    Returns
289    -------
290    str
291        The file path of the downloaded file.
292    """
293    if path is None:
294        fname = url.split('/')[-1]
295        # Empty filenames are invalid
296        assert fname, 'Can\'t construct file-name from this URL. ' \
297            'Please set the `path` option manually.'
298    else:
299        path = os.path.expanduser(path)
300        if os.path.isdir(path):
301            fname = os.path.join(path, url.split('/')[-1])
302        else:
303            fname = path
304    assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format(
305        retries)
306
307    if not verify_ssl:
308        warnings.warn(
309            'Unverified HTTPS request is being made (verify_ssl=False). '
310            'Adding certificate verification is strongly advised.')
311
312    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
313        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
314        if not os.path.exists(dirname):
315            os.makedirs(dirname, exist_ok=True)
316        while retries + 1 > 0:
317            # Disable pyling too broad Exception
318            # pylint: disable=W0703
319            try:
320                print('Downloading {} from {}...'.format(fname, url))
321                r = requests.get(url, stream=True, verify=verify_ssl)
322                if r.status_code != 200:
323                    raise RuntimeError('Failed downloading url {}'.format(url))
324                # create uuid for temporary files
325                random_uuid = str(uuid.uuid4())
326                with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
327                    for chunk in r.iter_content(chunk_size=1024):
328                        if chunk: # filter out keep-alive new chunks
329                            f.write(chunk)
330                # if the target file exists(created by other processes)
331                # and have the same hash with target file
332                # delete the temporary file
333                if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
334                    # atmoic operation in the same file system
335                    replace_file('{}.{}'.format(fname, random_uuid), fname)
336                else:
337                    try:
338                        os.remove('{}.{}'.format(fname, random_uuid))
339                    except OSError:
340                        pass
341                    finally:
342                        warnings.warn(
343                            'File {} exists in file system so the downloaded file is deleted'.format(fname))
344                if sha1_hash and not check_sha1(fname, sha1_hash):
345                    raise UserWarning(
346                        'File {} is downloaded but the content hash does not match.'
347                        ' The repo may be outdated or download may be incomplete. '
348                        'If the "repo_url" is overridden, consider switching to '
349                        'the default repo.'.format(fname))
350                break
351            except Exception as e:
352                retries -= 1
353                if retries <= 0:
354                    raise e
355
356                print('download failed due to {}, retrying, {} attempt{} left'
357                      .format(repr(e), retries, 's' if retries > 1 else ''))
358
359    return fname
360
361def _get_repo_url():
362    """Return the base URL for Gluon dataset and model repository."""
363    default_repo = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'
364    repo_url = os.environ.get('MXNET_GLUON_REPO', default_repo)
365    if repo_url[-1] != '/':
366        repo_url = repo_url+'/'
367    return repo_url
368
369def _get_repo_file_url(namespace, filename):
370    """Return the URL for hosted file in Gluon repository.
371
372    Parameters
373    ----------
374    namespace : str
375        Namespace of the file.
376    filename : str
377        Name of the file
378    """
379    return '{base_url}{namespace}/{filename}'.format(base_url=_get_repo_url(),
380                                                     namespace=namespace,
381                                                     filename=filename)
382
383def _brief_print_list(lst, limit=7):
384    """Print at most `limit` elements of list."""
385    lst = list(lst)
386    if len(lst) > limit:
387        return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \
388            _brief_print_list(lst[-limit//2:], limit)
389    return ', '.join(["'%s'"%str(i) for i in lst])
390
391
392class HookHandle(object):
393    """A handle that can attach/detach a hook."""
394
395    def __init__(self):
396        self._hooks_dict_ref = None
397        self._id = None
398
399    def attach(self, hooks_dict, hook):
400        assert not self._hooks_dict_ref, 'The same handle cannot be attached twice.'
401        self._id = id(hook)
402        hooks_dict[self._id] = hook
403        self._hooks_dict_ref = weakref.ref(hooks_dict)
404
405    def detach(self):
406        hooks_dict = self._hooks_dict_ref()
407        if hooks_dict is not None and self._id in hooks_dict:
408            del hooks_dict[self._id]
409
410    def __getstate__(self):
411        return (self._hooks_dict_ref(), self._id)
412
413    def __setstate__(self, state):
414        if state[0] is None:
415            self._hooks_dict_ref = weakref.ref(collections.OrderedDict())
416        else:
417            self._hooks_dict_ref = weakref.ref(state[0])
418        self._id = state[1]
419
420    def __enter__(self):
421        return self
422
423    def __exit__(self, ptype, value, trace):
424        self.detach()
425
426
427def shape_is_known(shape):
428    """Check whether a shape is completely known with or without np semantics.
429
430    Please see the doc of is_np_shape for more details.
431    """
432    if shape is None:
433        return False
434    unknown_dim_size = -1 if is_np_shape() else 0
435    if len(shape) == 0:
436        return unknown_dim_size == -1
437    for dim_size in shape:
438        if dim_size == unknown_dim_size:
439            return False
440        assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \
441                                            "received {}".format(unknown_dim_size, dim_size)
442    return True
443
444
445def _check_same_symbol_type(symbols):
446    """Check whether all the symbols in the list are of the same type.
447    Raise type error if the types are different. Return the class of
448    the symbols."""
449    from ..symbol.numpy import _Symbol as np_symbol
450    from ..symbol import Symbol as nd_symbol
451    is_np_sym = isinstance(symbols[0], np_symbol)
452    for s in symbols[1:]:
453        if is_np_sym != isinstance(s, np_symbol):
454            raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol '
455                            '(mx.sym.np._Symbol) in outputs. This will prevent you from building '
456                            'a computation graph by grouping them since different types of symbols '
457                            'are not allowed to be grouped in Gluon to form a computation graph. '
458                            'You will need to convert them to the same type of symbols, either '
459                            'classic or numpy following this rule: if you want numpy ndarray '
460                            'output(s) from the computation graph, please convert all the classic '
461                            'symbols in the list to numpy symbols by calling `as_np_ndarray()` '
462                            'on each of them; if you want classic ndarray output(s) from the '
463                            'computation graph, please convert all the numpy symbols in the list '
464                            'to classic symbols by calling `as_nd_ndarray()` on each of them.')
465    return np_symbol if is_np_sym else nd_symbol
466
467
468def _check_all_np_ndarrays(out):
469    """Check if ndarrays/symbols in out are all np.ndarray/np._Symbol."""
470    from ..numpy import ndarray as np_ndarray
471    from ..symbol.numpy import _Symbol as np_symbol
472    from ..symbol import Symbol as nd_symbol
473    from ..ndarray import NDArray as nd_ndarray
474
475    # pylint: disable=no-else-raise
476    if isinstance(out, (nd_ndarray, nd_symbol)) and not isinstance(out, (np_ndarray, np_symbol)):
477        raise TypeError("Block's output ndarrays/symbols must be of type `mxnet.numpy.ndarray`"
478                        " or `mxnet.symbol.numpy._Symbol`, while got output type {}"
479                        .format(str(type(out))))
480    elif isinstance(out, (list, tuple)):
481        for i in out:
482            _check_all_np_ndarrays(i)
483    # pylint: enable=no-else-raise
484