1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import numpy as np
19import mxnet as mx, math
20import argparse, math
21import logging
22from data import Corpus, CorpusIter
23from model import rnn, softmax_ce_loss
24from module import *
25from mxnet.model import BatchEndParam
26
27parser = argparse.ArgumentParser(description='Sherlock Holmes LSTM Language Model')
28parser.add_argument('--data', type=str, default='./data/sherlockholmes.',
29                    help='location of the data corpus')
30parser.add_argument('--emsize', type=int, default=650,
31                    help='size of word embeddings')
32parser.add_argument('--nhid', type=int, default=650,
33                    help='number of hidden units per layer')
34parser.add_argument('--nlayers', type=int, default=2,
35                    help='number of layers')
36parser.add_argument('--lr', type=float, default=1.0,
37                    help='initial learning rate')
38parser.add_argument('--clip', type=float, default=0.2,
39                    help='gradient clipping by global norm')
40parser.add_argument('--epochs', type=int, default=40,
41                    help='upper epoch limit')
42parser.add_argument('--batch_size', type=int, default=32,
43                    help='batch size')
44parser.add_argument('--dropout', type=float, default=0.5,
45                    help='dropout applied to layers (0 = no dropout)')
46parser.add_argument('--tied', action='store_true',
47                    help='tie the word embedding and softmax weights')
48parser.add_argument('--bptt', type=int, default=35,
49                    help='sequence length')
50parser.add_argument('--log-interval', type=int, default=200,
51                    help='report interval')
52parser.add_argument('--seed', type=int, default=3,
53                    help='random seed')
54args = parser.parse_args()
55
56best_loss = 9999
57
58def evaluate(valid_module, data_iter, epoch, mode, bptt, batch_size):
59    total_loss = 0.0
60    nbatch = 0
61    for batch in data_iter:
62        valid_module.forward(batch, is_train=False)
63        outputs = valid_module.get_loss()
64        total_loss += mx.nd.sum(outputs[0]).asscalar()
65        nbatch += 1
66    data_iter.reset()
67    loss = total_loss / bptt / batch_size / nbatch
68    logging.info('Iter[%d] %s loss:\t%.7f, Perplexity: %.7f' % \
69                 (epoch, mode, loss, math.exp(loss)))
70    return loss
71
72if __name__ == '__main__':
73    # args
74    head = '%(asctime)-15s %(message)s'
75    logging.basicConfig(level=logging.DEBUG, format=head)
76    args = parser.parse_args()
77    logging.info(args)
78    ctx = mx.gpu()
79    batch_size = args.batch_size
80    bptt = args.bptt
81    mx.random.seed(args.seed)
82
83    # data
84    corpus = Corpus(args.data)
85    ntokens = len(corpus.dictionary)
86    train_data = CorpusIter(corpus.train, batch_size, bptt)
87    valid_data = CorpusIter(corpus.valid, batch_size, bptt)
88    test_data = CorpusIter(corpus.test, batch_size, bptt)
89
90    # model
91    pred, states, state_names = rnn(bptt, ntokens, args.emsize, args.nhid,
92                                    args.nlayers, args.dropout, batch_size, args.tied)
93    loss = softmax_ce_loss(pred)
94
95    # module
96    module = CustomStatefulModule(loss, states, state_names=state_names, context=ctx)
97    module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
98    module.init_params(initializer=mx.init.Xavier())
99    optimizer = mx.optimizer.create('sgd', learning_rate=args.lr, rescale_grad=1.0/batch_size)
100    module.init_optimizer(optimizer=optimizer)
101
102    # metric
103    speedometer = mx.callback.Speedometer(batch_size, args.log_interval)
104
105    # train
106    logging.info("Training started ... ")
107    for epoch in range(args.epochs):
108        # train
109        total_loss = 0.0
110        nbatch = 0
111        for batch in train_data:
112            module.forward(batch)
113            module.backward()
114            module.update(max_norm=args.clip * bptt * batch_size)
115            # update metric
116            outputs = module.get_loss()
117            total_loss += mx.nd.sum(outputs[0]).asscalar()
118            speedometer_param = BatchEndParam(epoch=epoch, nbatch=nbatch,
119                                              eval_metric=None, locals=locals())
120            speedometer(speedometer_param)
121            if nbatch % args.log_interval == 0 and nbatch > 0:
122                cur_loss = total_loss / bptt / batch_size / args.log_interval
123                logging.info('Iter[%d] Batch [%d]\tLoss:  %.7f,\tPerplexity:\t%.7f' % \
124                             (epoch, nbatch, cur_loss, math.exp(cur_loss)))
125                total_loss = 0.0
126            nbatch += 1
127        # validation
128        valid_loss = evaluate(module, valid_data, epoch, 'Valid', bptt, batch_size)
129        if valid_loss < best_loss:
130            best_loss = valid_loss
131            # test
132            test_loss = evaluate(module, test_data, epoch, 'Test', bptt, batch_size)
133        else:
134            optimizer.lr *= 0.25
135        train_data.reset()
136    logging.info("Training completed. ")
137