1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import logging 19import mxnet as mx 20import numpy as np 21 22 23class RandomNumberQueue(object): 24 def __init__(self, pool_size=1000): 25 self._pool = np.random.rand(pool_size) 26 self._index = 0 27 28 def get_sample(self): 29 if self._index >= len(self._pool): 30 self._pool = np.random.rand(len(self._pool)) 31 self._index = 0 32 self._index += 1 33 return self._pool[self._index-1] 34 35 36class StochasticDepthModule(mx.module.BaseModule): 37 """Stochastic depth module is a two branch computation: one is actual computing and the 38 other is the skip computing (usually an identity map). This is similar to a Residual block, 39 except that a random variable is used to randomly turn off the computing branch, in order 40 to save computation during training. 41 42 Parameters 43 ---------- 44 symbol_compute: Symbol 45 The computation branch. 46 symbol_skip: Symbol 47 The skip branch. Could be None, in which case an identity map will be automatically 48 used. Note the two branch should produce exactly the same output shapes. 49 data_names: list of str 50 Default is `['data']`. Indicating the input names. Note if `symbol_skip` is not None, 51 it should have the same input names as `symbol_compute`. 52 label_names: list of str 53 Default is None, indicating that this module does not take labels. 54 death_rate: float 55 Default 0. The probability of turning off the computing branch. 56 """ 57 def __init__(self, symbol_compute, symbol_skip=None, 58 data_names=('data',), label_names=None, 59 logger=logging, context=mx.context.cpu(), 60 work_load_list=None, fixed_param_names=None, 61 death_rate=0): 62 super(StochasticDepthModule, self).__init__(logger=logger) 63 64 self._module_compute = mx.module.Module( 65 symbol_compute, data_names=data_names, 66 label_names=label_names, logger=logger, 67 context=context, work_load_list=work_load_list, 68 fixed_param_names=fixed_param_names) 69 70 if symbol_skip is not None: 71 self._module_skip = mx.module.Module( 72 symbol_skip, data_names=data_names, 73 label_names=label_names, logger=logger, 74 context=context, work_load_list=work_load_list, 75 fixed_param_names=fixed_param_names) 76 else: 77 self._module_skip = None 78 79 self._open_rate = 1 - death_rate 80 self._gate_open = True 81 self._outputs = None 82 self._input_grads = None 83 self._rnd_queue = RandomNumberQueue() 84 85 @property 86 def data_names(self): 87 return self._module_compute.data_names 88 89 @property 90 def output_names(self): 91 return self._module_compute.output_names 92 93 @property 94 def data_shapes(self): 95 return self._module_compute.data_shapes 96 97 @property 98 def label_shapes(self): 99 return self._module_compute.label_shapes 100 101 @property 102 def output_shapes(self): 103 return self._module_compute.output_shapes 104 105 def get_params(self): 106 params = self._module_compute.get_params() 107 if self._module_skip: 108 params = [x.copy() for x in params] 109 skip_params = self._module_skip.get_params() 110 for a, b in zip(params, skip_params): 111 # make sure they do not contain duplicated param names 112 assert len(set(a.keys()) & set(b.keys())) == 0 113 a.update(b) 114 return params 115 116 def init_params(self, *args, **kwargs): 117 self._module_compute.init_params(*args, **kwargs) 118 if self._module_skip: 119 self._module_skip.init_params(*args, **kwargs) 120 121 def bind(self, *args, **kwargs): 122 self._module_compute.bind(*args, **kwargs) 123 if self._module_skip: 124 self._module_skip.bind(*args, **kwargs) 125 126 def init_optimizer(self, *args, **kwargs): 127 self._module_compute.init_optimizer(*args, **kwargs) 128 if self._module_skip: 129 self._module_skip.init_optimizer(*args, **kwargs) 130 131 def borrow_optimizer(self, shared_module): 132 self._module_compute.borrow_optimizer(shared_module._module_compute) 133 if self._module_skip: 134 self._module_skip.borrow_optimizer(shared_module._module_skip) 135 136 def forward(self, data_batch, is_train=None): 137 if is_train is None: 138 is_train = self._module_compute.for_training 139 140 if self._module_skip: 141 self._module_skip.forward(data_batch, is_train=True) 142 self._outputs = self._module_skip.get_outputs() 143 else: 144 self._outputs = data_batch.data 145 146 if is_train: 147 self._gate_open = self._rnd_queue.get_sample() < self._open_rate 148 if self._gate_open: 149 self._module_compute.forward(data_batch, is_train=True) 150 computed_outputs = self._module_compute.get_outputs() 151 for i in range(len(self._outputs)): 152 self._outputs[i] += computed_outputs[i] 153 154 else: # do expectation for prediction 155 self._module_compute.forward(data_batch, is_train=False) 156 computed_outputs = self._module_compute.get_outputs() 157 for i in range(len(self._outputs)): 158 self._outputs[i] += self._open_rate * computed_outputs[i] 159 160 def backward(self, out_grads=None): 161 if self._module_skip: 162 self._module_skip.backward(out_grads=out_grads) 163 self._input_grads = self._module_skip.get_input_grads() 164 else: 165 self._input_grads = out_grads 166 167 if self._gate_open: 168 self._module_compute.backward(out_grads=out_grads) 169 computed_input_grads = self._module_compute.get_input_grads() 170 for i in range(len(self._input_grads)): 171 self._input_grads[i] += computed_input_grads[i] 172 173 def update(self): 174 self._module_compute.update() 175 if self._module_skip: 176 self._module_skip.update() 177 178 def update_metric(self, eval_metric, labels): 179 self._module_compute.update_metric(eval_metric, labels) 180 if self._module_skip: 181 self._module_skip.update_metric(eval_metric, labels) 182 183 def get_outputs(self, merge_multi_context=True): 184 assert merge_multi_context, "Force merging for now" 185 return self._outputs 186 187 def get_input_grads(self, merge_multi_context=True): 188 assert merge_multi_context, "Force merging for now" 189 return self._input_grads 190