1import chainer
2from chainer.backends import cuda
3from chainer.functions.normalization import batch_normalization
4from chainer import initializers
5from chainer import link
6import chainer.utils
7from chainer import variable
8from chainermn.functions import batch_normalization as \
9    chainermn_batch_normalization
10
11import numpy
12import copy
13
14
15class MultiNodeBatchNormalization(link.Link):
16
17    """Batch normalization layer that can use the whole batch stats.
18
19    When using chainer.link.BatchNormalization, batch mean and std are
20    computed independently for the local batch in each worker. When local
21    batch size is too small, training is unstable due to unreliable batch
22    stats.
23
24    In contrast, when using this MultiNodeBatchNormalization, workers
25    communicate to conduct 'correct' batch normalization (e.g., obtaining
26    mean and std for the whole global batch).
27
28    This link works only with Chainer >= 2.0.0.
29
30    Args:
31        size (int or tuple of ints): Size (or shape) of channel
32            dimensions.
33        comm (ChainerMN communicator): communicator to share
34            the batch stats.
35        decay (float): Decay rate of moving average. It is used on training.
36        eps (float): Epsilon value for numerical stability.
37        dtype (numpy.dtype): Type to use in computing.
38        use_gamma (bool): If ``True``, use scaling parameter. Otherwise, use
39            unit(1) which makes no effect.
40        use_beta (bool): If ``True``, use shifting parameter. Otherwise, use
41            unit(0) which makes no effect.
42        communication_backend (str): ``mpi``, ``nccl`` or ``auto``. It is used
43            to determine communication backend. If ``auto``, use the best
44            communication backend for each communicator.
45    """
46
47    def __init__(self, size, comm, decay=0.9, eps=2e-5, dtype=None,
48                 use_gamma=True, use_beta=True,
49                 initial_gamma=None, initial_beta=None,
50                 communication_backend='auto'):
51        chainer.utils.experimental(
52            'chainermn.links.MultiNodeBatchNormalization')
53
54        super(MultiNodeBatchNormalization, self).__init__()
55        self._highprec_dtype = chainer.get_dtype(
56            dtype, map_mixed16=numpy.float32)
57        self.comm = comm
58        self.avg_mean = numpy.zeros(size, dtype=self._highprec_dtype)
59        self.register_persistent('avg_mean')
60        self.avg_var = numpy.zeros(size, dtype=self._highprec_dtype)
61        self.register_persistent('avg_var')
62        self.N = 0
63        self.register_persistent('N')
64        self.decay = decay
65        self.eps = eps
66
67        self._communication_backend = \
68            chainermn_batch_normalization.get_communication_backend(
69                comm, communication_backend)
70
71        with self.init_scope():
72            if use_gamma:
73                if initial_gamma is None:
74                    initial_gamma = 1
75                initial_gamma = initializers._get_initializer(initial_gamma)
76                initial_gamma.dtype = self._highprec_dtype
77                self.gamma = variable.Parameter(initial_gamma, size)
78            if use_beta:
79                if initial_beta is None:
80                    initial_beta = 0
81                initial_beta = initializers._get_initializer(initial_beta)
82                initial_beta.dtype = self._highprec_dtype
83                self.beta = variable.Parameter(initial_beta, size)
84
85    def __call__(self, x, finetune=False):
86        if hasattr(self, 'gamma'):
87            gamma = self.gamma
88        else:
89            with cuda.get_device_from_id(self._device_id):
90                gamma = variable.Variable(self.xp.ones(
91                    self.avg_mean.shape, dtype=self._highprec_dtype))
92        if hasattr(self, 'beta'):
93            beta = self.beta
94        else:
95            with cuda.get_device_from_id(self._device_id):
96                beta = variable.Variable(self.xp.zeros(
97                    self.avg_mean.shape, dtype=self._highprec_dtype))
98
99        if chainer.configuration.config.train:
100            if finetune:
101                self.N += 1
102                decay = 1. - 1. / self.N
103            else:
104                decay = self.decay
105
106            func = batch_normalization.BatchNormalization(
107                self.eps, self.avg_mean, self.avg_var, decay,
108                impl_selector=(
109                    chainermn_batch_normalization.MultiNodeBNImplSelector(
110                        self.comm, self._communication_backend)))
111
112            ret = func.apply((x, gamma, beta))[0]
113
114            self.avg_mean[:] = func.running_mean
115            self.avg_var[:] = func.running_var
116        else:
117            # Use running average statistics or fine-tuned statistics.
118            mean = variable.Variable(self.avg_mean)
119            var = variable.Variable(self.avg_var)
120            ret = batch_normalization.fixed_batch_normalization(
121                x, gamma, beta, mean, var, self.eps)
122        return ret
123
124    def start_finetuning(self):
125        """Resets the population count for collecting population statistics.
126
127        This method can be skipped if it is the first time to use the
128        fine-tuning mode. Otherwise, this method should be called before
129        starting the fine-tuning mode again.
130
131        """
132        self.N = 0
133
134    def __deepcopy__(self, memo):
135        to_be_preserved = ['comm']
136        preserved = {}
137        for name in to_be_preserved:
138            preserved[name] = getattr(self, name)
139            setattr(self, name, None)
140
141        ret = copy.deepcopy(super(MultiNodeBatchNormalization, self))
142
143        for name in to_be_preserved:
144            setattr(self, name, preserved[name])
145            setattr(ret, name, preserved[name])
146
147        return ret
148