1
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // Author: wuke
16
17 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
18 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
19
20 #include <string>
21 #include <vector>
22 using std::vector;
23
24 #include <fst/compat.h>
25 #include <fst/extensions/linear/linear-fst.h>
26 #include <fst/extensions/linear/linear-fst-data-builder.h>
27 #include <iostream>
28 #include <fstream>
29 #include <sstream>
30 #include <fst/symbol-table.h>
31 #include <fst/script/arg-packs.h>
32 #include <fst/script/script-impl.h>
33
34 DECLARE_string(delimiter);
35 DECLARE_string(empty_symbol);
36 DECLARE_string(start_symbol);
37 DECLARE_string(end_symbol);
38 DECLARE_bool(classifier);
39
40 namespace fst {
41 namespace script {
42 typedef args::Package<const string &, const string &, const string &, char **,
43 int, const string &, const string &, const string &,
44 const string &> LinearCompileArgs;
45
46 bool ValidateDelimiter();
47 bool ValidateEmptySymbol();
48
49 // Returns the proper label given the symbol. For symbols other than
50 // `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol
51 // table to decide the label. Depending on whether
52 // `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it
53 // either returns `kNoLabel` for later processing or decides the label
54 // right away.
55 template <class Arc>
LookUp(const string & str,SymbolTable * syms)56 inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) {
57 if (str == FLAGS_start_symbol)
58 return str == FLAGS_end_symbol ? kNoLabel
59 : LinearFstData<Arc>::kStartOfSentence;
60 else if (str == FLAGS_end_symbol)
61 return LinearFstData<Arc>::kEndOfSentence;
62 else
63 return syms->AddSymbol(str);
64 }
65
66 // Splits `str` with `delim` as the delimiter and stores the labels in
67 // `output`.
68 template <class Arc>
SplitAndPush(const string & str,const char delim,SymbolTable * syms,vector<typename Arc::Label> * output)69 void SplitAndPush(const string &str, const char delim, SymbolTable *syms,
70 vector<typename Arc::Label> *output) {
71 if (str == FLAGS_empty_symbol) return;
72 istringstream strm(str);
73 string buf;
74 while (std::getline(strm, buf, delim))
75 output->push_back(LookUp<Arc>(buf, syms));
76 }
77
78 // Like `std::replace_copy` but returns the number of modifications
79 template <class InputIterator, class OutputIterator, class T>
80 size_t
ReplaceCopy(InputIterator first,InputIterator last,OutputIterator result,const T & old_value,const T & new_value)81 ReplaceCopy(InputIterator first, InputIterator last, OutputIterator result,
82 const T &old_value, const T &new_value) {
83 size_t changes = 0;
84 while (first != last) {
85 if (*first == old_value) {
86 *result = new_value;
87 ++changes;
88 } else {
89 *result = *first;
90 }
91 ++first;
92 ++result;
93 }
94 return changes;
95 }
96
97 template <class Arc>
98 bool GetVocabRecord(const string &vocab, istream &strm, // NOLINT
99 SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
100 typename Arc::Label *word,
101 vector<typename Arc::Label> *feature_labels,
102 vector<typename Arc::Label> *possible_labels,
103 size_t *num_line);
104
105 template <class Arc>
106 bool GetModelRecord(const string &model, istream &strm, // NOLINT
107 SymbolTable *fsyms, SymbolTable *osyms,
108 vector<typename Arc::Label> *input_labels,
109 vector<typename Arc::Label> *output_labels,
110 typename Arc::Weight *weight, size_t *num_line);
111
112 // Reads in vocabulary file. Each line is in the following format
113 //
114 // word <whitespace> features [ <whitespace> possible output ]
115 //
116 // where features and possible output are `FLAGS_delimiter`-delimited lists of
117 // tokens
118 template <class Arc>
AddVocab(const string & vocab,SymbolTable * isyms,SymbolTable * fsyms,SymbolTable * osyms,LinearFstDataBuilder<Arc> * builder)119 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
120 SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
121 ifstream in(vocab.c_str());
122 if (!in) LOG(FATAL) << "Can't open file: " << vocab;
123 size_t num_line = 0, num_added = 0;
124 vector<string> fields;
125 vector<typename Arc::Label> feature_labels, possible_labels;
126 typename Arc::Label word;
127 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
128 &feature_labels, &possible_labels, &num_line)) {
129 if (word == kNoLabel) {
130 LOG(WARNING) << "Ignored: boundary word: " << fields[0];
131 continue;
132 }
133 if (possible_labels.empty())
134 num_added += builder->AddWord(word, feature_labels);
135 else
136 num_added += builder->AddWord(word, feature_labels, possible_labels);
137 }
138 VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
139 << vocab;
140 }
141
142 template <class Arc>
AddVocab(const string & vocab,SymbolTable * isyms,SymbolTable * fsyms,SymbolTable * osyms,LinearClassifierFstDataBuilder<Arc> * builder)143 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
144 SymbolTable *osyms,
145 LinearClassifierFstDataBuilder<Arc> *builder) {
146 ifstream in(vocab.c_str());
147 if (!in) LOG(FATAL) << "Can't open file: " << vocab;
148 size_t num_line = 0, num_added = 0;
149 vector<string> fields;
150 vector<typename Arc::Label> feature_labels, possible_labels;
151 typename Arc::Label word;
152 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
153 &feature_labels, &possible_labels, &num_line)) {
154 if (!possible_labels.empty())
155 LOG(FATAL)
156 << "Classifier vocabulary should not have possible output constraint";
157 if (word == kNoLabel) {
158 LOG(WARNING) << "Ignored: boundary word: " << fields[0];
159 continue;
160 }
161 num_added += builder->AddWord(word, feature_labels);
162 }
163 VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
164 << vocab;
165 }
166
167 // Reads in model file. The first line is an integer designating the
168 // size of future window in the input sequences. After this, each line
169 // is in the following format
170 //
171 // input sequence <whitespace> output sequence <whitespace> weight
172 //
173 // input sequence is a `FLAGS_delimiter`-delimited sequence of feature
174 // labels (see `AddVocab()`) . output sequence is a
175 // `FLAGS_delimiter`-delimited sequence of output labels where the
176 // last label is the output of the feature position before the history
177 // boundary.
178 template <class Arc>
AddModel(const string & model,SymbolTable * fsyms,SymbolTable * osyms,LinearFstDataBuilder<Arc> * builder)179 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
180 LinearFstDataBuilder<Arc> *builder) {
181 ifstream in(model.c_str());
182 if (!in) LOG(FATAL) << "Can't open file: " << model;
183 string line;
184 std::getline(in, line);
185 if (!in) LOG(FATAL) << "Empty file: " << model;
186 size_t future_size;
187 {
188 istringstream strm(line);
189 strm >> future_size;
190 if (!strm) LOG(FATAL) << "Can't read future size: " << model;
191 }
192 size_t num_line = 1, num_added = 0;
193 const int group = builder->AddGroup(future_size);
194 CHECK_GE(group, 0);
195 VLOG(1) << "Group " << group << ": from " << model << "; future size is "
196 << future_size << ".";
197 // Add the rest of lines as a single feature group
198 vector<string> fields;
199 vector<typename Arc::Label> input_labels, output_labels;
200 typename Arc::Weight weight;
201 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
202 &output_labels, &weight, &num_line)) {
203 if (output_labels.empty())
204 LOG(FATAL) << "Empty output sequence in source " << model << ", line "
205 << num_line;
206
207 const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
208 LinearFstData<Arc>::kEndOfSentence};
209
210 vector<typename Arc::Label> copy_input(input_labels.size()),
211 copy_output(output_labels.size());
212 for (int i = 0; i < 2; ++i) {
213 for (int j = 0; j < 2; ++j) {
214 size_t num_input_changes =
215 ReplaceCopy(input_labels.begin(), input_labels.end(),
216 copy_input.begin(), kNoLabel, marks[i]);
217 size_t num_output_changes =
218 ReplaceCopy(output_labels.begin(), output_labels.end(),
219 copy_output.begin(), kNoLabel, marks[j]);
220 if ((num_input_changes > 0 || i == 0) &&
221 (num_output_changes > 0 || j == 0))
222 num_added +=
223 builder->AddWeight(group, copy_input, copy_output, weight);
224 }
225 }
226 }
227 VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
228 << num_line << " lines.";
229 }
230
231 template <class Arc>
AddModel(const string & model,SymbolTable * fsyms,SymbolTable * osyms,LinearClassifierFstDataBuilder<Arc> * builder)232 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
233 LinearClassifierFstDataBuilder<Arc> *builder) {
234 ifstream in(model.c_str());
235 if (!in) LOG(FATAL) << "Can't open file: " << model;
236 string line;
237 std::getline(in, line);
238 if (!in) LOG(FATAL) << "Empty file: " << model;
239 size_t future_size;
240 {
241 istringstream strm(line);
242 strm >> future_size;
243 if (!strm) LOG(FATAL) << "Can't read future size: " << model;
244 }
245 if (future_size != 0)
246 LOG(FATAL) << "Classifier model must have future size = 0; got "
247 << future_size << " from " << model;
248 size_t num_line = 1, num_added = 0;
249 const int group = builder->AddGroup();
250 CHECK_GE(group, 0);
251 VLOG(1) << "Group " << group << ": from " << model << "; future size is "
252 << future_size << ".";
253 // Add the rest of lines as a single feature group
254 vector<string> fields;
255 vector<typename Arc::Label> input_labels, output_labels;
256 typename Arc::Weight weight;
257 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
258 &output_labels, &weight, &num_line)) {
259 if (output_labels.size() != 1)
260 LOG(FATAL) << "Output not a single label in source " << model << ", line "
261 << num_line;
262
263 const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
264 LinearFstData<Arc>::kEndOfSentence};
265
266 typename Arc::Label pred = output_labels[0];
267
268 vector<typename Arc::Label> copy_input(input_labels.size());
269 for (int i = 0; i < 2; ++i) {
270 size_t num_input_changes =
271 ReplaceCopy(input_labels.begin(), input_labels.end(),
272 copy_input.begin(), kNoLabel, marks[i]);
273 if (num_input_changes > 0 || i == 0)
274 num_added += builder->AddWeight(group, copy_input, pred, weight);
275 }
276 }
277 VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
278 << num_line << " lines.";
279 }
280
281 void SplitByWhitespace(const string &str, vector<string> *out);
282 int ScanNumClasses(char **models, int models_length);
283
284 template <class Arc>
LinearCompileTpl(LinearCompileArgs * args)285 void LinearCompileTpl(LinearCompileArgs *args) {
286 const string &epsilon_symbol = args->arg1;
287 const string &unknown_symbol = args->arg2;
288 const string &vocab = args->arg3;
289 char **models = args->arg4;
290 const int models_length = args->arg5;
291 const string &out = args->arg6;
292 const string &save_isymbols = args->arg7;
293 const string &save_fsymbols = args->arg8;
294 const string &save_osymbols = args->arg9;
295
296 SymbolTable isyms, // input (e.g. word tokens)
297 osyms, // output (e.g. tags)
298 fsyms; // feature (e.g. word identity, suffix, etc.)
299 isyms.AddSymbol(epsilon_symbol);
300 osyms.AddSymbol(epsilon_symbol);
301 fsyms.AddSymbol(epsilon_symbol);
302 isyms.AddSymbol(unknown_symbol);
303
304 VLOG(1) << "start-of-sentence label is "
305 << LinearFstData<Arc>::kStartOfSentence;
306 VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
307
308 if (FLAGS_classifier) {
309 int num_classes = ScanNumClasses(models, models_length);
310 LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
311 &osyms);
312
313 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
314 for (int i = 0; i < models_length; ++i)
315 AddModel(models[i], &fsyms, &osyms, &builder);
316
317 LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
318 fst.Write(out);
319 } else {
320 LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
321
322 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
323 for (int i = 0; i < models_length; ++i)
324 AddModel(models[i], &fsyms, &osyms, &builder);
325
326 LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
327 fst.Write(out);
328 }
329
330 if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
331 if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
332 if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
333 }
334
335 void LinearCompile(const string &arc_type, const string &epsilon_symbol,
336 const string &unknown_symbol, const string &vocab,
337 char **models, int models_len, const string &out,
338 const string &save_isymbols, const string &save_fsymbols,
339 const string &save_osymbols);
340
341 template <class Arc>
GetVocabRecord(const string & vocab,istream & strm,SymbolTable * isyms,SymbolTable * fsyms,SymbolTable * osyms,typename Arc::Label * word,vector<typename Arc::Label> * feature_labels,vector<typename Arc::Label> * possible_labels,size_t * num_line)342 bool GetVocabRecord(const string &vocab, istream &strm, // NOLINT
343 SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
344 typename Arc::Label *word,
345 vector<typename Arc::Label> *feature_labels,
346 vector<typename Arc::Label> *possible_labels,
347 size_t *num_line) {
348 string line;
349 if (!std::getline(strm, line)) return false;
350 ++(*num_line);
351
352 vector<string> fields;
353 SplitByWhitespace(line, &fields);
354 if (fields.size() != 3)
355 LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
356 << num_line;
357
358 feature_labels->clear();
359 possible_labels->clear();
360
361 *word = LookUp<Arc>(fields[0], isyms);
362
363 const char delim = FLAGS_delimiter[0];
364 SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
365 SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
366
367 return true;
368 }
369
370 template <class Arc>
GetModelRecord(const string & model,istream & strm,SymbolTable * fsyms,SymbolTable * osyms,vector<typename Arc::Label> * input_labels,vector<typename Arc::Label> * output_labels,typename Arc::Weight * weight,size_t * num_line)371 bool GetModelRecord(const string &model, istream &strm, // NOLINT
372 SymbolTable *fsyms, SymbolTable *osyms,
373 vector<typename Arc::Label> *input_labels,
374 vector<typename Arc::Label> *output_labels,
375 typename Arc::Weight *weight, size_t *num_line) {
376 string line;
377 if (!std::getline(strm, line)) return false;
378 ++(*num_line);
379
380 vector<string> fields;
381 SplitByWhitespace(line, &fields);
382 if (fields.size() != 3)
383 LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
384 << num_line;
385
386 input_labels->clear();
387 output_labels->clear();
388
389 const char delim = FLAGS_delimiter[0];
390 SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
391 SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
392
393 *weight = StrToWeight<typename Arc::Weight>(fields[2], model, *num_line);
394
395 GuessStartOrEnd<Arc>(input_labels, kNoLabel);
396 GuessStartOrEnd<Arc>(output_labels, kNoLabel);
397
398 return true;
399 }
400 } // namespace script
401 } // namespace fst
402
403 #define REGISTER_FST_LINEAR_OPERATIONS(Arc) \
404 REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs);
405
406 #endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
407