1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Training config.""" 18 19import os 20import pickle 21 22from scripts.parsing.common.savable import Savable 23 24 25class _Config(Savable): 26 def __init__(self, train_file, dev_file, test_file, save_dir, 27 pretrained_embeddings_file=None, min_occur_count=2, 28 lstm_layers=3, word_dims=100, tag_dims=100, dropout_emb=0.33, lstm_hiddens=400, 29 dropout_lstm_input=0.33, 30 dropout_lstm_hidden=0.33, mlp_arc_size=500, mlp_rel_size=100, 31 dropout_mlp=0.33, learning_rate=2e-3, decay=.75, decay_steps=5000, 32 beta_1=.9, beta_2=.9, epsilon=1e-12, 33 num_buckets_train=40, 34 num_buckets_valid=10, num_buckets_test=10, 35 train_iters=50000, train_batch_size=5000, debug=False): 36 """Internal structure for hyper parameters, intended for pickle serialization. 37 38 May be replaced by a dict, but this class provides intuitive properties 39 and saving/loading mechanism 40 41 Parameters 42 ---------- 43 train_file 44 dev_file 45 test_file 46 save_dir 47 pretrained_embeddings_file 48 min_occur_count 49 lstm_layers 50 word_dims 51 tag_dims 52 dropout_emb 53 lstm_hiddens 54 dropout_lstm_input 55 dropout_lstm_hidden 56 mlp_arc_size 57 mlp_rel_size 58 dropout_mlp 59 learning_rate 60 decay 61 decay_steps 62 beta_1 63 beta_2 64 epsilon 65 num_buckets_train 66 num_buckets_valid 67 num_buckets_test 68 train_iters 69 train_batch_size 70 debug 71 """ 72 super(_Config, self).__init__() 73 self.pretrained_embeddings_file = pretrained_embeddings_file 74 self.train_file = train_file 75 self.dev_file = dev_file 76 self.test_file = test_file 77 self.min_occur_count = min_occur_count 78 self.save_dir = save_dir 79 self.lstm_layers = lstm_layers 80 self.word_dims = word_dims 81 self.tag_dims = tag_dims 82 self.dropout_emb = dropout_emb 83 self.lstm_hiddens = lstm_hiddens 84 self.dropout_lstm_input = dropout_lstm_input 85 self.dropout_lstm_hidden = dropout_lstm_hidden 86 self.mlp_arc_size = mlp_arc_size 87 self.mlp_rel_size = mlp_rel_size 88 self.dropout_mlp = dropout_mlp 89 self.learning_rate = learning_rate 90 self.decay = decay 91 self.decay_steps = decay_steps 92 self.beta_1 = beta_1 93 self.beta_2 = beta_2 94 self.epsilon = epsilon 95 self.num_buckets_train = num_buckets_train 96 self.num_buckets_valid = num_buckets_valid 97 self.num_buckets_test = num_buckets_test 98 self.train_iters = train_iters 99 self.train_batch_size = train_batch_size 100 self.debug = debug 101 102 @property 103 def save_model_path(self): 104 return os.path.join(self.save_dir, 'model.bin') 105 106 @property 107 def save_vocab_path(self): 108 return os.path.join(self.save_dir, 'vocab.pkl') 109 110 @property 111 def save_config_path(self): 112 return os.path.join(self.save_dir, 'config.pkl') 113 114 def save(self, path=None): 115 if not path: 116 path = self.save_config_path 117 with open(path, 'wb') as f: 118 pickle.dump(self, f) 119