1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Training config."""
18
19import os
20import pickle
21
22from scripts.parsing.common.savable import Savable
23
24
25class _Config(Savable):
26    def __init__(self, train_file, dev_file, test_file, save_dir,
27                 pretrained_embeddings_file=None, min_occur_count=2,
28                 lstm_layers=3, word_dims=100, tag_dims=100, dropout_emb=0.33, lstm_hiddens=400,
29                 dropout_lstm_input=0.33,
30                 dropout_lstm_hidden=0.33, mlp_arc_size=500, mlp_rel_size=100,
31                 dropout_mlp=0.33, learning_rate=2e-3, decay=.75, decay_steps=5000,
32                 beta_1=.9, beta_2=.9, epsilon=1e-12,
33                 num_buckets_train=40,
34                 num_buckets_valid=10, num_buckets_test=10,
35                 train_iters=50000, train_batch_size=5000, debug=False):
36        """Internal structure for hyper parameters, intended for pickle serialization.
37
38        May be replaced by a dict, but this class provides intuitive properties
39        and saving/loading mechanism
40
41        Parameters
42        ----------
43        train_file
44        dev_file
45        test_file
46        save_dir
47        pretrained_embeddings_file
48        min_occur_count
49        lstm_layers
50        word_dims
51        tag_dims
52        dropout_emb
53        lstm_hiddens
54        dropout_lstm_input
55        dropout_lstm_hidden
56        mlp_arc_size
57        mlp_rel_size
58        dropout_mlp
59        learning_rate
60        decay
61        decay_steps
62        beta_1
63        beta_2
64        epsilon
65        num_buckets_train
66        num_buckets_valid
67        num_buckets_test
68        train_iters
69        train_batch_size
70        debug
71        """
72        super(_Config, self).__init__()
73        self.pretrained_embeddings_file = pretrained_embeddings_file
74        self.train_file = train_file
75        self.dev_file = dev_file
76        self.test_file = test_file
77        self.min_occur_count = min_occur_count
78        self.save_dir = save_dir
79        self.lstm_layers = lstm_layers
80        self.word_dims = word_dims
81        self.tag_dims = tag_dims
82        self.dropout_emb = dropout_emb
83        self.lstm_hiddens = lstm_hiddens
84        self.dropout_lstm_input = dropout_lstm_input
85        self.dropout_lstm_hidden = dropout_lstm_hidden
86        self.mlp_arc_size = mlp_arc_size
87        self.mlp_rel_size = mlp_rel_size
88        self.dropout_mlp = dropout_mlp
89        self.learning_rate = learning_rate
90        self.decay = decay
91        self.decay_steps = decay_steps
92        self.beta_1 = beta_1
93        self.beta_2 = beta_2
94        self.epsilon = epsilon
95        self.num_buckets_train = num_buckets_train
96        self.num_buckets_valid = num_buckets_valid
97        self.num_buckets_test = num_buckets_test
98        self.train_iters = train_iters
99        self.train_batch_size = train_batch_size
100        self.debug = debug
101
102    @property
103    def save_model_path(self):
104        return os.path.join(self.save_dir, 'model.bin')
105
106    @property
107    def save_vocab_path(self):
108        return os.path.join(self.save_dir, 'vocab.pkl')
109
110    @property
111    def save_config_path(self):
112        return os.path.join(self.save_dir, 'config.pkl')
113
114    def save(self, path=None):
115        if not path:
116            path = self.save_config_path
117        with open(path, 'wb') as f:
118            pickle.dump(self, f)
119