1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Language models."""
18__all__ = ['AWDRNN', 'StandardRNN', 'BigRNN', 'awd_lstm_lm_1150', 'awd_lstm_lm_600',
19           'standard_lstm_lm_200', 'standard_lstm_lm_650', 'standard_lstm_lm_1500',
20           'big_rnn_lm_2048_512']
21
22import os
23
24from mxnet.gluon import Block, nn, rnn, contrib
25from mxnet import nd, cpu, autograd, sym
26from mxnet.gluon.model_zoo import model_store
27
28from . import train
29from .utils import _load_vocab, _load_pretrained_params
30from ..base import get_home_dir
31
32
33class AWDRNN(train.AWDRNN):
34    """AWD language model by salesforce.
35
36    Reference: https://github.com/salesforce/awd-lstm-lm
37
38    License: BSD 3-Clause
39
40    Parameters
41    ----------
42    mode : str
43        The type of RNN to use. Options are 'lstm', 'gru', 'rnn_tanh', 'rnn_relu'.
44    vocab_size : int
45        Size of the input vocabulary.
46    embed_size : int
47        Dimension of embedding vectors.
48    hidden_size : int
49        Number of hidden units for RNN.
50    num_layers : int
51        Number of RNN layers.
52    tie_weights : bool, default False
53        Whether to tie the weight matrices of output dense layer and input embedding layer.
54    dropout : float
55        Dropout rate to use for encoder output.
56    weight_drop : float
57        Dropout rate to use on encoder h2h weights.
58    drop_h : float
59        Dropout rate to on the output of intermediate layers of encoder.
60    drop_i : float
61        Dropout rate to on the output of embedding.
62    drop_e : float
63        Dropout rate to use on the embedding layer.
64    """
65    def __init__(self, mode, vocab_size, embed_size, hidden_size, num_layers,
66                 tie_weights, dropout, weight_drop, drop_h,
67                 drop_i, drop_e, **kwargs):
68        super(AWDRNN, self).__init__(mode, vocab_size, embed_size, hidden_size, num_layers,
69                                     tie_weights, dropout, weight_drop,
70                                     drop_h, drop_i, drop_e, **kwargs)
71
72    def hybrid_forward(self, F, inputs, begin_state=None):
73        # pylint: disable=arguments-differ
74        """Implement forward computation.
75
76        Parameters
77        -----------
78        inputs : NDArray
79            input tensor with shape `(sequence_length, batch_size)`
80            when `layout` is "TNC".
81        begin_state : list
82            initial recurrent state tensor with length equals to num_layers.
83            the initial state with shape `(1, batch_size, num_hidden)`
84
85        Returns
86        --------
87        out: NDArray
88            output tensor with shape `(sequence_length, batch_size, input_size)`
89            when `layout` is "TNC".
90        out_states: list
91            output recurrent state tensor with length equals to num_layers.
92            the state with shape `(1, batch_size, num_hidden)`
93        """
94        encoded = self.embedding(inputs)
95        if not begin_state:
96            if F == nd:
97                begin_state = self.begin_state(batch_size=inputs.shape[1])
98            else:
99                begin_state = self.begin_state(batch_size=0, func=sym.zeros)
100        out_states = []
101        for i, (e, s) in enumerate(zip(self.encoder, begin_state)):
102            encoded, state = e(encoded, s)
103            out_states.append(state)
104            if self._drop_h and i != len(self.encoder)-1:
105                encoded = F.Dropout(encoded, p=self._drop_h, axes=(0,))
106        if self._dropout:
107            encoded = F.Dropout(encoded, p=self._dropout, axes=(0,))
108        with autograd.predict_mode():
109            out = self.decoder(encoded)
110        return out, out_states
111
112class StandardRNN(train.StandardRNN):
113    """Standard RNN language model.
114
115    Parameters
116    ----------
117    mode : str
118        The type of RNN to use. Options are 'lstm', 'gru', 'rnn_tanh', 'rnn_relu'.
119    vocab_size : int
120        Size of the input vocabulary.
121    embed_size : int
122        Dimension of embedding vectors.
123    hidden_size : int
124        Number of hidden units for RNN.
125    num_layers : int
126        Number of RNN layers.
127    dropout : float
128        Dropout rate to use for encoder output.
129    tie_weights : bool, default False
130        Whether to tie the weight matrices of output dense layer and input embedding layer.
131    """
132    def __init__(self, mode, vocab_size, embed_size, hidden_size,
133                 num_layers, dropout, tie_weights, **kwargs):
134        if tie_weights:
135            assert embed_size == hidden_size, 'Embedding dimension must be equal to ' \
136                                              'hidden dimension in order to tie weights. ' \
137                                              'Got: emb: {}, hid: {}.'.format(embed_size,
138                                                                              hidden_size)
139        super(StandardRNN, self).__init__(mode, vocab_size, embed_size, hidden_size,
140                                          num_layers, dropout, tie_weights, **kwargs)
141
142    def hybrid_forward(self, F, inputs, begin_state=None): # pylint: disable=arguments-differ
143        """Defines the forward computation. Arguments can be either
144        :py:class:`NDArray` or :py:class:`Symbol`.
145
146        Parameters
147        -----------
148        inputs : NDArray
149            input tensor with shape `(sequence_length, batch_size)`
150              when `layout` is "TNC".
151        begin_state : list
152            initial recurrent state tensor with length equals to num_layers-1.
153            the initial state with shape `(num_layers, batch_size, num_hidden)`
154
155        Returns
156        --------
157        out: NDArray
158            output tensor with shape `(sequence_length, batch_size, input_size)`
159              when `layout` is "TNC".
160        out_states: list
161            output recurrent state tensor with length equals to num_layers-1.
162            the state with shape `(num_layers, batch_size, num_hidden)`
163        """
164        encoded = self.embedding(inputs)
165        if not begin_state:
166            if F == nd:
167                begin_state = self.begin_state(batch_size=inputs.shape[1])
168            else:
169                begin_state = self.begin_state(batch_size=0, func=sym.zeros)
170        encoded, state = self.encoder(encoded, begin_state)
171        if self._dropout:
172            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0,))
173        out = self.decoder(encoded)
174        return out, state
175
176awd_lstm_lm_1150_hparams = {
177        'embed_size': 400,
178        'hidden_size': 1150,
179        'mode': 'lstm',
180        'num_layers': 3,
181        'tie_weights': True,
182        'dropout': 0.4,
183        'weight_drop': 0.5,
184        'drop_h': 0.2,
185        'drop_i': 0.65,
186        'drop_e': 0.1
187}
188
189awd_lstm_lm_600_hparams = {
190        'embed_size': 200,
191        'hidden_size': 600,
192        'mode': 'lstm',
193        'num_layers': 3,
194        'tie_weights': True,
195        'dropout': 0.2,
196        'weight_drop': 0.2,
197        'drop_h': 0.1,
198        'drop_i': 0.3,
199        'drop_e': 0.05
200}
201
202standard_lstm_lm_200_hparams = {
203        'embed_size': 200,
204        'hidden_size': 200,
205        'mode': 'lstm',
206        'num_layers': 2,
207        'tie_weights': True,
208        'dropout': 0.2
209}
210
211standard_lstm_lm_650_hparams = {
212        'embed_size': 650,
213        'hidden_size': 650,
214        'mode': 'lstm',
215        'num_layers': 2,
216        'tie_weights': True,
217        'dropout': 0.5
218}
219
220standard_lstm_lm_1500_hparams = {
221        'embed_size': 1500,
222        'hidden_size': 1500,
223        'mode': 'lstm',
224        'num_layers': 2,
225        'tie_weights': True,
226        'dropout': 0.65
227}
228
229awd_lstm_lm_hparams = {
230        'awd_lstm_lm_1150': awd_lstm_lm_1150_hparams,
231        'awd_lstm_lm_600': awd_lstm_lm_600_hparams
232}
233
234standard_lstm_lm_hparams = {
235        'standard_lstm_lm_200': standard_lstm_lm_200_hparams,
236        'standard_lstm_lm_650': standard_lstm_lm_650_hparams,
237        'standard_lstm_lm_1500': standard_lstm_lm_1500_hparams
238}
239
240def _get_rnn_model(model_cls, model_name, dataset_name, vocab, pretrained, ctx, root, **kwargs):
241    vocab = _load_vocab(dataset_name, vocab, root)
242    kwargs['vocab_size'] = len(vocab)
243    net = model_cls(**kwargs)
244    if pretrained:
245        _load_pretrained_params(net, model_name, dataset_name, root, ctx)
246    return net, vocab
247
248
249def awd_lstm_lm_1150(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
250                     root=os.path.join(get_home_dir(), 'models'),
251                     hparam_allow_override=False, **kwargs):
252    r"""3-layer LSTM language model with weight-drop, variational dropout, and tied weights.
253
254    Embedding size is 400, and hidden layer size is 1150.
255
256    Parameters
257    ----------
258    dataset_name : str or None, default None
259        The dataset name on which the pre-trained model is trained.
260        Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from
261        the training set of the dataset.
262        If None, then vocab is required, for specifying embedding weight size, and is directly
263        returned.
264        The pre-trained model achieves 73.32/69.74 ppl on Val and Test of wikitext-2 respectively.
265    vocab : gluonnlp.Vocab or None, default None
266        Vocab object to be used with the language model.
267        Required when dataset_name is not specified.
268    pretrained : bool, default False
269        Whether to load the pre-trained weights for model.
270    ctx : Context, default CPU
271        The context in which to load the pre-trained weights.
272    root : str, default '$MXNET_HOME/models'
273        Location for keeping the model parameters.
274        MXNET_HOME defaults to '~/.mxnet'.
275    hparam_allow_override : bool, default False
276        If set to True, pre-defined hyper-parameters of the model
277        (e.g. the number of layers, hidden units) can be overriden.
278
279    Returns
280    -------
281    gluon.Block, gluonnlp.Vocab
282    """
283    predefined_args = awd_lstm_lm_hparams['awd_lstm_lm_1150'].copy()
284    if not hparam_allow_override:
285        mutable_args = frozenset(['dropout', 'weight_drop', 'drop_h', 'drop_i', 'drop_e'])
286        assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
287            'Cannot override predefined model settings.'
288    predefined_args.update(kwargs)
289    return _get_rnn_model(AWDRNN, 'awd_lstm_lm_1150', dataset_name, vocab, pretrained,
290                          ctx, root, **predefined_args)
291
292
293def awd_lstm_lm_600(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
294                    root=os.path.join(get_home_dir(), 'models'),
295                    hparam_allow_override=False, **kwargs):
296    r"""3-layer LSTM language model with weight-drop, variational dropout, and tied weights.
297
298    Embedding size is 200, and hidden layer size is 600.
299
300    Parameters
301    ----------
302    dataset_name : str or None, default None
303        The dataset name on which the pre-trained model is trained.
304        Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from
305        the training set of the dataset.
306        If None, then vocab is required, for specifying embedding weight size, and is directly
307        returned.
308        The pre-trained model achieves 84.61/80.96 ppl on Val and Test of wikitext-2 respectively.
309    vocab : gluonnlp.Vocab or None, default None
310        Vocab object to be used with the language model.
311        Required when dataset_name is not specified.
312    pretrained : bool, default False
313        Whether to load the pre-trained weights for model.
314    ctx : Context, default CPU
315        The context in which to load the pre-trained weights.
316    root : str, default '$MXNET_HOME/models'
317        Location for keeping the model parameters.
318        MXNET_HOME defaults to '~/.mxnet'.
319    hparam_allow_override : bool, default False
320        If set to True, pre-defined hyper-parameters of the model
321        (e.g. the number of layers, hidden units) can be overriden.
322
323    Returns
324    -------
325    gluon.Block, gluonnlp.Vocab
326    """
327    predefined_args = awd_lstm_lm_hparams['awd_lstm_lm_600'].copy()
328    if not hparam_allow_override:
329        mutable_args = frozenset(['dropout', 'weight_drop', 'drop_h', 'drop_i', 'drop_e'])
330        assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
331            'Cannot override predefined model settings.'
332    predefined_args.update(kwargs)
333    return _get_rnn_model(AWDRNN, 'awd_lstm_lm_600', dataset_name, vocab, pretrained,
334                          ctx, root, **predefined_args)
335
336def standard_lstm_lm_200(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
337                         root=os.path.join(get_home_dir(), 'models'),
338                         hparam_allow_override=False, **kwargs):
339    r"""Standard 2-layer LSTM language model with tied embedding and output weights.
340
341    Both embedding and hidden dimensions are 200.
342
343    Parameters
344    ----------
345    dataset_name : str or None, default None
346        The dataset name on which the pre-trained model is trained.
347        Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from
348        the training set of the dataset.
349        If None, then vocab is required, for specifying embedding weight size, and is directly
350        returned.
351        The pre-trained model achieves 108.25/102.26 ppl on Val and Test of wikitext-2 respectively.
352    vocab : gluonnlp.Vocab or None, default None
353        Vocabulary object to be used with the language model.
354        Required when dataset_name is not specified.
355    pretrained : bool, default False
356        Whether to load the pre-trained weights for model.
357    ctx : Context, default CPU
358        The context in which to load the pre-trained weights.
359    root : str, default '$MXNET_HOME/models'
360        Location for keeping the model parameters.
361        MXNET_HOME defaults to '~/.mxnet'.
362    hparam_allow_override : bool, default False
363        If set to True, pre-defined hyper-parameters of the model
364        (e.g. the number of layers, hidden units) can be overriden.
365
366    Returns
367    -------
368    gluon.Block, gluonnlp.Vocab
369    """
370    predefined_args = standard_lstm_lm_hparams['standard_lstm_lm_200'].copy()
371    if not hparam_allow_override:
372        mutable_args = frozenset(['dropout'])
373        assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
374            'Cannot override predefined model settings.'
375    predefined_args.update(kwargs)
376    return _get_rnn_model(StandardRNN, 'standard_lstm_lm_200', dataset_name, vocab, pretrained,
377                          ctx, root, **predefined_args)
378
379
380def standard_lstm_lm_650(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
381                         root=os.path.join(get_home_dir(), 'models'),
382                         hparam_allow_override=False, **kwargs):
383    r"""Standard 2-layer LSTM language model with tied embedding and output weights.
384
385    Both embedding and hidden dimensions are 650.
386
387    Parameters
388    ----------
389    dataset_name : str or None, default None
390        The dataset name on which the pre-trained model is trained.
391        Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from
392        the training set of the dataset.
393        If None, then vocab is required, for specifying embedding weight size, and is directly
394        returned.
395        The pre-trained model achieves 98.96/93.90 ppl on Val and Test of wikitext-2 respectively.
396    vocab : gluonnlp.Vocab or None, default None
397        Vocabulary object to be used with the language model.
398        Required when dataset_name is not specified.
399    pretrained : bool, default False
400        Whether to load the pre-trained weights for model.
401    ctx : Context, default CPU
402        The context in which to load the pre-trained weights.
403    root : str, default '$MXNET_HOME/models'
404        Location for keeping the model parameters.
405        MXNET_HOME defaults to '~/.mxnet'.
406    hparam_allow_override : bool, default False
407        If set to True, pre-defined hyper-parameters of the model
408        (e.g. the number of layers, hidden units) can be overriden.
409
410    Returns
411    -------
412    gluon.Block, gluonnlp.Vocab
413    """
414    predefined_args = standard_lstm_lm_hparams['standard_lstm_lm_650'].copy()
415    if not hparam_allow_override:
416        mutable_args = frozenset(['dropout'])
417        assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
418            'Cannot override predefined model settings.'
419    predefined_args.update(kwargs)
420    return _get_rnn_model(StandardRNN, 'standard_lstm_lm_650', dataset_name, vocab, pretrained,
421                          ctx, root, **predefined_args)
422
423
424def standard_lstm_lm_1500(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
425                          root=os.path.join(get_home_dir(), 'models'),
426                          hparam_allow_override=False, **kwargs):
427    r"""Standard 2-layer LSTM language model with tied embedding and output weights.
428
429    Both embedding and hidden dimensions are 1500.
430
431    Parameters
432    ----------
433    dataset_name : str or None, default None
434        The dataset name on which the pre-trained model is trained.
435        Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from
436        the training set of the dataset.
437        If None, then vocab is required, for specifying embedding weight size, and is directly
438        returned.
439        The pre-trained model achieves 98.29/92.83 ppl on Val and Test of wikitext-2 respectively.
440    vocab : gluonnlp.Vocab or None, default None
441        Vocabulary object to be used with the language model.
442        Required when dataset_name is not specified.
443    pretrained : bool, default False
444        Whether to load the pre-trained weights for model.
445    ctx : Context, default CPU
446        The context in which to load the pre-trained weights.
447    root : str, default '$MXNET_HOME/models'
448        Location for keeping the model parameters.
449        MXNET_HOME defaults to '~/.mxnet'.
450    hparam_allow_override : bool, default False
451        If set to True, pre-defined hyper-parameters of the model
452        (e.g. the number of layers, hidden units) can be overriden.
453
454    Returns
455    -------
456    gluon.Block, gluonnlp.Vocab
457    """
458    predefined_args = standard_lstm_lm_hparams['standard_lstm_lm_1500'].copy()
459    if not hparam_allow_override:
460        mutable_args = frozenset(['dropout'])
461        assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
462            'Cannot override predefined model settings.'
463    predefined_args.update(kwargs)
464    return _get_rnn_model(StandardRNN, 'standard_lstm_lm_1500',
465                          dataset_name, vocab, pretrained, ctx, root, **predefined_args)
466
467model_store._model_sha1.update(
468    {name: checksum for checksum, name in [
469        ('a416351377d837ef12d17aae27739393f59f0b82', 'standard_lstm_lm_1500_wikitext-2'),
470        ('631f39040cd65b49f5c8828a0aba65606d73a9cb', 'standard_lstm_lm_650_wikitext-2'),
471        ('b233c700e80fb0846c17fe14846cb7e08db3fd51', 'standard_lstm_lm_200_wikitext-2'),
472        ('f9562ed05d9bcc7e1f5b7f3c81a1988019878038', 'awd_lstm_lm_1150_wikitext-2'),
473        ('e952becc7580a0b5a6030aab09d0644e9a13ce18', 'awd_lstm_lm_600_wikitext-2'),
474        ('6bb3e991eb4439fabfe26c129da2fe15a324e918', 'big_rnn_lm_2048_512_gbw')
475    ]})
476
477class BigRNN(Block):
478    """Big language model with LSTMP for inference.
479
480    Parameters
481    ----------
482    vocab_size : int
483        Size of the input vocabulary.
484    embed_size : int
485        Dimension of embedding vectors.
486    hidden_size : int
487        Number of hidden units for LSTMP.
488    num_layers : int
489        Number of LSTMP layers.
490    projection_size : int
491        Number of projection units for LSTMP.
492    embed_dropout : float
493        Dropout rate to use for embedding output.
494    encode_dropout : float
495        Dropout rate to use for encoder output.
496
497    """
498    def __init__(self, vocab_size, embed_size, hidden_size, num_layers,
499                 projection_size, embed_dropout=0.0, encode_dropout=0.0, **kwargs):
500        super(BigRNN, self).__init__(**kwargs)
501        self._embed_size = embed_size
502        self._hidden_size = hidden_size
503        self._projection_size = projection_size
504        self._num_layers = num_layers
505        self._embed_dropout = embed_dropout
506        self._encode_dropout = encode_dropout
507        self._vocab_size = vocab_size
508
509        with self.name_scope():
510            self.embedding = self._get_embedding()
511            self.encoder = self._get_encoder()
512            self.decoder = self._get_decoder()
513
514    def _get_embedding(self):
515        prefix = 'embedding0_'
516        embedding = nn.HybridSequential(prefix=prefix)
517        with embedding.name_scope():
518            embedding.add(nn.Embedding(self._vocab_size, self._embed_size, prefix=prefix))
519            if self._embed_dropout:
520                embedding.add(nn.Dropout(self._embed_dropout))
521        return embedding
522
523    def _get_encoder(self):
524        block = rnn.HybridSequentialRNNCell()
525        with block.name_scope():
526            for _ in range(self._num_layers):
527                block.add(contrib.rnn.LSTMPCell(self._hidden_size, self._projection_size))
528                if self._encode_dropout:
529                    block.add(rnn.DropoutCell(self._encode_dropout))
530        return block
531
532    def _get_decoder(self):
533        output = nn.Dense(self._vocab_size, prefix='decoder0_')
534        return output
535
536    def begin_state(self, **kwargs):
537        return self.encoder.begin_state(**kwargs)
538
539    def forward(self, inputs, begin_state): # pylint: disable=arguments-differ
540        """Implement forward computation.
541
542        Parameters
543        -----------
544        inputs : NDArray
545            input tensor with shape `(sequence_length, batch_size)`
546            when `layout` is "TNC".
547        begin_state : list
548            initial recurrent state tensor with length equals to num_layers*2.
549            For each layer the two initial states have shape `(batch_size, num_hidden)`
550            and `(batch_size, num_projection)`
551
552        Returns
553        --------
554        out : NDArray
555            output tensor with shape `(sequence_length, batch_size, vocab_size)`
556              when `layout` is "TNC".
557        out_states : list
558            output recurrent state tensor with length equals to num_layers*2.
559            For each layer the two initial states have shape `(batch_size, num_hidden)`
560            and `(batch_size, num_projection)`
561        """
562        encoded = self.embedding(inputs)
563        length = inputs.shape[0]
564        batch_size = inputs.shape[1]
565        encoded, state = self.encoder.unroll(length, encoded, begin_state,
566                                             layout='TNC', merge_outputs=True)
567        encoded = encoded.reshape((-1, self._projection_size))
568        out = self.decoder(encoded)
569        out = out.reshape((length, batch_size, -1))
570        return out, state
571
572big_rnn_lm_2048_512_hparams = {
573        'embed_size': 512,
574        'hidden_size': 2048,
575        'projection_size': 512,
576        'num_layers': 1,
577        'embed_dropout': 0.1,
578        'encode_dropout': 0.1}
579
580big_rnn_lm_hparams = {
581        'big_rnn_lm_2048_512': big_rnn_lm_2048_512_hparams
582}
583
584def big_rnn_lm_2048_512(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
585                        root=os.path.join(get_home_dir(), 'models'),
586                        hparam_allow_override=False, **kwargs):
587    r"""Big 1-layer LSTMP language model.
588
589    Both embedding and projection size are 512. Hidden size is 2048.
590
591    Parameters
592    ----------
593    dataset_name : str or None, default None
594        The dataset name on which the pre-trained model is trained.
595        Options are 'gbw'. If specified, then the returned vocabulary is extracted from
596        the training set of the dataset.
597        If None, then vocab is required, for specifying embedding weight size, and is directly
598        returned.
599        The pre-trained model achieves 44.05 ppl on Test of GBW dataset.
600    vocab : gluonnlp.Vocab or None, default None
601        Vocabulary object to be used with the language model.
602        Required when dataset_name is not specified.
603    pretrained : bool, default False
604        Whether to load the pre-trained weights for model.
605    ctx : Context, default CPU
606        The context in which to load the pre-trained weights.
607    root : str, default '$MXNET_HOME/models'
608        Location for keeping the model parameters.
609        MXNET_HOME defaults to '~/.mxnet'.
610    hparam_allow_override : bool, default False
611        If set to True, pre-defined hyper-parameters of the model
612        (e.g. the number of layers, hidden units) can be overriden.
613
614    Returns
615    -------
616    gluon.Block, gluonnlp.Vocab
617    """
618    predefined_args = big_rnn_lm_hparams['big_rnn_lm_2048_512'].copy()
619    if not hparam_allow_override:
620        mutable_args = frozenset(['embed_dropout', 'encode_dropout'])
621        assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
622            'Cannot override predefined model settings.'
623    predefined_args.update(kwargs)
624    return _get_rnn_model(BigRNN, 'big_rnn_lm_2048_512', dataset_name, vocab, pretrained,
625                          ctx, root, **predefined_args)
626