1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import argparse 19import time 20import math 21import os 22import mxnet as mx 23from mxnet import gluon, autograd 24from mxnet.gluon import contrib 25import model 26 27parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.') 28parser.add_argument('--model', type=str, default='lstm', 29 help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)') 30parser.add_argument('--emsize', type=int, default=650, 31 help='size of word embeddings') 32parser.add_argument('--nhid', type=int, default=650, 33 help='number of hidden units per layer') 34parser.add_argument('--nlayers', type=int, default=2, 35 help='number of layers') 36parser.add_argument('--lr', type=float, default=20, 37 help='initial learning rate') 38parser.add_argument('--clip', type=float, default=0.25, 39 help='gradient clipping') 40parser.add_argument('--epochs', type=int, default=40, 41 help='upper epoch limit') 42parser.add_argument('--batch_size', type=int, default=20, metavar='N', 43 help='batch size') 44parser.add_argument('--bptt', type=int, default=35, 45 help='sequence length') 46parser.add_argument('--dropout', type=float, default=0.5, 47 help='dropout applied to layers (0 = no dropout)') 48parser.add_argument('--tied', action='store_true', 49 help='tie the word embedding and softmax weights') 50parser.add_argument('--cuda', action='store_true', 51 help='Whether to use gpu') 52parser.add_argument('--log-interval', type=int, default=200, metavar='N', 53 help='report interval') 54parser.add_argument('--save', type=str, default='model.params', 55 help='path to save the final model') 56parser.add_argument('--gctype', type=str, default='none', 57 help='type of gradient compression to use, \ 58 takes `2bit` or `none` for now.') 59parser.add_argument('--gcthreshold', type=float, default=0.5, 60 help='threshold for 2bit gradient compression') 61parser.add_argument('--hybridize', action='store_true', 62 help='whether to hybridize in mxnet>=1.3 (default=False)') 63parser.add_argument('--static-alloc', action='store_true', 64 help='whether to use static-alloc hybridize in mxnet>=1.3 (default=False)') 65parser.add_argument('--static-shape', action='store_true', 66 help='whether to use static-shape hybridize in mxnet>=1.3 (default=False)') 67parser.add_argument('--export-model', action='store_true', 68 help='export a symbol graph and exit (default=False)') 69args = parser.parse_args() 70 71print(args) 72 73############################################################################### 74# Load data 75############################################################################### 76 77 78if args.cuda: 79 context = mx.gpu(0) 80else: 81 context = mx.cpu(0) 82 83if args.export_model: 84 args.hybridize = True 85 86# optional parameters only for mxnet >= 1.3 87hybridize_optional = dict(filter(lambda kv:kv[1], 88 {'static_alloc':args.static_alloc, 'static_shape':args.static_shape}.items())) 89if args.hybridize: 90 print('hybridize_optional', hybridize_optional) 91 92dirname = './data' 93dirname = os.path.expanduser(dirname) 94if not os.path.exists(dirname): 95 os.makedirs(dirname) 96 97train_dataset = contrib.data.text.WikiText2(dirname, 'train', seq_len=args.bptt) 98vocab = train_dataset.vocabulary 99val_dataset, test_dataset = [contrib.data.text.WikiText2(dirname, segment, 100 vocab=vocab, 101 seq_len=args.bptt) 102 for segment in ['validation', 'test']] 103 104nbatch_train = len(train_dataset) // args.batch_size 105train_data = gluon.data.DataLoader(train_dataset, 106 batch_size=args.batch_size, 107 sampler=contrib.data.IntervalSampler(len(train_dataset), 108 nbatch_train), 109 last_batch='discard') 110 111nbatch_val = len(val_dataset) // args.batch_size 112val_data = gluon.data.DataLoader(val_dataset, 113 batch_size=args.batch_size, 114 sampler=contrib.data.IntervalSampler(len(val_dataset), 115 nbatch_val), 116 last_batch='discard') 117 118nbatch_test = len(test_dataset) // args.batch_size 119test_data = gluon.data.DataLoader(test_dataset, 120 batch_size=args.batch_size, 121 sampler=contrib.data.IntervalSampler(len(test_dataset), 122 nbatch_test), 123 last_batch='discard') 124 125 126############################################################################### 127# Build the model 128############################################################################### 129 130 131ntokens = len(vocab) 132model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, 133 args.nlayers, args.dropout, args.tied) 134if args.hybridize: 135 model.hybridize(**hybridize_optional) 136model.initialize(mx.init.Xavier(), ctx=context) 137 138compression_params = None if args.gctype == 'none' else {'type': args.gctype, 'threshold': args.gcthreshold} 139trainer = gluon.Trainer(model.collect_params(), 'sgd', 140 {'learning_rate': args.lr, 141 'momentum': 0, 142 'wd': 0}, 143 compression_params=compression_params) 144loss = gluon.loss.SoftmaxCrossEntropyLoss() 145if args.hybridize: 146 loss.hybridize(**hybridize_optional) 147 148############################################################################### 149# Training code 150############################################################################### 151 152def detach(hidden): 153 if isinstance(hidden, (tuple, list)): 154 hidden = [i.detach() for i in hidden] 155 else: 156 hidden = hidden.detach() 157 return hidden 158 159def eval(data_source): 160 total_L = 0.0 161 ntotal = 0 162 hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context) 163 for i, (data, target) in enumerate(data_source): 164 data = data.as_in_context(context).T 165 target = target.as_in_context(context).T.reshape((-1, 1)) 166 output, hidden = model(data, hidden) 167 L = loss(output, target) 168 total_L += mx.nd.sum(L).asscalar() 169 ntotal += L.size 170 return total_L / ntotal 171 172def train(): 173 best_val = float("Inf") 174 for epoch in range(args.epochs): 175 total_L = 0.0 176 start_time = time.time() 177 hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context) 178 for i, (data, target) in enumerate(train_data): 179 data = data.as_in_context(context).T 180 target = target.as_in_context(context).T.reshape((-1, 1)) 181 hidden = detach(hidden) 182 with autograd.record(): 183 output, hidden = model(data, hidden) 184 # Here L is a vector of size batch_size * bptt size 185 L = loss(output, target) 186 L = L / (args.bptt * args.batch_size) 187 L.backward() 188 189 grads = [p.grad(context) for p in model.collect_params().values()] 190 gluon.utils.clip_global_norm(grads, args.clip) 191 192 trainer.step(1) 193 total_L += mx.nd.sum(L).asscalar() 194 195 if i % args.log_interval == 0 and i > 0: 196 cur_L = total_L / args.log_interval 197 print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%( 198 epoch, i, cur_L, math.exp(cur_L))) 199 total_L = 0.0 200 201 if args.export_model: 202 model.export('model') 203 return 204 205 val_L = eval(val_data) 206 207 print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%( 208 epoch, time.time()-start_time, val_L, math.exp(val_L))) 209 210 if val_L < best_val: 211 best_val = val_L 212 test_L = eval(test_data) 213 model.save_parameters(args.save) 214 print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L))) 215 else: 216 args.lr = args.lr*0.25 217 trainer.set_learning_rate(args.lr) 218 219if __name__ == '__main__': 220 train() 221 if not args.export_model: 222 model.load_parameters(args.save, context) 223 test_L = eval(test_data) 224 print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L))) 225 226