1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Language models.""" 18__all__ = ['AWDRNN', 'StandardRNN', 'BigRNN', 'awd_lstm_lm_1150', 'awd_lstm_lm_600', 19 'standard_lstm_lm_200', 'standard_lstm_lm_650', 'standard_lstm_lm_1500', 20 'big_rnn_lm_2048_512'] 21 22import os 23 24from mxnet.gluon import Block, nn, rnn, contrib 25from mxnet import nd, cpu, autograd, sym 26from mxnet.gluon.model_zoo import model_store 27 28from . import train 29from .utils import _load_vocab, _load_pretrained_params 30from ..base import get_home_dir 31 32 33class AWDRNN(train.AWDRNN): 34 """AWD language model by salesforce. 35 36 Reference: https://github.com/salesforce/awd-lstm-lm 37 38 License: BSD 3-Clause 39 40 Parameters 41 ---------- 42 mode : str 43 The type of RNN to use. Options are 'lstm', 'gru', 'rnn_tanh', 'rnn_relu'. 44 vocab_size : int 45 Size of the input vocabulary. 46 embed_size : int 47 Dimension of embedding vectors. 48 hidden_size : int 49 Number of hidden units for RNN. 50 num_layers : int 51 Number of RNN layers. 52 tie_weights : bool, default False 53 Whether to tie the weight matrices of output dense layer and input embedding layer. 54 dropout : float 55 Dropout rate to use for encoder output. 56 weight_drop : float 57 Dropout rate to use on encoder h2h weights. 58 drop_h : float 59 Dropout rate to on the output of intermediate layers of encoder. 60 drop_i : float 61 Dropout rate to on the output of embedding. 62 drop_e : float 63 Dropout rate to use on the embedding layer. 64 """ 65 def __init__(self, mode, vocab_size, embed_size, hidden_size, num_layers, 66 tie_weights, dropout, weight_drop, drop_h, 67 drop_i, drop_e, **kwargs): 68 super(AWDRNN, self).__init__(mode, vocab_size, embed_size, hidden_size, num_layers, 69 tie_weights, dropout, weight_drop, 70 drop_h, drop_i, drop_e, **kwargs) 71 72 def hybrid_forward(self, F, inputs, begin_state=None): 73 # pylint: disable=arguments-differ 74 """Implement forward computation. 75 76 Parameters 77 ----------- 78 inputs : NDArray 79 input tensor with shape `(sequence_length, batch_size)` 80 when `layout` is "TNC". 81 begin_state : list 82 initial recurrent state tensor with length equals to num_layers. 83 the initial state with shape `(1, batch_size, num_hidden)` 84 85 Returns 86 -------- 87 out: NDArray 88 output tensor with shape `(sequence_length, batch_size, input_size)` 89 when `layout` is "TNC". 90 out_states: list 91 output recurrent state tensor with length equals to num_layers. 92 the state with shape `(1, batch_size, num_hidden)` 93 """ 94 encoded = self.embedding(inputs) 95 if not begin_state: 96 if F == nd: 97 begin_state = self.begin_state(batch_size=inputs.shape[1]) 98 else: 99 begin_state = self.begin_state(batch_size=0, func=sym.zeros) 100 out_states = [] 101 for i, (e, s) in enumerate(zip(self.encoder, begin_state)): 102 encoded, state = e(encoded, s) 103 out_states.append(state) 104 if self._drop_h and i != len(self.encoder)-1: 105 encoded = F.Dropout(encoded, p=self._drop_h, axes=(0,)) 106 if self._dropout: 107 encoded = F.Dropout(encoded, p=self._dropout, axes=(0,)) 108 with autograd.predict_mode(): 109 out = self.decoder(encoded) 110 return out, out_states 111 112class StandardRNN(train.StandardRNN): 113 """Standard RNN language model. 114 115 Parameters 116 ---------- 117 mode : str 118 The type of RNN to use. Options are 'lstm', 'gru', 'rnn_tanh', 'rnn_relu'. 119 vocab_size : int 120 Size of the input vocabulary. 121 embed_size : int 122 Dimension of embedding vectors. 123 hidden_size : int 124 Number of hidden units for RNN. 125 num_layers : int 126 Number of RNN layers. 127 dropout : float 128 Dropout rate to use for encoder output. 129 tie_weights : bool, default False 130 Whether to tie the weight matrices of output dense layer and input embedding layer. 131 """ 132 def __init__(self, mode, vocab_size, embed_size, hidden_size, 133 num_layers, dropout, tie_weights, **kwargs): 134 if tie_weights: 135 assert embed_size == hidden_size, 'Embedding dimension must be equal to ' \ 136 'hidden dimension in order to tie weights. ' \ 137 'Got: emb: {}, hid: {}.'.format(embed_size, 138 hidden_size) 139 super(StandardRNN, self).__init__(mode, vocab_size, embed_size, hidden_size, 140 num_layers, dropout, tie_weights, **kwargs) 141 142 def hybrid_forward(self, F, inputs, begin_state=None): # pylint: disable=arguments-differ 143 """Defines the forward computation. Arguments can be either 144 :py:class:`NDArray` or :py:class:`Symbol`. 145 146 Parameters 147 ----------- 148 inputs : NDArray 149 input tensor with shape `(sequence_length, batch_size)` 150 when `layout` is "TNC". 151 begin_state : list 152 initial recurrent state tensor with length equals to num_layers-1. 153 the initial state with shape `(num_layers, batch_size, num_hidden)` 154 155 Returns 156 -------- 157 out: NDArray 158 output tensor with shape `(sequence_length, batch_size, input_size)` 159 when `layout` is "TNC". 160 out_states: list 161 output recurrent state tensor with length equals to num_layers-1. 162 the state with shape `(num_layers, batch_size, num_hidden)` 163 """ 164 encoded = self.embedding(inputs) 165 if not begin_state: 166 if F == nd: 167 begin_state = self.begin_state(batch_size=inputs.shape[1]) 168 else: 169 begin_state = self.begin_state(batch_size=0, func=sym.zeros) 170 encoded, state = self.encoder(encoded, begin_state) 171 if self._dropout: 172 encoded = nd.Dropout(encoded, p=self._dropout, axes=(0,)) 173 out = self.decoder(encoded) 174 return out, state 175 176awd_lstm_lm_1150_hparams = { 177 'embed_size': 400, 178 'hidden_size': 1150, 179 'mode': 'lstm', 180 'num_layers': 3, 181 'tie_weights': True, 182 'dropout': 0.4, 183 'weight_drop': 0.5, 184 'drop_h': 0.2, 185 'drop_i': 0.65, 186 'drop_e': 0.1 187} 188 189awd_lstm_lm_600_hparams = { 190 'embed_size': 200, 191 'hidden_size': 600, 192 'mode': 'lstm', 193 'num_layers': 3, 194 'tie_weights': True, 195 'dropout': 0.2, 196 'weight_drop': 0.2, 197 'drop_h': 0.1, 198 'drop_i': 0.3, 199 'drop_e': 0.05 200} 201 202standard_lstm_lm_200_hparams = { 203 'embed_size': 200, 204 'hidden_size': 200, 205 'mode': 'lstm', 206 'num_layers': 2, 207 'tie_weights': True, 208 'dropout': 0.2 209} 210 211standard_lstm_lm_650_hparams = { 212 'embed_size': 650, 213 'hidden_size': 650, 214 'mode': 'lstm', 215 'num_layers': 2, 216 'tie_weights': True, 217 'dropout': 0.5 218} 219 220standard_lstm_lm_1500_hparams = { 221 'embed_size': 1500, 222 'hidden_size': 1500, 223 'mode': 'lstm', 224 'num_layers': 2, 225 'tie_weights': True, 226 'dropout': 0.65 227} 228 229awd_lstm_lm_hparams = { 230 'awd_lstm_lm_1150': awd_lstm_lm_1150_hparams, 231 'awd_lstm_lm_600': awd_lstm_lm_600_hparams 232} 233 234standard_lstm_lm_hparams = { 235 'standard_lstm_lm_200': standard_lstm_lm_200_hparams, 236 'standard_lstm_lm_650': standard_lstm_lm_650_hparams, 237 'standard_lstm_lm_1500': standard_lstm_lm_1500_hparams 238} 239 240def _get_rnn_model(model_cls, model_name, dataset_name, vocab, pretrained, ctx, root, **kwargs): 241 vocab = _load_vocab(dataset_name, vocab, root) 242 kwargs['vocab_size'] = len(vocab) 243 net = model_cls(**kwargs) 244 if pretrained: 245 _load_pretrained_params(net, model_name, dataset_name, root, ctx) 246 return net, vocab 247 248 249def awd_lstm_lm_1150(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(), 250 root=os.path.join(get_home_dir(), 'models'), 251 hparam_allow_override=False, **kwargs): 252 r"""3-layer LSTM language model with weight-drop, variational dropout, and tied weights. 253 254 Embedding size is 400, and hidden layer size is 1150. 255 256 Parameters 257 ---------- 258 dataset_name : str or None, default None 259 The dataset name on which the pre-trained model is trained. 260 Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from 261 the training set of the dataset. 262 If None, then vocab is required, for specifying embedding weight size, and is directly 263 returned. 264 The pre-trained model achieves 73.32/69.74 ppl on Val and Test of wikitext-2 respectively. 265 vocab : gluonnlp.Vocab or None, default None 266 Vocab object to be used with the language model. 267 Required when dataset_name is not specified. 268 pretrained : bool, default False 269 Whether to load the pre-trained weights for model. 270 ctx : Context, default CPU 271 The context in which to load the pre-trained weights. 272 root : str, default '$MXNET_HOME/models' 273 Location for keeping the model parameters. 274 MXNET_HOME defaults to '~/.mxnet'. 275 hparam_allow_override : bool, default False 276 If set to True, pre-defined hyper-parameters of the model 277 (e.g. the number of layers, hidden units) can be overriden. 278 279 Returns 280 ------- 281 gluon.Block, gluonnlp.Vocab 282 """ 283 predefined_args = awd_lstm_lm_hparams['awd_lstm_lm_1150'].copy() 284 if not hparam_allow_override: 285 mutable_args = frozenset(['dropout', 'weight_drop', 'drop_h', 'drop_i', 'drop_e']) 286 assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 287 'Cannot override predefined model settings.' 288 predefined_args.update(kwargs) 289 return _get_rnn_model(AWDRNN, 'awd_lstm_lm_1150', dataset_name, vocab, pretrained, 290 ctx, root, **predefined_args) 291 292 293def awd_lstm_lm_600(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(), 294 root=os.path.join(get_home_dir(), 'models'), 295 hparam_allow_override=False, **kwargs): 296 r"""3-layer LSTM language model with weight-drop, variational dropout, and tied weights. 297 298 Embedding size is 200, and hidden layer size is 600. 299 300 Parameters 301 ---------- 302 dataset_name : str or None, default None 303 The dataset name on which the pre-trained model is trained. 304 Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from 305 the training set of the dataset. 306 If None, then vocab is required, for specifying embedding weight size, and is directly 307 returned. 308 The pre-trained model achieves 84.61/80.96 ppl on Val and Test of wikitext-2 respectively. 309 vocab : gluonnlp.Vocab or None, default None 310 Vocab object to be used with the language model. 311 Required when dataset_name is not specified. 312 pretrained : bool, default False 313 Whether to load the pre-trained weights for model. 314 ctx : Context, default CPU 315 The context in which to load the pre-trained weights. 316 root : str, default '$MXNET_HOME/models' 317 Location for keeping the model parameters. 318 MXNET_HOME defaults to '~/.mxnet'. 319 hparam_allow_override : bool, default False 320 If set to True, pre-defined hyper-parameters of the model 321 (e.g. the number of layers, hidden units) can be overriden. 322 323 Returns 324 ------- 325 gluon.Block, gluonnlp.Vocab 326 """ 327 predefined_args = awd_lstm_lm_hparams['awd_lstm_lm_600'].copy() 328 if not hparam_allow_override: 329 mutable_args = frozenset(['dropout', 'weight_drop', 'drop_h', 'drop_i', 'drop_e']) 330 assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 331 'Cannot override predefined model settings.' 332 predefined_args.update(kwargs) 333 return _get_rnn_model(AWDRNN, 'awd_lstm_lm_600', dataset_name, vocab, pretrained, 334 ctx, root, **predefined_args) 335 336def standard_lstm_lm_200(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(), 337 root=os.path.join(get_home_dir(), 'models'), 338 hparam_allow_override=False, **kwargs): 339 r"""Standard 2-layer LSTM language model with tied embedding and output weights. 340 341 Both embedding and hidden dimensions are 200. 342 343 Parameters 344 ---------- 345 dataset_name : str or None, default None 346 The dataset name on which the pre-trained model is trained. 347 Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from 348 the training set of the dataset. 349 If None, then vocab is required, for specifying embedding weight size, and is directly 350 returned. 351 The pre-trained model achieves 108.25/102.26 ppl on Val and Test of wikitext-2 respectively. 352 vocab : gluonnlp.Vocab or None, default None 353 Vocabulary object to be used with the language model. 354 Required when dataset_name is not specified. 355 pretrained : bool, default False 356 Whether to load the pre-trained weights for model. 357 ctx : Context, default CPU 358 The context in which to load the pre-trained weights. 359 root : str, default '$MXNET_HOME/models' 360 Location for keeping the model parameters. 361 MXNET_HOME defaults to '~/.mxnet'. 362 hparam_allow_override : bool, default False 363 If set to True, pre-defined hyper-parameters of the model 364 (e.g. the number of layers, hidden units) can be overriden. 365 366 Returns 367 ------- 368 gluon.Block, gluonnlp.Vocab 369 """ 370 predefined_args = standard_lstm_lm_hparams['standard_lstm_lm_200'].copy() 371 if not hparam_allow_override: 372 mutable_args = frozenset(['dropout']) 373 assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 374 'Cannot override predefined model settings.' 375 predefined_args.update(kwargs) 376 return _get_rnn_model(StandardRNN, 'standard_lstm_lm_200', dataset_name, vocab, pretrained, 377 ctx, root, **predefined_args) 378 379 380def standard_lstm_lm_650(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(), 381 root=os.path.join(get_home_dir(), 'models'), 382 hparam_allow_override=False, **kwargs): 383 r"""Standard 2-layer LSTM language model with tied embedding and output weights. 384 385 Both embedding and hidden dimensions are 650. 386 387 Parameters 388 ---------- 389 dataset_name : str or None, default None 390 The dataset name on which the pre-trained model is trained. 391 Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from 392 the training set of the dataset. 393 If None, then vocab is required, for specifying embedding weight size, and is directly 394 returned. 395 The pre-trained model achieves 98.96/93.90 ppl on Val and Test of wikitext-2 respectively. 396 vocab : gluonnlp.Vocab or None, default None 397 Vocabulary object to be used with the language model. 398 Required when dataset_name is not specified. 399 pretrained : bool, default False 400 Whether to load the pre-trained weights for model. 401 ctx : Context, default CPU 402 The context in which to load the pre-trained weights. 403 root : str, default '$MXNET_HOME/models' 404 Location for keeping the model parameters. 405 MXNET_HOME defaults to '~/.mxnet'. 406 hparam_allow_override : bool, default False 407 If set to True, pre-defined hyper-parameters of the model 408 (e.g. the number of layers, hidden units) can be overriden. 409 410 Returns 411 ------- 412 gluon.Block, gluonnlp.Vocab 413 """ 414 predefined_args = standard_lstm_lm_hparams['standard_lstm_lm_650'].copy() 415 if not hparam_allow_override: 416 mutable_args = frozenset(['dropout']) 417 assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 418 'Cannot override predefined model settings.' 419 predefined_args.update(kwargs) 420 return _get_rnn_model(StandardRNN, 'standard_lstm_lm_650', dataset_name, vocab, pretrained, 421 ctx, root, **predefined_args) 422 423 424def standard_lstm_lm_1500(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(), 425 root=os.path.join(get_home_dir(), 'models'), 426 hparam_allow_override=False, **kwargs): 427 r"""Standard 2-layer LSTM language model with tied embedding and output weights. 428 429 Both embedding and hidden dimensions are 1500. 430 431 Parameters 432 ---------- 433 dataset_name : str or None, default None 434 The dataset name on which the pre-trained model is trained. 435 Options are 'wikitext-2'. If specified, then the returned vocabulary is extracted from 436 the training set of the dataset. 437 If None, then vocab is required, for specifying embedding weight size, and is directly 438 returned. 439 The pre-trained model achieves 98.29/92.83 ppl on Val and Test of wikitext-2 respectively. 440 vocab : gluonnlp.Vocab or None, default None 441 Vocabulary object to be used with the language model. 442 Required when dataset_name is not specified. 443 pretrained : bool, default False 444 Whether to load the pre-trained weights for model. 445 ctx : Context, default CPU 446 The context in which to load the pre-trained weights. 447 root : str, default '$MXNET_HOME/models' 448 Location for keeping the model parameters. 449 MXNET_HOME defaults to '~/.mxnet'. 450 hparam_allow_override : bool, default False 451 If set to True, pre-defined hyper-parameters of the model 452 (e.g. the number of layers, hidden units) can be overriden. 453 454 Returns 455 ------- 456 gluon.Block, gluonnlp.Vocab 457 """ 458 predefined_args = standard_lstm_lm_hparams['standard_lstm_lm_1500'].copy() 459 if not hparam_allow_override: 460 mutable_args = frozenset(['dropout']) 461 assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 462 'Cannot override predefined model settings.' 463 predefined_args.update(kwargs) 464 return _get_rnn_model(StandardRNN, 'standard_lstm_lm_1500', 465 dataset_name, vocab, pretrained, ctx, root, **predefined_args) 466 467model_store._model_sha1.update( 468 {name: checksum for checksum, name in [ 469 ('a416351377d837ef12d17aae27739393f59f0b82', 'standard_lstm_lm_1500_wikitext-2'), 470 ('631f39040cd65b49f5c8828a0aba65606d73a9cb', 'standard_lstm_lm_650_wikitext-2'), 471 ('b233c700e80fb0846c17fe14846cb7e08db3fd51', 'standard_lstm_lm_200_wikitext-2'), 472 ('f9562ed05d9bcc7e1f5b7f3c81a1988019878038', 'awd_lstm_lm_1150_wikitext-2'), 473 ('e952becc7580a0b5a6030aab09d0644e9a13ce18', 'awd_lstm_lm_600_wikitext-2'), 474 ('6bb3e991eb4439fabfe26c129da2fe15a324e918', 'big_rnn_lm_2048_512_gbw') 475 ]}) 476 477class BigRNN(Block): 478 """Big language model with LSTMP for inference. 479 480 Parameters 481 ---------- 482 vocab_size : int 483 Size of the input vocabulary. 484 embed_size : int 485 Dimension of embedding vectors. 486 hidden_size : int 487 Number of hidden units for LSTMP. 488 num_layers : int 489 Number of LSTMP layers. 490 projection_size : int 491 Number of projection units for LSTMP. 492 embed_dropout : float 493 Dropout rate to use for embedding output. 494 encode_dropout : float 495 Dropout rate to use for encoder output. 496 497 """ 498 def __init__(self, vocab_size, embed_size, hidden_size, num_layers, 499 projection_size, embed_dropout=0.0, encode_dropout=0.0, **kwargs): 500 super(BigRNN, self).__init__(**kwargs) 501 self._embed_size = embed_size 502 self._hidden_size = hidden_size 503 self._projection_size = projection_size 504 self._num_layers = num_layers 505 self._embed_dropout = embed_dropout 506 self._encode_dropout = encode_dropout 507 self._vocab_size = vocab_size 508 509 with self.name_scope(): 510 self.embedding = self._get_embedding() 511 self.encoder = self._get_encoder() 512 self.decoder = self._get_decoder() 513 514 def _get_embedding(self): 515 prefix = 'embedding0_' 516 embedding = nn.HybridSequential(prefix=prefix) 517 with embedding.name_scope(): 518 embedding.add(nn.Embedding(self._vocab_size, self._embed_size, prefix=prefix)) 519 if self._embed_dropout: 520 embedding.add(nn.Dropout(self._embed_dropout)) 521 return embedding 522 523 def _get_encoder(self): 524 block = rnn.HybridSequentialRNNCell() 525 with block.name_scope(): 526 for _ in range(self._num_layers): 527 block.add(contrib.rnn.LSTMPCell(self._hidden_size, self._projection_size)) 528 if self._encode_dropout: 529 block.add(rnn.DropoutCell(self._encode_dropout)) 530 return block 531 532 def _get_decoder(self): 533 output = nn.Dense(self._vocab_size, prefix='decoder0_') 534 return output 535 536 def begin_state(self, **kwargs): 537 return self.encoder.begin_state(**kwargs) 538 539 def forward(self, inputs, begin_state): # pylint: disable=arguments-differ 540 """Implement forward computation. 541 542 Parameters 543 ----------- 544 inputs : NDArray 545 input tensor with shape `(sequence_length, batch_size)` 546 when `layout` is "TNC". 547 begin_state : list 548 initial recurrent state tensor with length equals to num_layers*2. 549 For each layer the two initial states have shape `(batch_size, num_hidden)` 550 and `(batch_size, num_projection)` 551 552 Returns 553 -------- 554 out : NDArray 555 output tensor with shape `(sequence_length, batch_size, vocab_size)` 556 when `layout` is "TNC". 557 out_states : list 558 output recurrent state tensor with length equals to num_layers*2. 559 For each layer the two initial states have shape `(batch_size, num_hidden)` 560 and `(batch_size, num_projection)` 561 """ 562 encoded = self.embedding(inputs) 563 length = inputs.shape[0] 564 batch_size = inputs.shape[1] 565 encoded, state = self.encoder.unroll(length, encoded, begin_state, 566 layout='TNC', merge_outputs=True) 567 encoded = encoded.reshape((-1, self._projection_size)) 568 out = self.decoder(encoded) 569 out = out.reshape((length, batch_size, -1)) 570 return out, state 571 572big_rnn_lm_2048_512_hparams = { 573 'embed_size': 512, 574 'hidden_size': 2048, 575 'projection_size': 512, 576 'num_layers': 1, 577 'embed_dropout': 0.1, 578 'encode_dropout': 0.1} 579 580big_rnn_lm_hparams = { 581 'big_rnn_lm_2048_512': big_rnn_lm_2048_512_hparams 582} 583 584def big_rnn_lm_2048_512(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(), 585 root=os.path.join(get_home_dir(), 'models'), 586 hparam_allow_override=False, **kwargs): 587 r"""Big 1-layer LSTMP language model. 588 589 Both embedding and projection size are 512. Hidden size is 2048. 590 591 Parameters 592 ---------- 593 dataset_name : str or None, default None 594 The dataset name on which the pre-trained model is trained. 595 Options are 'gbw'. If specified, then the returned vocabulary is extracted from 596 the training set of the dataset. 597 If None, then vocab is required, for specifying embedding weight size, and is directly 598 returned. 599 The pre-trained model achieves 44.05 ppl on Test of GBW dataset. 600 vocab : gluonnlp.Vocab or None, default None 601 Vocabulary object to be used with the language model. 602 Required when dataset_name is not specified. 603 pretrained : bool, default False 604 Whether to load the pre-trained weights for model. 605 ctx : Context, default CPU 606 The context in which to load the pre-trained weights. 607 root : str, default '$MXNET_HOME/models' 608 Location for keeping the model parameters. 609 MXNET_HOME defaults to '~/.mxnet'. 610 hparam_allow_override : bool, default False 611 If set to True, pre-defined hyper-parameters of the model 612 (e.g. the number of layers, hidden units) can be overriden. 613 614 Returns 615 ------- 616 gluon.Block, gluonnlp.Vocab 617 """ 618 predefined_args = big_rnn_lm_hparams['big_rnn_lm_2048_512'].copy() 619 if not hparam_allow_override: 620 mutable_args = frozenset(['embed_dropout', 'encode_dropout']) 621 assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 622 'Cannot override predefined model settings.' 623 predefined_args.update(kwargs) 624 return _get_rnn_model(BigRNN, 'big_rnn_lm_2048_512', dataset_name, vocab, pretrained, 625 ctx, root, **predefined_args) 626