1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Deep Biaffine Dependency Parser driver class and script.""" 18 19import math 20import os 21import numpy as np 22 23import mxnet as mx 24from mxnet import gluon, autograd 25 26from scripts.parsing.common.config import _Config 27from scripts.parsing.common.data import ParserVocabulary, DataLoader, ConllWord, ConllSentence 28from scripts.parsing.common.exponential_scheduler import ExponentialScheduler 29from scripts.parsing.common.utils import init_logger, mxnet_prefer_gpu, Progbar 30from scripts.parsing.parser.biaffine_parser import BiaffineParser 31from scripts.parsing.parser.evaluate import evaluate_official_script 32 33 34class DepParser: 35 """User interfaces for biaffine dependency parser. 36 37 It wraps a biaffine model inside, provides training, evaluating and parsing. 38 """ 39 40 def __init__(self): 41 super().__init__() 42 self._parser = None 43 self._vocab = None 44 45 def train(self, train_file, dev_file, test_file, save_dir, 46 pretrained_embeddings=None, min_occur_count=2, 47 lstm_layers=3, word_dims=100, tag_dims=100, dropout_emb=0.33, lstm_hiddens=400, 48 dropout_lstm_input=0.33, dropout_lstm_hidden=0.33, 49 mlp_arc_size=500, mlp_rel_size=100, 50 dropout_mlp=0.33, learning_rate=2e-3, decay=.75, decay_steps=5000, 51 beta_1=.9, beta_2=.9, epsilon=1e-12, 52 num_buckets_train=40, 53 num_buckets_valid=10, num_buckets_test=10, train_iters=50000, train_batch_size=5000, 54 test_batch_size=5000, validate_every=100, save_after=5000, debug=False): 55 """Train a deep biaffine dependency parser. 56 57 Parameters 58 ---------- 59 train_file : str 60 path to training set 61 dev_file : str 62 path to dev set 63 test_file : str 64 path to test set 65 save_dir : str 66 a directory for saving model and related meta-data 67 pretrained_embeddings : tuple 68 (embedding_name, source), used for gluonnlp.embedding.create(embedding_name, source) 69 min_occur_count : int 70 threshold of rare words, which will be replaced with UNKs, 71 lstm_layers : int 72 layers of lstm 73 word_dims : int 74 dimension of word embedding 75 tag_dims : int 76 dimension of tag embedding 77 dropout_emb : float 78 word dropout 79 lstm_hiddens : int 80 size of lstm hidden states 81 dropout_lstm_input : int 82 dropout on x in variational RNN 83 dropout_lstm_hidden : int 84 dropout on h in variational RNN 85 mlp_arc_size : int 86 output size of MLP for arc feature extraction 87 mlp_rel_size : int 88 output size of MLP for rel feature extraction 89 dropout_mlp : float 90 dropout on the output of LSTM 91 learning_rate : float 92 learning rate 93 decay : float 94 see ExponentialScheduler 95 decay_steps : int 96 see ExponentialScheduler 97 beta_1 : float 98 see ExponentialScheduler 99 beta_2 : float 100 see ExponentialScheduler 101 epsilon : float 102 see ExponentialScheduler 103 num_buckets_train : int 104 number of buckets for training data set 105 num_buckets_valid : int 106 number of buckets for dev data set 107 num_buckets_test : int 108 number of buckets for testing data set 109 train_iters : int 110 training iterations 111 train_batch_size : int 112 training batch size 113 test_batch_size : int 114 test batch size 115 validate_every : int 116 validate on dev set every such number of batches 117 save_after : int 118 skip saving model in early epochs 119 debug : bool 120 debug mode 121 122 Returns 123 ------- 124 DepParser 125 parser itself 126 """ 127 logger = init_logger(save_dir) 128 config = _Config(train_file, dev_file, test_file, save_dir, pretrained_embeddings, 129 min_occur_count, 130 lstm_layers, word_dims, tag_dims, dropout_emb, lstm_hiddens, 131 dropout_lstm_input, dropout_lstm_hidden, mlp_arc_size, mlp_rel_size, 132 dropout_mlp, learning_rate, decay, decay_steps, 133 beta_1, beta_2, epsilon, num_buckets_train, num_buckets_valid, 134 num_buckets_test, train_iters, 135 train_batch_size, debug) 136 config.save() 137 self._vocab = vocab = ParserVocabulary(train_file, 138 pretrained_embeddings, 139 min_occur_count) 140 vocab.save(config.save_vocab_path) 141 vocab.log_info(logger) 142 143 with mx.Context(mxnet_prefer_gpu()): 144 self._parser = parser = BiaffineParser(vocab, word_dims, tag_dims, 145 dropout_emb, 146 lstm_layers, 147 lstm_hiddens, dropout_lstm_input, 148 dropout_lstm_hidden, 149 mlp_arc_size, 150 mlp_rel_size, dropout_mlp, debug) 151 parser.initialize() 152 scheduler = ExponentialScheduler(learning_rate, decay, decay_steps) 153 optimizer = mx.optimizer.Adam(learning_rate, beta_1, beta_2, epsilon, 154 lr_scheduler=scheduler) 155 trainer = gluon.Trainer(parser.collect_params(), optimizer=optimizer) 156 data_loader = DataLoader(train_file, num_buckets_train, vocab) 157 global_step = 0 158 best_UAS = 0. 159 batch_id = 0 160 epoch = 1 161 total_epoch = math.ceil(train_iters / validate_every) 162 logger.info('Epoch %d out of %d', epoch, total_epoch) 163 bar = Progbar(target=min(validate_every, data_loader.samples)) 164 while global_step < train_iters: 165 for words, tags, arcs, rels in data_loader.get_batches(batch_size=train_batch_size, 166 shuffle=True): 167 with autograd.record(): 168 arc_accuracy, _, _, loss = parser.forward(words, tags, arcs, rels) 169 loss_value = loss.asscalar() 170 loss.backward() 171 trainer.step(train_batch_size) 172 batch_id += 1 173 try: 174 bar.update(batch_id, 175 exact=[('UAS', arc_accuracy, 2), 176 ('loss', loss_value)]) 177 except OverflowError: 178 pass # sometimes loss can be 0 or infinity, crashes the bar 179 180 global_step += 1 181 if global_step % validate_every == 0: 182 bar = Progbar(target=min(validate_every, train_iters - global_step)) 183 batch_id = 0 184 UAS, LAS, speed = evaluate_official_script(parser, vocab, 185 num_buckets_valid, 186 test_batch_size, 187 dev_file, 188 os.path.join(save_dir, 189 'valid_tmp')) 190 logger.info('Dev: UAS %.2f%% LAS %.2f%% %d sents/s', UAS, LAS, speed) 191 epoch += 1 192 if global_step < train_iters: 193 logger.info('Epoch %d out of %d', epoch, total_epoch) 194 if global_step > save_after and UAS > best_UAS: 195 logger.info('- new best score!') 196 best_UAS = UAS 197 parser.save(config.save_model_path) 198 199 # When validate_every is too big 200 if not os.path.isfile(config.save_model_path) or best_UAS != UAS: 201 parser.save(config.save_model_path) 202 203 return self 204 205 def load(self, path): 206 """Load from disk 207 208 Parameters 209 ---------- 210 path : str 211 path to the directory which typically contains a config.pkl file and a model.bin file 212 213 Returns 214 ------- 215 DepParser 216 parser itself 217 """ 218 config = _Config.load(os.path.join(path, 'config.pkl')) 219 config.save_dir = path # redirect root path to what user specified 220 self._vocab = vocab = ParserVocabulary.load(config.save_vocab_path) 221 with mx.Context(mxnet_prefer_gpu()): 222 self._parser = BiaffineParser(vocab, config.word_dims, config.tag_dims, 223 config.dropout_emb, 224 config.lstm_layers, 225 config.lstm_hiddens, config.dropout_lstm_input, 226 config.dropout_lstm_hidden, 227 config.mlp_arc_size, config.mlp_rel_size, 228 config.dropout_mlp, config.debug) 229 self._parser.load(config.save_model_path) 230 return self 231 232 def evaluate(self, test_file, save_dir=None, logger=None, 233 num_buckets_test=10, test_batch_size=5000): 234 """Run evaluation on test set 235 236 Parameters 237 ---------- 238 test_file : str 239 path to test set 240 save_dir : str 241 where to store intermediate results and log 242 logger : logging.logger 243 logger for printing results 244 num_buckets_test : int 245 number of clusters for sentences from test set 246 test_batch_size : int 247 batch size of test set 248 249 Returns 250 ------- 251 tuple 252 UAS, LAS 253 """ 254 parser = self._parser 255 vocab = self._vocab 256 with mx.Context(mxnet_prefer_gpu()): 257 UAS, LAS, speed = evaluate_official_script(parser, vocab, num_buckets_test, 258 test_batch_size, test_file, 259 os.path.join(save_dir, 'valid_tmp')) 260 if logger is None: 261 logger = init_logger(save_dir, 'test.log') 262 logger.info('Test: UAS %.2f%% LAS %.2f%% %d sents/s', UAS, LAS, speed) 263 264 return UAS, LAS 265 266 def parse(self, sentence): 267 """Parse raw sentence into ConllSentence 268 269 Parameters 270 ---------- 271 sentence : list 272 a list of (word, tag) tuples 273 274 Returns 275 ------- 276 ConllSentence 277 ConllSentence object 278 """ 279 words = np.zeros((len(sentence) + 1, 1), np.int32) 280 tags = np.zeros((len(sentence) + 1, 1), np.int32) 281 words[0, 0] = ParserVocabulary.ROOT 282 tags[0, 0] = ParserVocabulary.ROOT 283 vocab = self._vocab 284 285 for i, (word, tag) in enumerate(sentence): 286 words[i + 1, 0], tags[i + 1, 0] = vocab.word2id(word.lower()), vocab.tag2id(tag) 287 288 with mx.Context(mxnet_prefer_gpu()): 289 outputs = self._parser.forward(words, tags) 290 words = [] 291 for arc, rel, (word, tag) in zip(outputs[0][0], outputs[0][1], sentence): 292 words.append(ConllWord(idx=len(words) + 1, form=word, pos=tag, 293 head=arc, relation=vocab.id2rel(rel))) 294 return ConllSentence(words) 295 296 297if __name__ == '__main__': 298 dep_parser = DepParser() 299 dep_parser.train(train_file='tests/data/biaffine/ptb/train.conllx', 300 dev_file='tests/data/biaffine/ptb/dev.conllx', 301 test_file='tests/data/biaffine/ptb/test.conllx', 302 save_dir='tests/data/biaffine/model', 303 pretrained_embeddings=('glove', 'glove.6B.100d')) 304 dep_parser.load('tests/data/biaffine/model') 305 dep_parser.evaluate(test_file='tests/data/biaffine/ptb/test.conllx', 306 save_dir='tests/data/biaffine/model') 307 308 sent = [('Is', 'VBZ'), ('this', 'DT'), ('the', 'DT'), ('future', 'NN'), 309 ('of', 'IN'), ('chamber', 'NN'), ('music', 'NN'), ('?', '.')] 310 print(dep_parser.parse(sent)) 311