1import chainer 2from chainer.backends import cuda 3from chainer.functions.normalization import batch_normalization 4from chainer import initializers 5from chainer import link 6import chainer.utils 7from chainer import variable 8from chainermn.functions import batch_normalization as \ 9 chainermn_batch_normalization 10 11import numpy 12import copy 13 14 15class MultiNodeBatchNormalization(link.Link): 16 17 """Batch normalization layer that can use the whole batch stats. 18 19 When using chainer.link.BatchNormalization, batch mean and std are 20 computed independently for the local batch in each worker. When local 21 batch size is too small, training is unstable due to unreliable batch 22 stats. 23 24 In contrast, when using this MultiNodeBatchNormalization, workers 25 communicate to conduct 'correct' batch normalization (e.g., obtaining 26 mean and std for the whole global batch). 27 28 This link works only with Chainer >= 2.0.0. 29 30 Args: 31 size (int or tuple of ints): Size (or shape) of channel 32 dimensions. 33 comm (ChainerMN communicator): communicator to share 34 the batch stats. 35 decay (float): Decay rate of moving average. It is used on training. 36 eps (float): Epsilon value for numerical stability. 37 dtype (numpy.dtype): Type to use in computing. 38 use_gamma (bool): If ``True``, use scaling parameter. Otherwise, use 39 unit(1) which makes no effect. 40 use_beta (bool): If ``True``, use shifting parameter. Otherwise, use 41 unit(0) which makes no effect. 42 communication_backend (str): ``mpi``, ``nccl`` or ``auto``. It is used 43 to determine communication backend. If ``auto``, use the best 44 communication backend for each communicator. 45 """ 46 47 def __init__(self, size, comm, decay=0.9, eps=2e-5, dtype=None, 48 use_gamma=True, use_beta=True, 49 initial_gamma=None, initial_beta=None, 50 communication_backend='auto'): 51 chainer.utils.experimental( 52 'chainermn.links.MultiNodeBatchNormalization') 53 54 super(MultiNodeBatchNormalization, self).__init__() 55 self._highprec_dtype = chainer.get_dtype( 56 dtype, map_mixed16=numpy.float32) 57 self.comm = comm 58 self.avg_mean = numpy.zeros(size, dtype=self._highprec_dtype) 59 self.register_persistent('avg_mean') 60 self.avg_var = numpy.zeros(size, dtype=self._highprec_dtype) 61 self.register_persistent('avg_var') 62 self.N = 0 63 self.register_persistent('N') 64 self.decay = decay 65 self.eps = eps 66 67 self._communication_backend = \ 68 chainermn_batch_normalization.get_communication_backend( 69 comm, communication_backend) 70 71 with self.init_scope(): 72 if use_gamma: 73 if initial_gamma is None: 74 initial_gamma = 1 75 initial_gamma = initializers._get_initializer(initial_gamma) 76 initial_gamma.dtype = self._highprec_dtype 77 self.gamma = variable.Parameter(initial_gamma, size) 78 if use_beta: 79 if initial_beta is None: 80 initial_beta = 0 81 initial_beta = initializers._get_initializer(initial_beta) 82 initial_beta.dtype = self._highprec_dtype 83 self.beta = variable.Parameter(initial_beta, size) 84 85 def __call__(self, x, finetune=False): 86 if hasattr(self, 'gamma'): 87 gamma = self.gamma 88 else: 89 with cuda.get_device_from_id(self._device_id): 90 gamma = variable.Variable(self.xp.ones( 91 self.avg_mean.shape, dtype=self._highprec_dtype)) 92 if hasattr(self, 'beta'): 93 beta = self.beta 94 else: 95 with cuda.get_device_from_id(self._device_id): 96 beta = variable.Variable(self.xp.zeros( 97 self.avg_mean.shape, dtype=self._highprec_dtype)) 98 99 if chainer.configuration.config.train: 100 if finetune: 101 self.N += 1 102 decay = 1. - 1. / self.N 103 else: 104 decay = self.decay 105 106 func = batch_normalization.BatchNormalization( 107 self.eps, self.avg_mean, self.avg_var, decay, 108 impl_selector=( 109 chainermn_batch_normalization.MultiNodeBNImplSelector( 110 self.comm, self._communication_backend))) 111 112 ret = func.apply((x, gamma, beta))[0] 113 114 self.avg_mean[:] = func.running_mean 115 self.avg_var[:] = func.running_var 116 else: 117 # Use running average statistics or fine-tuned statistics. 118 mean = variable.Variable(self.avg_mean) 119 var = variable.Variable(self.avg_var) 120 ret = batch_normalization.fixed_batch_normalization( 121 x, gamma, beta, mean, var, self.eps) 122 return ret 123 124 def start_finetuning(self): 125 """Resets the population count for collecting population statistics. 126 127 This method can be skipped if it is the first time to use the 128 fine-tuning mode. Otherwise, this method should be called before 129 starting the fine-tuning mode again. 130 131 """ 132 self.N = 0 133 134 def __deepcopy__(self, memo): 135 to_be_preserved = ['comm'] 136 preserved = {} 137 for name in to_be_preserved: 138 preserved[name] = getattr(self, name) 139 setattr(self, name, None) 140 141 ret = copy.deepcopy(super(MultiNodeBatchNormalization, self)) 142 143 for name in to_be_preserved: 144 setattr(self, name, preserved[name]) 145 setattr(ret, name, preserved[name]) 146 147 return ret 148