1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Deep Biaffine Dependency Parser driver class and script."""
18
19import math
20import os
21import numpy as np
22
23import mxnet as mx
24from mxnet import gluon, autograd
25
26from scripts.parsing.common.config import _Config
27from scripts.parsing.common.data import ParserVocabulary, DataLoader, ConllWord, ConllSentence
28from scripts.parsing.common.exponential_scheduler import ExponentialScheduler
29from scripts.parsing.common.utils import init_logger, mxnet_prefer_gpu, Progbar
30from scripts.parsing.parser.biaffine_parser import BiaffineParser
31from scripts.parsing.parser.evaluate import evaluate_official_script
32
33
34class DepParser:
35    """User interfaces for biaffine dependency parser.
36
37    It wraps a biaffine model inside, provides training, evaluating and parsing.
38    """
39
40    def __init__(self):
41        super().__init__()
42        self._parser = None
43        self._vocab = None
44
45    def train(self, train_file, dev_file, test_file, save_dir,
46              pretrained_embeddings=None, min_occur_count=2,
47              lstm_layers=3, word_dims=100, tag_dims=100, dropout_emb=0.33, lstm_hiddens=400,
48              dropout_lstm_input=0.33, dropout_lstm_hidden=0.33,
49              mlp_arc_size=500, mlp_rel_size=100,
50              dropout_mlp=0.33, learning_rate=2e-3, decay=.75, decay_steps=5000,
51              beta_1=.9, beta_2=.9, epsilon=1e-12,
52              num_buckets_train=40,
53              num_buckets_valid=10, num_buckets_test=10, train_iters=50000, train_batch_size=5000,
54              test_batch_size=5000, validate_every=100, save_after=5000, debug=False):
55        """Train a deep biaffine dependency parser.
56
57        Parameters
58        ----------
59        train_file : str
60            path to training set
61        dev_file : str
62            path to dev set
63        test_file : str
64            path to test set
65        save_dir : str
66            a directory for saving model and related meta-data
67        pretrained_embeddings : tuple
68            (embedding_name, source), used for gluonnlp.embedding.create(embedding_name, source)
69        min_occur_count : int
70            threshold of rare words, which will be replaced with UNKs,
71        lstm_layers : int
72            layers of lstm
73        word_dims : int
74            dimension of word embedding
75        tag_dims : int
76            dimension of tag embedding
77        dropout_emb : float
78            word dropout
79        lstm_hiddens : int
80            size of lstm hidden states
81        dropout_lstm_input : int
82            dropout on x in variational RNN
83        dropout_lstm_hidden : int
84            dropout on h in variational RNN
85        mlp_arc_size : int
86            output size of MLP for arc feature extraction
87        mlp_rel_size : int
88            output size of MLP for rel feature extraction
89        dropout_mlp : float
90            dropout on the output of LSTM
91        learning_rate : float
92            learning rate
93        decay : float
94            see ExponentialScheduler
95        decay_steps : int
96            see ExponentialScheduler
97        beta_1 : float
98            see ExponentialScheduler
99        beta_2 : float
100            see ExponentialScheduler
101        epsilon : float
102            see ExponentialScheduler
103        num_buckets_train : int
104            number of buckets for training data set
105        num_buckets_valid : int
106            number of buckets for dev data set
107        num_buckets_test : int
108            number of buckets for testing data set
109        train_iters : int
110            training iterations
111        train_batch_size : int
112            training batch size
113        test_batch_size : int
114            test batch size
115        validate_every : int
116            validate on dev set every such number of batches
117        save_after : int
118            skip saving model in early epochs
119        debug : bool
120            debug mode
121
122        Returns
123        -------
124        DepParser
125            parser itself
126        """
127        logger = init_logger(save_dir)
128        config = _Config(train_file, dev_file, test_file, save_dir, pretrained_embeddings,
129                         min_occur_count,
130                         lstm_layers, word_dims, tag_dims, dropout_emb, lstm_hiddens,
131                         dropout_lstm_input, dropout_lstm_hidden, mlp_arc_size, mlp_rel_size,
132                         dropout_mlp, learning_rate, decay, decay_steps,
133                         beta_1, beta_2, epsilon, num_buckets_train, num_buckets_valid,
134                         num_buckets_test, train_iters,
135                         train_batch_size, debug)
136        config.save()
137        self._vocab = vocab = ParserVocabulary(train_file,
138                                               pretrained_embeddings,
139                                               min_occur_count)
140        vocab.save(config.save_vocab_path)
141        vocab.log_info(logger)
142
143        with mx.Context(mxnet_prefer_gpu()):
144            self._parser = parser = BiaffineParser(vocab, word_dims, tag_dims,
145                                                   dropout_emb,
146                                                   lstm_layers,
147                                                   lstm_hiddens, dropout_lstm_input,
148                                                   dropout_lstm_hidden,
149                                                   mlp_arc_size,
150                                                   mlp_rel_size, dropout_mlp, debug)
151            parser.initialize()
152            scheduler = ExponentialScheduler(learning_rate, decay, decay_steps)
153            optimizer = mx.optimizer.Adam(learning_rate, beta_1, beta_2, epsilon,
154                                          lr_scheduler=scheduler)
155            trainer = gluon.Trainer(parser.collect_params(), optimizer=optimizer)
156            data_loader = DataLoader(train_file, num_buckets_train, vocab)
157            global_step = 0
158            best_UAS = 0.
159            batch_id = 0
160            epoch = 1
161            total_epoch = math.ceil(train_iters / validate_every)
162            logger.info('Epoch %d out of %d', epoch, total_epoch)
163            bar = Progbar(target=min(validate_every, data_loader.samples))
164            while global_step < train_iters:
165                for words, tags, arcs, rels in data_loader.get_batches(batch_size=train_batch_size,
166                                                                       shuffle=True):
167                    with autograd.record():
168                        arc_accuracy, _, _, loss = parser.forward(words, tags, arcs, rels)
169                        loss_value = loss.asscalar()
170                    loss.backward()
171                    trainer.step(train_batch_size)
172                    batch_id += 1
173                    try:
174                        bar.update(batch_id,
175                                   exact=[('UAS', arc_accuracy, 2),
176                                          ('loss', loss_value)])
177                    except OverflowError:
178                        pass  # sometimes loss can be 0 or infinity, crashes the bar
179
180                    global_step += 1
181                    if global_step % validate_every == 0:
182                        bar = Progbar(target=min(validate_every, train_iters - global_step))
183                        batch_id = 0
184                        UAS, LAS, speed = evaluate_official_script(parser, vocab,
185                                                                   num_buckets_valid,
186                                                                   test_batch_size,
187                                                                   dev_file,
188                                                                   os.path.join(save_dir,
189                                                                                'valid_tmp'))
190                        logger.info('Dev: UAS %.2f%% LAS %.2f%% %d sents/s', UAS, LAS, speed)
191                        epoch += 1
192                        if global_step < train_iters:
193                            logger.info('Epoch %d out of %d', epoch, total_epoch)
194                        if global_step > save_after and UAS > best_UAS:
195                            logger.info('- new best score!')
196                            best_UAS = UAS
197                            parser.save(config.save_model_path)
198
199        # When validate_every is too big
200        if not os.path.isfile(config.save_model_path) or best_UAS != UAS:
201            parser.save(config.save_model_path)
202
203        return self
204
205    def load(self, path):
206        """Load from disk
207
208        Parameters
209        ----------
210        path : str
211            path to the directory which typically contains a config.pkl file and a model.bin file
212
213        Returns
214        -------
215        DepParser
216            parser itself
217        """
218        config = _Config.load(os.path.join(path, 'config.pkl'))
219        config.save_dir = path  # redirect root path to what user specified
220        self._vocab = vocab = ParserVocabulary.load(config.save_vocab_path)
221        with mx.Context(mxnet_prefer_gpu()):
222            self._parser = BiaffineParser(vocab, config.word_dims, config.tag_dims,
223                                          config.dropout_emb,
224                                          config.lstm_layers,
225                                          config.lstm_hiddens, config.dropout_lstm_input,
226                                          config.dropout_lstm_hidden,
227                                          config.mlp_arc_size, config.mlp_rel_size,
228                                          config.dropout_mlp, config.debug)
229            self._parser.load(config.save_model_path)
230        return self
231
232    def evaluate(self, test_file, save_dir=None, logger=None,
233                 num_buckets_test=10, test_batch_size=5000):
234        """Run evaluation on test set
235
236        Parameters
237        ----------
238        test_file : str
239            path to test set
240        save_dir : str
241            where to store intermediate results and log
242        logger : logging.logger
243            logger for printing results
244        num_buckets_test : int
245            number of clusters for sentences from test set
246        test_batch_size : int
247            batch size of test set
248
249        Returns
250        -------
251        tuple
252            UAS, LAS
253        """
254        parser = self._parser
255        vocab = self._vocab
256        with mx.Context(mxnet_prefer_gpu()):
257            UAS, LAS, speed = evaluate_official_script(parser, vocab, num_buckets_test,
258                                                       test_batch_size, test_file,
259                                                       os.path.join(save_dir, 'valid_tmp'))
260        if logger is None:
261            logger = init_logger(save_dir, 'test.log')
262        logger.info('Test: UAS %.2f%% LAS %.2f%% %d sents/s', UAS, LAS, speed)
263
264        return UAS, LAS
265
266    def parse(self, sentence):
267        """Parse raw sentence into ConllSentence
268
269        Parameters
270        ----------
271        sentence : list
272            a list of (word, tag) tuples
273
274        Returns
275        -------
276        ConllSentence
277            ConllSentence object
278        """
279        words = np.zeros((len(sentence) + 1, 1), np.int32)
280        tags = np.zeros((len(sentence) + 1, 1), np.int32)
281        words[0, 0] = ParserVocabulary.ROOT
282        tags[0, 0] = ParserVocabulary.ROOT
283        vocab = self._vocab
284
285        for i, (word, tag) in enumerate(sentence):
286            words[i + 1, 0], tags[i + 1, 0] = vocab.word2id(word.lower()), vocab.tag2id(tag)
287
288        with mx.Context(mxnet_prefer_gpu()):
289            outputs = self._parser.forward(words, tags)
290        words = []
291        for arc, rel, (word, tag) in zip(outputs[0][0], outputs[0][1], sentence):
292            words.append(ConllWord(idx=len(words) + 1, form=word, pos=tag,
293                                   head=arc, relation=vocab.id2rel(rel)))
294        return ConllSentence(words)
295
296
297if __name__ == '__main__':
298    dep_parser = DepParser()
299    dep_parser.train(train_file='tests/data/biaffine/ptb/train.conllx',
300                     dev_file='tests/data/biaffine/ptb/dev.conllx',
301                     test_file='tests/data/biaffine/ptb/test.conllx',
302                     save_dir='tests/data/biaffine/model',
303                     pretrained_embeddings=('glove', 'glove.6B.100d'))
304    dep_parser.load('tests/data/biaffine/model')
305    dep_parser.evaluate(test_file='tests/data/biaffine/ptb/test.conllx',
306                        save_dir='tests/data/biaffine/model')
307
308    sent = [('Is', 'VBZ'), ('this', 'DT'), ('the', 'DT'), ('future', 'NN'),
309            ('of', 'IN'), ('chamber', 'NN'), ('music', 'NN'), ('?', '.')]
310    print(dep_parser.parse(sent))
311