1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import logging
19import mxnet as mx
20import numpy as np
21
22
23class RandomNumberQueue(object):
24    def __init__(self, pool_size=1000):
25        self._pool = np.random.rand(pool_size)
26        self._index = 0
27
28    def get_sample(self):
29        if self._index >= len(self._pool):
30            self._pool = np.random.rand(len(self._pool))
31            self._index = 0
32        self._index += 1
33        return self._pool[self._index-1]
34
35
36class StochasticDepthModule(mx.module.BaseModule):
37    """Stochastic depth module is a two branch computation: one is actual computing and the
38    other is the skip computing (usually an identity map). This is similar to a Residual block,
39    except that a random variable is used to randomly turn off the computing branch, in order
40    to save computation during training.
41
42    Parameters
43    ----------
44    symbol_compute: Symbol
45        The computation branch.
46    symbol_skip: Symbol
47        The skip branch. Could be None, in which case an identity map will be automatically
48        used. Note the two branch should produce exactly the same output shapes.
49    data_names: list of str
50        Default is `['data']`. Indicating the input names. Note if `symbol_skip` is not None,
51        it should have the same input names as `symbol_compute`.
52    label_names: list of str
53        Default is None, indicating that this module does not take labels.
54    death_rate: float
55        Default 0. The probability of turning off the computing branch.
56    """
57    def __init__(self, symbol_compute, symbol_skip=None,
58                 data_names=('data',), label_names=None,
59                 logger=logging, context=mx.context.cpu(),
60                 work_load_list=None, fixed_param_names=None,
61                 death_rate=0):
62        super(StochasticDepthModule, self).__init__(logger=logger)
63
64        self._module_compute = mx.module.Module(
65            symbol_compute, data_names=data_names,
66            label_names=label_names, logger=logger,
67            context=context, work_load_list=work_load_list,
68            fixed_param_names=fixed_param_names)
69
70        if symbol_skip is not None:
71            self._module_skip = mx.module.Module(
72                symbol_skip, data_names=data_names,
73                label_names=label_names, logger=logger,
74                context=context, work_load_list=work_load_list,
75                fixed_param_names=fixed_param_names)
76        else:
77            self._module_skip = None
78
79        self._open_rate = 1 - death_rate
80        self._gate_open = True
81        self._outputs = None
82        self._input_grads = None
83        self._rnd_queue = RandomNumberQueue()
84
85    @property
86    def data_names(self):
87        return self._module_compute.data_names
88
89    @property
90    def output_names(self):
91        return self._module_compute.output_names
92
93    @property
94    def data_shapes(self):
95        return self._module_compute.data_shapes
96
97    @property
98    def label_shapes(self):
99        return self._module_compute.label_shapes
100
101    @property
102    def output_shapes(self):
103        return self._module_compute.output_shapes
104
105    def get_params(self):
106        params = self._module_compute.get_params()
107        if self._module_skip:
108            params = [x.copy() for x in params]
109            skip_params = self._module_skip.get_params()
110            for a, b in zip(params, skip_params):
111                # make sure they do not contain duplicated param names
112                assert len(set(a.keys()) & set(b.keys())) == 0
113                a.update(b)
114        return params
115
116    def init_params(self, *args, **kwargs):
117        self._module_compute.init_params(*args, **kwargs)
118        if self._module_skip:
119            self._module_skip.init_params(*args, **kwargs)
120
121    def bind(self, *args, **kwargs):
122        self._module_compute.bind(*args, **kwargs)
123        if self._module_skip:
124            self._module_skip.bind(*args, **kwargs)
125
126    def init_optimizer(self, *args, **kwargs):
127        self._module_compute.init_optimizer(*args, **kwargs)
128        if self._module_skip:
129            self._module_skip.init_optimizer(*args, **kwargs)
130
131    def borrow_optimizer(self, shared_module):
132        self._module_compute.borrow_optimizer(shared_module._module_compute)
133        if self._module_skip:
134            self._module_skip.borrow_optimizer(shared_module._module_skip)
135
136    def forward(self, data_batch, is_train=None):
137        if is_train is None:
138            is_train = self._module_compute.for_training
139
140        if self._module_skip:
141            self._module_skip.forward(data_batch, is_train=True)
142            self._outputs = self._module_skip.get_outputs()
143        else:
144            self._outputs = data_batch.data
145
146        if is_train:
147            self._gate_open = self._rnd_queue.get_sample() < self._open_rate
148            if self._gate_open:
149                self._module_compute.forward(data_batch, is_train=True)
150                computed_outputs = self._module_compute.get_outputs()
151                for i in range(len(self._outputs)):
152                    self._outputs[i] += computed_outputs[i]
153
154        else:  # do expectation for prediction
155            self._module_compute.forward(data_batch, is_train=False)
156            computed_outputs = self._module_compute.get_outputs()
157            for i in range(len(self._outputs)):
158                self._outputs[i] += self._open_rate * computed_outputs[i]
159
160    def backward(self, out_grads=None):
161        if self._module_skip:
162            self._module_skip.backward(out_grads=out_grads)
163            self._input_grads = self._module_skip.get_input_grads()
164        else:
165            self._input_grads = out_grads
166
167        if self._gate_open:
168            self._module_compute.backward(out_grads=out_grads)
169            computed_input_grads = self._module_compute.get_input_grads()
170            for i in range(len(self._input_grads)):
171                self._input_grads[i] += computed_input_grads[i]
172
173    def update(self):
174        self._module_compute.update()
175        if self._module_skip:
176            self._module_skip.update()
177
178    def update_metric(self, eval_metric, labels):
179        self._module_compute.update_metric(eval_metric, labels)
180        if self._module_skip:
181            self._module_skip.update_metric(eval_metric, labels)
182
183    def get_outputs(self, merge_multi_context=True):
184        assert merge_multi_context, "Force merging for now"
185        return self._outputs
186
187    def get_input_grads(self, merge_multi_context=True):
188        assert merge_multi_context, "Force merging for now"
189        return self._input_grads
190