1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import argparse
19import time
20import math
21import os
22import mxnet as mx
23from mxnet import gluon, autograd
24from mxnet.gluon import contrib
25import model
26
27parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
28parser.add_argument('--model', type=str, default='lstm',
29                    help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)')
30parser.add_argument('--emsize', type=int, default=650,
31                    help='size of word embeddings')
32parser.add_argument('--nhid', type=int, default=650,
33                    help='number of hidden units per layer')
34parser.add_argument('--nlayers', type=int, default=2,
35                    help='number of layers')
36parser.add_argument('--lr', type=float, default=20,
37                    help='initial learning rate')
38parser.add_argument('--clip', type=float, default=0.25,
39                    help='gradient clipping')
40parser.add_argument('--epochs', type=int, default=40,
41                    help='upper epoch limit')
42parser.add_argument('--batch_size', type=int, default=20, metavar='N',
43                    help='batch size')
44parser.add_argument('--bptt', type=int, default=35,
45                    help='sequence length')
46parser.add_argument('--dropout', type=float, default=0.5,
47                    help='dropout applied to layers (0 = no dropout)')
48parser.add_argument('--tied', action='store_true',
49                    help='tie the word embedding and softmax weights')
50parser.add_argument('--cuda', action='store_true',
51                    help='Whether to use gpu')
52parser.add_argument('--log-interval', type=int, default=200, metavar='N',
53                    help='report interval')
54parser.add_argument('--save', type=str, default='model.params',
55                    help='path to save the final model')
56parser.add_argument('--gctype', type=str, default='none',
57                    help='type of gradient compression to use, \
58                          takes `2bit` or `none` for now.')
59parser.add_argument('--gcthreshold', type=float, default=0.5,
60                    help='threshold for 2bit gradient compression')
61parser.add_argument('--hybridize', action='store_true',
62                    help='whether to hybridize in mxnet>=1.3 (default=False)')
63parser.add_argument('--static-alloc', action='store_true',
64                    help='whether to use static-alloc hybridize in mxnet>=1.3 (default=False)')
65parser.add_argument('--static-shape', action='store_true',
66                    help='whether to use static-shape hybridize in mxnet>=1.3 (default=False)')
67parser.add_argument('--export-model', action='store_true',
68                    help='export a symbol graph and exit (default=False)')
69args = parser.parse_args()
70
71print(args)
72
73###############################################################################
74# Load data
75###############################################################################
76
77
78if args.cuda:
79    context = mx.gpu(0)
80else:
81    context = mx.cpu(0)
82
83if args.export_model:
84    args.hybridize = True
85
86# optional parameters only for mxnet >= 1.3
87hybridize_optional = dict(filter(lambda kv:kv[1],
88    {'static_alloc':args.static_alloc, 'static_shape':args.static_shape}.items()))
89if args.hybridize:
90    print('hybridize_optional', hybridize_optional)
91
92dirname = './data'
93dirname = os.path.expanduser(dirname)
94if not os.path.exists(dirname):
95    os.makedirs(dirname)
96
97train_dataset = contrib.data.text.WikiText2(dirname, 'train', seq_len=args.bptt)
98vocab = train_dataset.vocabulary
99val_dataset, test_dataset = [contrib.data.text.WikiText2(dirname, segment,
100                                                         vocab=vocab,
101                                                         seq_len=args.bptt)
102                             for segment in ['validation', 'test']]
103
104nbatch_train = len(train_dataset) // args.batch_size
105train_data = gluon.data.DataLoader(train_dataset,
106                                   batch_size=args.batch_size,
107                                   sampler=contrib.data.IntervalSampler(len(train_dataset),
108                                                                        nbatch_train),
109                                   last_batch='discard')
110
111nbatch_val = len(val_dataset) // args.batch_size
112val_data = gluon.data.DataLoader(val_dataset,
113                                 batch_size=args.batch_size,
114                                 sampler=contrib.data.IntervalSampler(len(val_dataset),
115                                                                      nbatch_val),
116                                 last_batch='discard')
117
118nbatch_test = len(test_dataset) // args.batch_size
119test_data = gluon.data.DataLoader(test_dataset,
120                                  batch_size=args.batch_size,
121                                  sampler=contrib.data.IntervalSampler(len(test_dataset),
122                                                                       nbatch_test),
123                                  last_batch='discard')
124
125
126###############################################################################
127# Build the model
128###############################################################################
129
130
131ntokens = len(vocab)
132model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
133                       args.nlayers, args.dropout, args.tied)
134if args.hybridize:
135    model.hybridize(**hybridize_optional)
136model.initialize(mx.init.Xavier(), ctx=context)
137
138compression_params = None if args.gctype == 'none' else {'type': args.gctype, 'threshold': args.gcthreshold}
139trainer = gluon.Trainer(model.collect_params(), 'sgd',
140                        {'learning_rate': args.lr,
141                         'momentum': 0,
142                         'wd': 0},
143                        compression_params=compression_params)
144loss = gluon.loss.SoftmaxCrossEntropyLoss()
145if args.hybridize:
146    loss.hybridize(**hybridize_optional)
147
148###############################################################################
149# Training code
150###############################################################################
151
152def detach(hidden):
153    if isinstance(hidden, (tuple, list)):
154        hidden = [i.detach() for i in hidden]
155    else:
156        hidden = hidden.detach()
157    return hidden
158
159def eval(data_source):
160    total_L = 0.0
161    ntotal = 0
162    hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
163    for i, (data, target) in enumerate(data_source):
164        data = data.as_in_context(context).T
165        target = target.as_in_context(context).T.reshape((-1, 1))
166        output, hidden = model(data, hidden)
167        L = loss(output, target)
168        total_L += mx.nd.sum(L).asscalar()
169        ntotal += L.size
170    return total_L / ntotal
171
172def train():
173    best_val = float("Inf")
174    for epoch in range(args.epochs):
175        total_L = 0.0
176        start_time = time.time()
177        hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
178        for i, (data, target) in enumerate(train_data):
179            data = data.as_in_context(context).T
180            target = target.as_in_context(context).T.reshape((-1, 1))
181            hidden = detach(hidden)
182            with autograd.record():
183                output, hidden = model(data, hidden)
184                # Here L is a vector of size batch_size * bptt size
185                L = loss(output, target)
186                L = L / (args.bptt * args.batch_size)
187                L.backward()
188
189            grads = [p.grad(context) for p in model.collect_params().values()]
190            gluon.utils.clip_global_norm(grads, args.clip)
191
192            trainer.step(1)
193            total_L += mx.nd.sum(L).asscalar()
194
195            if i % args.log_interval == 0 and i > 0:
196                cur_L = total_L / args.log_interval
197                print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%(
198                    epoch, i, cur_L, math.exp(cur_L)))
199                total_L = 0.0
200
201            if args.export_model:
202                model.export('model')
203                return
204
205        val_L = eval(val_data)
206
207        print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%(
208            epoch, time.time()-start_time, val_L, math.exp(val_L)))
209
210        if val_L < best_val:
211            best_val = val_L
212            test_L = eval(test_data)
213            model.save_parameters(args.save)
214            print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
215        else:
216            args.lr = args.lr*0.25
217            trainer.set_learning_rate(args.lr)
218
219if __name__ == '__main__':
220    train()
221    if not args.export_model:
222        model.load_parameters(args.save, context)
223        test_L = eval(test_data)
224        print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
225
226