1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Building blocks and utility for models.""" 18__all__ = ['apply_weight_drop'] 19 20import collections 21import functools 22import re 23import warnings 24 25from mxnet.gluon import Block, contrib, rnn 26from mxnet.gluon.model_zoo import model_store 27from ..data.utils import _load_pretrained_vocab 28from .parameter import WeightDropParameter 29from .lstmpcellwithclip import LSTMPCellWithClip 30 31# pylint: disable=too-many-nested-blocks 32 33 34def apply_weight_drop(block, local_param_regex, rate, axes=(), 35 weight_dropout_mode='training'): 36 """Apply weight drop to the parameter of a block. 37 38 Parameters 39 ---------- 40 block : Block or HybridBlock 41 The block whose parameter is to be applied weight-drop. 42 local_param_regex : str 43 The regex for parameter names used in the self.params.get(), such as 'weight'. 44 rate : float 45 Fraction of the input units to drop. Must be a number between 0 and 1. 46 axes : tuple of int, default () 47 The axes on which dropout mask is shared. If empty, regular dropout is applied. 48 weight_drop_mode : {'training', 'always'}, default 'training' 49 Whether the weight dropout should be applied only at training time, or always be applied. 50 51 Examples 52 -------- 53 >>> net = gluon.rnn.LSTM(10, num_layers=2, bidirectional=True) 54 >>> gluonnlp.model.apply_weight_drop(net, r'.*h2h_weight', 0.5) 55 >>> net.collect_params() 56 lstm0_ ( 57 Parameter lstm0_l0_i2h_weight (shape=(40, 0), dtype=float32) 58 WeightDropParameter lstm0_l0_h2h_weight (shape=(40, 10), dtype=float32, \ 59rate=0.5, mode=training) 60 Parameter lstm0_l0_i2h_bias (shape=(40,), dtype=float32) 61 Parameter lstm0_l0_h2h_bias (shape=(40,), dtype=float32) 62 Parameter lstm0_r0_i2h_weight (shape=(40, 0), dtype=float32) 63 WeightDropParameter lstm0_r0_h2h_weight (shape=(40, 10), dtype=float32, \ 64rate=0.5, mode=training) 65 Parameter lstm0_r0_i2h_bias (shape=(40,), dtype=float32) 66 Parameter lstm0_r0_h2h_bias (shape=(40,), dtype=float32) 67 Parameter lstm0_l1_i2h_weight (shape=(40, 20), dtype=float32) 68 WeightDropParameter lstm0_l1_h2h_weight (shape=(40, 10), dtype=float32, \ 69rate=0.5, mode=training) 70 Parameter lstm0_l1_i2h_bias (shape=(40,), dtype=float32) 71 Parameter lstm0_l1_h2h_bias (shape=(40,), dtype=float32) 72 Parameter lstm0_r1_i2h_weight (shape=(40, 20), dtype=float32) 73 WeightDropParameter lstm0_r1_h2h_weight (shape=(40, 10), dtype=float32, \ 74rate=0.5, mode=training) 75 Parameter lstm0_r1_i2h_bias (shape=(40,), dtype=float32) 76 Parameter lstm0_r1_h2h_bias (shape=(40,), dtype=float32) 77 ) 78 >>> ones = mx.nd.ones((3, 4, 5)) 79 >>> net.initialize() 80 >>> with mx.autograd.train_mode(): 81 ... net(ones).max().asscalar() != net(ones).max().asscalar() 82 True 83 """ 84 if not rate: 85 return 86 87 existing_params = _find_params(block, local_param_regex) 88 for (local_param_name, param), \ 89 (ref_params_list, ref_reg_params_list) in existing_params.items(): 90 if isinstance(param, WeightDropParameter): 91 continue 92 dropped_param = WeightDropParameter(param, rate, weight_dropout_mode, axes) 93 for ref_params in ref_params_list: 94 ref_params[param.name] = dropped_param 95 for ref_reg_params in ref_reg_params_list: 96 ref_reg_params[local_param_name] = dropped_param 97 if hasattr(block, local_param_name): 98 local_attr = getattr(block, local_param_name) 99 if local_attr == param: 100 local_attr = dropped_param 101 elif isinstance(local_attr, (list, tuple)): 102 if isinstance(local_attr, tuple): 103 local_attr = list(local_attr) 104 for i, v in enumerate(local_attr): 105 if v == param: 106 local_attr[i] = dropped_param 107 elif isinstance(local_attr, dict): 108 for k, v in local_attr: 109 if v == param: 110 local_attr[k] = dropped_param 111 else: 112 continue 113 if local_attr: 114 super(Block, block).__setattr__(local_param_name, local_attr) 115 116# pylint: enable=too-many-nested-blocks 117 118 119def _find_params(block, local_param_regex): 120 # return {(local_param_name, parameter): (referenced_params_list, 121 # referenced_reg_params_list)} 122 123 results = collections.defaultdict(lambda: ([], [])) 124 pattern = re.compile(local_param_regex) 125 local_param_names = ((local_param_name, p) for local_param_name, p in block._reg_params.items() 126 if pattern.match(local_param_name)) 127 128 for local_param_name, p in local_param_names: 129 ref_params_list, ref_reg_params_list = results[(local_param_name, p)] 130 ref_reg_params_list.append(block._reg_params) 131 132 params = block._params 133 while params: 134 if p.name in params._params: 135 ref_params_list.append(params._params) 136 if params._shared: 137 params = params._shared 138 warnings.warn('When applying weight drop, target parameter {} was found ' 139 'in a shared parameter dict. The parameter attribute of the ' 140 'original block on which the shared parameter dict was attached ' 141 'will not be updated with WeightDropParameter. If necessary, ' 142 'please update the attribute manually. The likely name of the ' 143 'attribute is ".{}"'.format(p.name, local_param_name)) 144 else: 145 break 146 147 if block._children: 148 if isinstance(block._children, list): 149 children = block._children 150 elif isinstance(block._children, dict): 151 children = block._children.values() 152 for c in children: 153 child_results = _find_params(c, local_param_regex) 154 for (child_p_name, child_p), (child_pd_list, child_rd_list) in child_results.items(): 155 pd_list, rd_list = results[(child_p_name, child_p)] 156 pd_list.extend(child_pd_list) 157 rd_list.extend(child_rd_list) 158 159 return results 160 161 162def _get_rnn_cell(mode, num_layers, input_size, hidden_size, 163 dropout, weight_dropout, 164 var_drop_in, var_drop_state, var_drop_out, 165 skip_connection, proj_size=None, cell_clip=None, proj_clip=None): 166 """create rnn cell given specs 167 168 Parameters 169 ---------- 170 mode : str 171 The type of RNN cell to use. Options are 'lstmpc', 'rnn_tanh', 'rnn_relu', 'lstm', 'gru'. 172 num_layers : int 173 The number of RNN cells in the encoder. 174 input_size : int 175 The initial input size of in the RNN cell. 176 hidden_size : int 177 The hidden size of the RNN cell. 178 dropout : float 179 The dropout rate to use for encoder output. 180 weight_dropout: float 181 The dropout rate to the hidden to hidden connections. 182 var_drop_in: float 183 The variational dropout rate for inputs. Won’t apply dropout if it equals 0. 184 var_drop_state: float 185 The variational dropout rate for state inputs on the first state channel. 186 Won’t apply dropout if it equals 0. 187 var_drop_out: float 188 The variational dropout rate for outputs. Won’t apply dropout if it equals 0. 189 skip_connection : bool 190 Whether to add skip connections (add RNN cell input to output) 191 proj_size : int 192 The projection size of each LSTMPCellWithClip cell. 193 Only available when the mode=lstmpc. 194 cell_clip : float 195 Clip cell state between [-cellclip, cell_clip] in LSTMPCellWithClip cell. 196 Only available when the mode=lstmpc. 197 proj_clip : float 198 Clip projection between [-projclip, projclip] in LSTMPCellWithClip cell 199 Only available when the mode=lstmpc. 200 """ 201 202 assert mode == 'lstmpc' or proj_size is None, \ 203 'proj_size takes effect only when mode is lstmpc' 204 assert mode == 'lstmpc' or cell_clip is None, \ 205 'cell_clip takes effect only when mode is lstmpc' 206 assert mode == 'lstmpc' or proj_clip is None, \ 207 'proj_clip takes effect only when mode is lstmpc' 208 209 rnn_cell = rnn.HybridSequentialRNNCell() 210 with rnn_cell.name_scope(): 211 for i in range(num_layers): 212 if mode == 'rnn_relu': 213 cell = rnn.RNNCell(hidden_size, 'relu', input_size=input_size) 214 elif mode == 'rnn_tanh': 215 cell = rnn.RNNCell(hidden_size, 'tanh', input_size=input_size) 216 elif mode == 'lstm': 217 cell = rnn.LSTMCell(hidden_size, input_size=input_size) 218 elif mode == 'gru': 219 cell = rnn.GRUCell(hidden_size, input_size=input_size) 220 elif mode == 'lstmpc': 221 cell = LSTMPCellWithClip(hidden_size, proj_size, 222 cell_clip=cell_clip, 223 projection_clip=proj_clip, 224 input_size=input_size) 225 if var_drop_in + var_drop_state + var_drop_out != 0: 226 cell = contrib.rnn.VariationalDropoutCell(cell, 227 var_drop_in, 228 var_drop_state, 229 var_drop_out) 230 231 if skip_connection: 232 cell = rnn.ResidualCell(cell) 233 234 rnn_cell.add(cell) 235 236 if i != num_layers - 1 and dropout != 0: 237 rnn_cell.add(rnn.DropoutCell(dropout)) 238 239 if weight_dropout: 240 apply_weight_drop(rnn_cell, 'h2h_weight', rate=weight_dropout) 241 242 return rnn_cell 243 244 245def _get_rnn_layer(mode, num_layers, input_size, hidden_size, dropout, weight_dropout): 246 """create rnn layer given specs""" 247 if mode == 'rnn_relu': 248 rnn_block = functools.partial(rnn.RNN, activation='relu') 249 elif mode == 'rnn_tanh': 250 rnn_block = functools.partial(rnn.RNN, activation='tanh') 251 elif mode == 'lstm': 252 rnn_block = rnn.LSTM 253 elif mode == 'gru': 254 rnn_block = rnn.GRU 255 256 block = rnn_block(hidden_size, num_layers, dropout=dropout, 257 input_size=input_size) 258 259 if weight_dropout: 260 apply_weight_drop(block, '.*h2h_weight', rate=weight_dropout) 261 262 return block 263 264 265def _load_vocab(dataset_name, vocab, root, cls=None): 266 if dataset_name: 267 if vocab is not None: 268 warnings.warn('Both dataset_name and vocab are specified. ' 269 'Loading vocab based on dataset_name. ' 270 'Input "vocab" argument will be ignored.') 271 vocab = _load_pretrained_vocab(dataset_name, root, cls) 272 else: 273 assert vocab is not None, 'Must specify vocab if not loading from predefined datasets.' 274 return vocab 275 276 277def _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=False, 278 allow_missing=False): 279 assert isinstance(dataset_name, str), \ 280 'dataset_name(str) is required when loading pretrained models. Got {}'.format(dataset_name) 281 path = '_'.join([model_name, dataset_name]) 282 model_file = model_store.get_model_file(path, root=root) 283 net.load_parameters(model_file, ctx=ctx, ignore_extra=ignore_extra, allow_missing=allow_missing) 284 285def _get_cell_type(cell_type): 286 """Get the object type of the cell by parsing the input 287 288 Parameters 289 ---------- 290 cell_type : str or type 291 292 Returns 293 ------- 294 cell_constructor: type 295 The constructor of the RNNCell 296 """ 297 if isinstance(cell_type, str): 298 if cell_type == 'lstm': 299 return rnn.LSTMCell 300 elif cell_type == 'gru': 301 return rnn.GRUCell 302 elif cell_type == 'relu_rnn': 303 return functools.partial(rnn.RNNCell, activation='relu') 304 elif cell_type == 'tanh_rnn': 305 return functools.partial(rnn.RNNCell, activation='tanh') 306 else: 307 raise NotImplementedError 308 else: 309 return cell_type 310