1"""Fast Pose Utils for loading parameters"""
2from mxnet import random
3from mxnet import ndarray
4from mxnet.initializer import Initializer
5
6
7def _try_load_parameters(self, filename=None, model=None, ctx=None, allow_missing=False,
8                         ignore_extra=False):
9    def getblock(parent, name):
10        if len(name) == 1:
11            if name[0].isnumeric():
12                return parent[int(name[0])]
13            else:
14                return getattr(parent, name[0])
15        else:
16            if name[0].isnumeric():
17                return getblock(parent[int(name[0])], name[1:])
18            else:
19                return getblock(getattr(parent, name[0]), name[1:])
20    if filename is not None:
21        loaded = ndarray.load(filename)
22    else:
23        loaded = {k: v.data() for k, v in model._collect_params_with_prefix().items()}
24    params = self._collect_params_with_prefix()
25    if not loaded and not params:
26        return
27
28    if not any('.' in i for i in loaded.keys()):
29        # legacy loading
30        del loaded
31        self.collect_params().load(
32            filename, ctx, allow_missing, ignore_extra, self.prefix)
33        return
34
35    for name in loaded:
36        if name in params:
37            if params[name].shape != loaded[name].shape:
38                continue
39            params[name]._load_init(loaded[name], ctx)
40
41
42def _load_from_pytorch(self, filename, ctx=None):
43    import torch
44    from mxnet import nd
45    loaded = torch.load(filename)
46    params = self._collect_params_with_prefix()
47
48    new_params = {}
49
50    for name in loaded:
51        if 'bn' in name or 'batchnorm' in name or '.downsample.1.' in name:
52            if 'weight' in name:
53                mxnet_name = name.replace('weight', 'gamma')
54            elif 'bias' in name:
55                mxnet_name = name.replace('bias', 'beta')
56            else:
57                mxnet_name = name
58            new_params[mxnet_name] = nd.array(loaded[name].cpu().data.numpy())
59        else:
60            new_params[name] = nd.array(loaded[name].cpu().data.numpy())
61
62    for name in new_params:
63        if name not in params:
64            print('==={}==='.format(name))
65            raise Exception
66        if name in params:
67            params[name]._load_init(new_params[name], ctx=ctx)
68
69
70class ZeroUniform(Initializer):
71    """Initializes weights with random values uniformly sampled from a given range.
72
73    Parameters
74    ----------
75    scale : float, optional
76        The bound on the range of the generated random values.
77        Values are generated from the range [0, `scale`].
78        Default scale is 1.
79
80    Example
81    -------
82    >>> # Given 'module', an instance of 'mxnet.module.Module', initialize weights
83    >>> # to random values uniformly sampled between 0 and 0.1.
84    ...
85    >>> init = ZeroUniform(0.1)
86    >>> module.init_params(init)
87    >>> for dictionary in module.get_params():
88    ...     for key in dictionary:
89    ...         print(key)
90    ...         print(dictionary[key].asnumpy())
91    ...
92    fullyconnected0_weight
93    [[ 0.01360891 0.02144304  0.08511933]]
94    """
95    def __init__(self, scale=1):
96        super(ZeroUniform, self).__init__(scale=scale)
97        self.scale = scale
98
99    def _init_weight(self, _, arr):
100        random.uniform(0, self.scale, out=arr)
101