1"""Fast Pose Utils for loading parameters""" 2from mxnet import random 3from mxnet import ndarray 4from mxnet.initializer import Initializer 5 6 7def _try_load_parameters(self, filename=None, model=None, ctx=None, allow_missing=False, 8 ignore_extra=False): 9 def getblock(parent, name): 10 if len(name) == 1: 11 if name[0].isnumeric(): 12 return parent[int(name[0])] 13 else: 14 return getattr(parent, name[0]) 15 else: 16 if name[0].isnumeric(): 17 return getblock(parent[int(name[0])], name[1:]) 18 else: 19 return getblock(getattr(parent, name[0]), name[1:]) 20 if filename is not None: 21 loaded = ndarray.load(filename) 22 else: 23 loaded = {k: v.data() for k, v in model._collect_params_with_prefix().items()} 24 params = self._collect_params_with_prefix() 25 if not loaded and not params: 26 return 27 28 if not any('.' in i for i in loaded.keys()): 29 # legacy loading 30 del loaded 31 self.collect_params().load( 32 filename, ctx, allow_missing, ignore_extra, self.prefix) 33 return 34 35 for name in loaded: 36 if name in params: 37 if params[name].shape != loaded[name].shape: 38 continue 39 params[name]._load_init(loaded[name], ctx) 40 41 42def _load_from_pytorch(self, filename, ctx=None): 43 import torch 44 from mxnet import nd 45 loaded = torch.load(filename) 46 params = self._collect_params_with_prefix() 47 48 new_params = {} 49 50 for name in loaded: 51 if 'bn' in name or 'batchnorm' in name or '.downsample.1.' in name: 52 if 'weight' in name: 53 mxnet_name = name.replace('weight', 'gamma') 54 elif 'bias' in name: 55 mxnet_name = name.replace('bias', 'beta') 56 else: 57 mxnet_name = name 58 new_params[mxnet_name] = nd.array(loaded[name].cpu().data.numpy()) 59 else: 60 new_params[name] = nd.array(loaded[name].cpu().data.numpy()) 61 62 for name in new_params: 63 if name not in params: 64 print('==={}==='.format(name)) 65 raise Exception 66 if name in params: 67 params[name]._load_init(new_params[name], ctx=ctx) 68 69 70class ZeroUniform(Initializer): 71 """Initializes weights with random values uniformly sampled from a given range. 72 73 Parameters 74 ---------- 75 scale : float, optional 76 The bound on the range of the generated random values. 77 Values are generated from the range [0, `scale`]. 78 Default scale is 1. 79 80 Example 81 ------- 82 >>> # Given 'module', an instance of 'mxnet.module.Module', initialize weights 83 >>> # to random values uniformly sampled between 0 and 0.1. 84 ... 85 >>> init = ZeroUniform(0.1) 86 >>> module.init_params(init) 87 >>> for dictionary in module.get_params(): 88 ... for key in dictionary: 89 ... print(key) 90 ... print(dictionary[key].asnumpy()) 91 ... 92 fullyconnected0_weight 93 [[ 0.01360891 0.02144304 0.08511933]] 94 """ 95 def __init__(self, scale=1): 96 super(ZeroUniform, self).__init__(scale=scale) 97 self.scale = scale 98 99 def _init_weight(self, _, arr): 100 random.uniform(0, self.scale, out=arr) 101