1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Building blocks and utility for models."""
18__all__ = ['apply_weight_drop']
19
20import collections
21import functools
22import re
23import warnings
24
25from mxnet.gluon import Block, contrib, rnn
26from mxnet.gluon.model_zoo import model_store
27from ..data.utils import _load_pretrained_vocab
28from .parameter import WeightDropParameter
29from .lstmpcellwithclip import LSTMPCellWithClip
30
31# pylint: disable=too-many-nested-blocks
32
33
34def apply_weight_drop(block, local_param_regex, rate, axes=(),
35                      weight_dropout_mode='training'):
36    """Apply weight drop to the parameter of a block.
37
38    Parameters
39    ----------
40    block : Block or HybridBlock
41        The block whose parameter is to be applied weight-drop.
42    local_param_regex : str
43        The regex for parameter names used in the self.params.get(), such as 'weight'.
44    rate : float
45        Fraction of the input units to drop. Must be a number between 0 and 1.
46    axes : tuple of int, default ()
47        The axes on which dropout mask is shared. If empty, regular dropout is applied.
48    weight_drop_mode : {'training', 'always'}, default 'training'
49        Whether the weight dropout should be applied only at training time, or always be applied.
50
51    Examples
52    --------
53    >>> net = gluon.rnn.LSTM(10, num_layers=2, bidirectional=True)
54    >>> gluonnlp.model.apply_weight_drop(net, r'.*h2h_weight', 0.5)
55    >>> net.collect_params()
56    lstm0_ (
57      Parameter lstm0_l0_i2h_weight (shape=(40, 0), dtype=float32)
58      WeightDropParameter lstm0_l0_h2h_weight (shape=(40, 10), dtype=float32, \
59rate=0.5, mode=training)
60      Parameter lstm0_l0_i2h_bias (shape=(40,), dtype=float32)
61      Parameter lstm0_l0_h2h_bias (shape=(40,), dtype=float32)
62      Parameter lstm0_r0_i2h_weight (shape=(40, 0), dtype=float32)
63      WeightDropParameter lstm0_r0_h2h_weight (shape=(40, 10), dtype=float32, \
64rate=0.5, mode=training)
65      Parameter lstm0_r0_i2h_bias (shape=(40,), dtype=float32)
66      Parameter lstm0_r0_h2h_bias (shape=(40,), dtype=float32)
67      Parameter lstm0_l1_i2h_weight (shape=(40, 20), dtype=float32)
68      WeightDropParameter lstm0_l1_h2h_weight (shape=(40, 10), dtype=float32, \
69rate=0.5, mode=training)
70      Parameter lstm0_l1_i2h_bias (shape=(40,), dtype=float32)
71      Parameter lstm0_l1_h2h_bias (shape=(40,), dtype=float32)
72      Parameter lstm0_r1_i2h_weight (shape=(40, 20), dtype=float32)
73      WeightDropParameter lstm0_r1_h2h_weight (shape=(40, 10), dtype=float32, \
74rate=0.5, mode=training)
75      Parameter lstm0_r1_i2h_bias (shape=(40,), dtype=float32)
76      Parameter lstm0_r1_h2h_bias (shape=(40,), dtype=float32)
77    )
78    >>> ones = mx.nd.ones((3, 4, 5))
79    >>> net.initialize()
80    >>> with mx.autograd.train_mode():
81    ...     net(ones).max().asscalar() != net(ones).max().asscalar()
82    True
83    """
84    if not rate:
85        return
86
87    existing_params = _find_params(block, local_param_regex)
88    for (local_param_name, param), \
89            (ref_params_list, ref_reg_params_list) in existing_params.items():
90        if isinstance(param, WeightDropParameter):
91            continue
92        dropped_param = WeightDropParameter(param, rate, weight_dropout_mode, axes)
93        for ref_params in ref_params_list:
94            ref_params[param.name] = dropped_param
95        for ref_reg_params in ref_reg_params_list:
96            ref_reg_params[local_param_name] = dropped_param
97            if hasattr(block, local_param_name):
98                local_attr = getattr(block, local_param_name)
99                if local_attr == param:
100                    local_attr = dropped_param
101                elif isinstance(local_attr, (list, tuple)):
102                    if isinstance(local_attr, tuple):
103                        local_attr = list(local_attr)
104                    for i, v in enumerate(local_attr):
105                        if v == param:
106                            local_attr[i] = dropped_param
107                elif isinstance(local_attr, dict):
108                    for k, v in local_attr:
109                        if v == param:
110                            local_attr[k] = dropped_param
111                else:
112                    continue
113                if local_attr:
114                    super(Block, block).__setattr__(local_param_name, local_attr)
115
116# pylint: enable=too-many-nested-blocks
117
118
119def _find_params(block, local_param_regex):
120    # return {(local_param_name, parameter): (referenced_params_list,
121    #                                         referenced_reg_params_list)}
122
123    results = collections.defaultdict(lambda: ([], []))
124    pattern = re.compile(local_param_regex)
125    local_param_names = ((local_param_name, p) for local_param_name, p in block._reg_params.items()
126                         if pattern.match(local_param_name))
127
128    for local_param_name, p in local_param_names:
129        ref_params_list, ref_reg_params_list = results[(local_param_name, p)]
130        ref_reg_params_list.append(block._reg_params)
131
132        params = block._params
133        while params:
134            if p.name in params._params:
135                ref_params_list.append(params._params)
136            if params._shared:
137                params = params._shared
138                warnings.warn('When applying weight drop, target parameter {} was found '
139                              'in a shared parameter dict. The parameter attribute of the '
140                              'original block on which the shared parameter dict was attached '
141                              'will not be updated with WeightDropParameter. If necessary, '
142                              'please update the attribute manually. The likely name of the '
143                              'attribute is ".{}"'.format(p.name, local_param_name))
144            else:
145                break
146
147    if block._children:
148        if isinstance(block._children, list):
149            children = block._children
150        elif isinstance(block._children, dict):
151            children = block._children.values()
152        for c in children:
153            child_results = _find_params(c, local_param_regex)
154            for (child_p_name, child_p), (child_pd_list, child_rd_list) in child_results.items():
155                pd_list, rd_list = results[(child_p_name, child_p)]
156                pd_list.extend(child_pd_list)
157                rd_list.extend(child_rd_list)
158
159    return results
160
161
162def _get_rnn_cell(mode, num_layers, input_size, hidden_size,
163                  dropout, weight_dropout,
164                  var_drop_in, var_drop_state, var_drop_out,
165                  skip_connection, proj_size=None, cell_clip=None, proj_clip=None):
166    """create rnn cell given specs
167
168    Parameters
169    ----------
170    mode : str
171        The type of RNN cell to use. Options are 'lstmpc', 'rnn_tanh', 'rnn_relu', 'lstm', 'gru'.
172    num_layers : int
173        The number of RNN cells in the encoder.
174    input_size : int
175        The initial input size of in the RNN cell.
176    hidden_size : int
177        The hidden size of the RNN cell.
178    dropout : float
179        The dropout rate to use for encoder output.
180    weight_dropout: float
181        The dropout rate to the hidden to hidden connections.
182    var_drop_in: float
183        The variational dropout rate for inputs. Won’t apply dropout if it equals 0.
184    var_drop_state: float
185        The variational dropout rate for state inputs on the first state channel.
186        Won’t apply dropout if it equals 0.
187    var_drop_out: float
188        The variational dropout rate for outputs. Won’t apply dropout if it equals 0.
189    skip_connection : bool
190        Whether to add skip connections (add RNN cell input to output)
191    proj_size : int
192        The projection size of each LSTMPCellWithClip cell.
193        Only available when the mode=lstmpc.
194    cell_clip : float
195        Clip cell state between [-cellclip, cell_clip] in LSTMPCellWithClip cell.
196        Only available when the mode=lstmpc.
197    proj_clip : float
198        Clip projection between [-projclip, projclip] in LSTMPCellWithClip cell
199        Only available when the mode=lstmpc.
200    """
201
202    assert mode == 'lstmpc' or proj_size is None, \
203        'proj_size takes effect only when mode is lstmpc'
204    assert mode == 'lstmpc' or cell_clip is None, \
205        'cell_clip takes effect only when mode is lstmpc'
206    assert mode == 'lstmpc' or proj_clip is None, \
207        'proj_clip takes effect only when mode is lstmpc'
208
209    rnn_cell = rnn.HybridSequentialRNNCell()
210    with rnn_cell.name_scope():
211        for i in range(num_layers):
212            if mode == 'rnn_relu':
213                cell = rnn.RNNCell(hidden_size, 'relu', input_size=input_size)
214            elif mode == 'rnn_tanh':
215                cell = rnn.RNNCell(hidden_size, 'tanh', input_size=input_size)
216            elif mode == 'lstm':
217                cell = rnn.LSTMCell(hidden_size, input_size=input_size)
218            elif mode == 'gru':
219                cell = rnn.GRUCell(hidden_size, input_size=input_size)
220            elif mode == 'lstmpc':
221                cell = LSTMPCellWithClip(hidden_size, proj_size,
222                                         cell_clip=cell_clip,
223                                         projection_clip=proj_clip,
224                                         input_size=input_size)
225            if var_drop_in + var_drop_state + var_drop_out != 0:
226                cell = contrib.rnn.VariationalDropoutCell(cell,
227                                                          var_drop_in,
228                                                          var_drop_state,
229                                                          var_drop_out)
230
231            if skip_connection:
232                cell = rnn.ResidualCell(cell)
233
234            rnn_cell.add(cell)
235
236            if i != num_layers - 1 and dropout != 0:
237                rnn_cell.add(rnn.DropoutCell(dropout))
238
239            if weight_dropout:
240                apply_weight_drop(rnn_cell, 'h2h_weight', rate=weight_dropout)
241
242    return rnn_cell
243
244
245def _get_rnn_layer(mode, num_layers, input_size, hidden_size, dropout, weight_dropout):
246    """create rnn layer given specs"""
247    if mode == 'rnn_relu':
248        rnn_block = functools.partial(rnn.RNN, activation='relu')
249    elif mode == 'rnn_tanh':
250        rnn_block = functools.partial(rnn.RNN, activation='tanh')
251    elif mode == 'lstm':
252        rnn_block = rnn.LSTM
253    elif mode == 'gru':
254        rnn_block = rnn.GRU
255
256    block = rnn_block(hidden_size, num_layers, dropout=dropout,
257                      input_size=input_size)
258
259    if weight_dropout:
260        apply_weight_drop(block, '.*h2h_weight', rate=weight_dropout)
261
262    return block
263
264
265def _load_vocab(dataset_name, vocab, root, cls=None):
266    if dataset_name:
267        if vocab is not None:
268            warnings.warn('Both dataset_name and vocab are specified. '
269                          'Loading vocab based on dataset_name. '
270                          'Input "vocab" argument will be ignored.')
271        vocab = _load_pretrained_vocab(dataset_name, root, cls)
272    else:
273        assert vocab is not None, 'Must specify vocab if not loading from predefined datasets.'
274    return vocab
275
276
277def _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=False,
278                            allow_missing=False):
279    assert isinstance(dataset_name, str), \
280      'dataset_name(str) is required when loading pretrained models. Got {}'.format(dataset_name)
281    path = '_'.join([model_name, dataset_name])
282    model_file = model_store.get_model_file(path, root=root)
283    net.load_parameters(model_file, ctx=ctx, ignore_extra=ignore_extra, allow_missing=allow_missing)
284
285def _get_cell_type(cell_type):
286    """Get the object type of the cell by parsing the input
287
288    Parameters
289    ----------
290    cell_type : str or type
291
292    Returns
293    -------
294    cell_constructor: type
295        The constructor of the RNNCell
296    """
297    if isinstance(cell_type, str):
298        if cell_type == 'lstm':
299            return rnn.LSTMCell
300        elif cell_type == 'gru':
301            return rnn.GRUCell
302        elif cell_type == 'relu_rnn':
303            return functools.partial(rnn.RNNCell, activation='relu')
304        elif cell_type == 'tanh_rnn':
305            return functools.partial(rnn.RNNCell, activation='tanh')
306        else:
307            raise NotImplementedError
308    else:
309        return cell_type
310