1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import numpy as np 19import mxnet as mx, math 20import argparse, math 21import logging 22from data import Corpus, CorpusIter 23from model import rnn, softmax_ce_loss 24from module import * 25from mxnet.model import BatchEndParam 26 27parser = argparse.ArgumentParser(description='Sherlock Holmes LSTM Language Model') 28parser.add_argument('--data', type=str, default='./data/sherlockholmes.', 29 help='location of the data corpus') 30parser.add_argument('--emsize', type=int, default=650, 31 help='size of word embeddings') 32parser.add_argument('--nhid', type=int, default=650, 33 help='number of hidden units per layer') 34parser.add_argument('--nlayers', type=int, default=2, 35 help='number of layers') 36parser.add_argument('--lr', type=float, default=1.0, 37 help='initial learning rate') 38parser.add_argument('--clip', type=float, default=0.2, 39 help='gradient clipping by global norm') 40parser.add_argument('--epochs', type=int, default=40, 41 help='upper epoch limit') 42parser.add_argument('--batch_size', type=int, default=32, 43 help='batch size') 44parser.add_argument('--dropout', type=float, default=0.5, 45 help='dropout applied to layers (0 = no dropout)') 46parser.add_argument('--tied', action='store_true', 47 help='tie the word embedding and softmax weights') 48parser.add_argument('--bptt', type=int, default=35, 49 help='sequence length') 50parser.add_argument('--log-interval', type=int, default=200, 51 help='report interval') 52parser.add_argument('--seed', type=int, default=3, 53 help='random seed') 54args = parser.parse_args() 55 56best_loss = 9999 57 58def evaluate(valid_module, data_iter, epoch, mode, bptt, batch_size): 59 total_loss = 0.0 60 nbatch = 0 61 for batch in data_iter: 62 valid_module.forward(batch, is_train=False) 63 outputs = valid_module.get_loss() 64 total_loss += mx.nd.sum(outputs[0]).asscalar() 65 nbatch += 1 66 data_iter.reset() 67 loss = total_loss / bptt / batch_size / nbatch 68 logging.info('Iter[%d] %s loss:\t%.7f, Perplexity: %.7f' % \ 69 (epoch, mode, loss, math.exp(loss))) 70 return loss 71 72if __name__ == '__main__': 73 # args 74 head = '%(asctime)-15s %(message)s' 75 logging.basicConfig(level=logging.DEBUG, format=head) 76 args = parser.parse_args() 77 logging.info(args) 78 ctx = mx.gpu() 79 batch_size = args.batch_size 80 bptt = args.bptt 81 mx.random.seed(args.seed) 82 83 # data 84 corpus = Corpus(args.data) 85 ntokens = len(corpus.dictionary) 86 train_data = CorpusIter(corpus.train, batch_size, bptt) 87 valid_data = CorpusIter(corpus.valid, batch_size, bptt) 88 test_data = CorpusIter(corpus.test, batch_size, bptt) 89 90 # model 91 pred, states, state_names = rnn(bptt, ntokens, args.emsize, args.nhid, 92 args.nlayers, args.dropout, batch_size, args.tied) 93 loss = softmax_ce_loss(pred) 94 95 # module 96 module = CustomStatefulModule(loss, states, state_names=state_names, context=ctx) 97 module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) 98 module.init_params(initializer=mx.init.Xavier()) 99 optimizer = mx.optimizer.create('sgd', learning_rate=args.lr, rescale_grad=1.0/batch_size) 100 module.init_optimizer(optimizer=optimizer) 101 102 # metric 103 speedometer = mx.callback.Speedometer(batch_size, args.log_interval) 104 105 # train 106 logging.info("Training started ... ") 107 for epoch in range(args.epochs): 108 # train 109 total_loss = 0.0 110 nbatch = 0 111 for batch in train_data: 112 module.forward(batch) 113 module.backward() 114 module.update(max_norm=args.clip * bptt * batch_size) 115 # update metric 116 outputs = module.get_loss() 117 total_loss += mx.nd.sum(outputs[0]).asscalar() 118 speedometer_param = BatchEndParam(epoch=epoch, nbatch=nbatch, 119 eval_metric=None, locals=locals()) 120 speedometer(speedometer_param) 121 if nbatch % args.log_interval == 0 and nbatch > 0: 122 cur_loss = total_loss / bptt / batch_size / args.log_interval 123 logging.info('Iter[%d] Batch [%d]\tLoss: %.7f,\tPerplexity:\t%.7f' % \ 124 (epoch, nbatch, cur_loss, math.exp(cur_loss))) 125 total_loss = 0.0 126 nbatch += 1 127 # validation 128 valid_loss = evaluate(module, valid_data, epoch, 'Valid', bptt, batch_size) 129 if valid_loss < best_loss: 130 best_loss = valid_loss 131 # test 132 test_loss = evaluate(module, test_data, epoch, 'Test', bptt, batch_size) 133 else: 134 optimizer.lr *= 0.25 135 train_data.reset() 136 logging.info("Training completed. ") 137