1 
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // Author: wuke
16 
17 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
18 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
19 
20 #include <string>
21 #include <vector>
22 using std::vector;
23 
24 #include <fst/compat.h>
25 #include <fst/extensions/linear/linear-fst.h>
26 #include <fst/extensions/linear/linear-fst-data-builder.h>
27 #include <iostream>
28 #include <fstream>
29 #include <sstream>
30 #include <fst/symbol-table.h>
31 #include <fst/script/arg-packs.h>
32 #include <fst/script/script-impl.h>
33 
34 DECLARE_string(delimiter);
35 DECLARE_string(empty_symbol);
36 DECLARE_string(start_symbol);
37 DECLARE_string(end_symbol);
38 DECLARE_bool(classifier);
39 
40 namespace fst {
41 namespace script {
42 typedef args::Package<const string &, const string &, const string &, char **,
43                       int, const string &, const string &, const string &,
44                       const string &> LinearCompileArgs;
45 
46 bool ValidateDelimiter();
47 bool ValidateEmptySymbol();
48 
49 // Returns the proper label given the symbol. For symbols other than
50 // `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol
51 // table to decide the label. Depending on whether
52 // `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it
53 // either returns `kNoLabel` for later processing or decides the label
54 // right away.
55 template <class Arc>
LookUp(const string & str,SymbolTable * syms)56 inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) {
57   if (str == FLAGS_start_symbol)
58     return str == FLAGS_end_symbol ? kNoLabel
59                                    : LinearFstData<Arc>::kStartOfSentence;
60   else if (str == FLAGS_end_symbol)
61     return LinearFstData<Arc>::kEndOfSentence;
62   else
63     return syms->AddSymbol(str);
64 }
65 
66 // Splits `str` with `delim` as the delimiter and stores the labels in
67 // `output`.
68 template <class Arc>
SplitAndPush(const string & str,const char delim,SymbolTable * syms,vector<typename Arc::Label> * output)69 void SplitAndPush(const string &str, const char delim, SymbolTable *syms,
70                   vector<typename Arc::Label> *output) {
71   if (str == FLAGS_empty_symbol) return;
72   istringstream strm(str);
73   string buf;
74   while (std::getline(strm, buf, delim))
75     output->push_back(LookUp<Arc>(buf, syms));
76 }
77 
78 // Like `std::replace_copy` but returns the number of modifications
79 template <class InputIterator, class OutputIterator, class T>
80 size_t
ReplaceCopy(InputIterator first,InputIterator last,OutputIterator result,const T & old_value,const T & new_value)81 ReplaceCopy(InputIterator first, InputIterator last, OutputIterator result,
82             const T &old_value, const T &new_value) {
83   size_t changes = 0;
84   while (first != last) {
85     if (*first == old_value) {
86       *result = new_value;
87       ++changes;
88     } else {
89       *result = *first;
90     }
91     ++first;
92     ++result;
93   }
94   return changes;
95 }
96 
97 template <class Arc>
98 bool GetVocabRecord(const string &vocab, istream &strm,  // NOLINT
99                     SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
100                     typename Arc::Label *word,
101                     vector<typename Arc::Label> *feature_labels,
102                     vector<typename Arc::Label> *possible_labels,
103                     size_t *num_line);
104 
105 template <class Arc>
106 bool GetModelRecord(const string &model, istream &strm,  // NOLINT
107                     SymbolTable *fsyms, SymbolTable *osyms,
108                     vector<typename Arc::Label> *input_labels,
109                     vector<typename Arc::Label> *output_labels,
110                     typename Arc::Weight *weight, size_t *num_line);
111 
112 // Reads in vocabulary file. Each line is in the following format
113 //
114 //   word <whitespace> features [ <whitespace> possible output ]
115 //
116 // where features and possible output are `FLAGS_delimiter`-delimited lists of
117 // tokens
118 template <class Arc>
AddVocab(const string & vocab,SymbolTable * isyms,SymbolTable * fsyms,SymbolTable * osyms,LinearFstDataBuilder<Arc> * builder)119 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
120               SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
121   ifstream in(vocab.c_str());
122   if (!in) LOG(FATAL) << "Can't open file: " << vocab;
123   size_t num_line = 0, num_added = 0;
124   vector<string> fields;
125   vector<typename Arc::Label> feature_labels, possible_labels;
126   typename Arc::Label word;
127   while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
128                              &feature_labels, &possible_labels, &num_line)) {
129     if (word == kNoLabel) {
130       LOG(WARNING) << "Ignored: boundary word: " << fields[0];
131       continue;
132     }
133     if (possible_labels.empty())
134       num_added += builder->AddWord(word, feature_labels);
135     else
136       num_added += builder->AddWord(word, feature_labels, possible_labels);
137   }
138   VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
139           << vocab;
140 }
141 
142 template <class Arc>
AddVocab(const string & vocab,SymbolTable * isyms,SymbolTable * fsyms,SymbolTable * osyms,LinearClassifierFstDataBuilder<Arc> * builder)143 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
144               SymbolTable *osyms,
145               LinearClassifierFstDataBuilder<Arc> *builder) {
146   ifstream in(vocab.c_str());
147   if (!in) LOG(FATAL) << "Can't open file: " << vocab;
148   size_t num_line = 0, num_added = 0;
149   vector<string> fields;
150   vector<typename Arc::Label> feature_labels, possible_labels;
151   typename Arc::Label word;
152   while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
153                              &feature_labels, &possible_labels, &num_line)) {
154     if (!possible_labels.empty())
155       LOG(FATAL)
156           << "Classifier vocabulary should not have possible output constraint";
157     if (word == kNoLabel) {
158       LOG(WARNING) << "Ignored: boundary word: " << fields[0];
159       continue;
160     }
161     num_added += builder->AddWord(word, feature_labels);
162   }
163   VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
164           << vocab;
165 }
166 
167 // Reads in model file. The first line is an integer designating the
168 // size of future window in the input sequences. After this, each line
169 // is in the following format
170 //
171 //   input sequence <whitespace> output sequence <whitespace> weight
172 //
173 // input sequence is a `FLAGS_delimiter`-delimited sequence of feature
174 // labels (see `AddVocab()`) . output sequence is a
175 // `FLAGS_delimiter`-delimited sequence of output labels where the
176 // last label is the output of the feature position before the history
177 // boundary.
178 template <class Arc>
AddModel(const string & model,SymbolTable * fsyms,SymbolTable * osyms,LinearFstDataBuilder<Arc> * builder)179 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
180               LinearFstDataBuilder<Arc> *builder) {
181   ifstream in(model.c_str());
182   if (!in) LOG(FATAL) << "Can't open file: " << model;
183   string line;
184   std::getline(in, line);
185   if (!in) LOG(FATAL) << "Empty file: " << model;
186   size_t future_size;
187   {
188     istringstream strm(line);
189     strm >> future_size;
190     if (!strm) LOG(FATAL) << "Can't read future size: " << model;
191   }
192   size_t num_line = 1, num_added = 0;
193   const int group = builder->AddGroup(future_size);
194   CHECK_GE(group, 0);
195   VLOG(1) << "Group " << group << ": from " << model << "; future size is "
196           << future_size << ".";
197   // Add the rest of lines as a single feature group
198   vector<string> fields;
199   vector<typename Arc::Label> input_labels, output_labels;
200   typename Arc::Weight weight;
201   while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
202                              &output_labels, &weight, &num_line)) {
203     if (output_labels.empty())
204       LOG(FATAL) << "Empty output sequence in source " << model << ", line "
205                  << num_line;
206 
207     const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
208                                          LinearFstData<Arc>::kEndOfSentence};
209 
210     vector<typename Arc::Label> copy_input(input_labels.size()),
211         copy_output(output_labels.size());
212     for (int i = 0; i < 2; ++i) {
213       for (int j = 0; j < 2; ++j) {
214         size_t num_input_changes =
215             ReplaceCopy(input_labels.begin(), input_labels.end(),
216                         copy_input.begin(), kNoLabel, marks[i]);
217         size_t num_output_changes =
218             ReplaceCopy(output_labels.begin(), output_labels.end(),
219                         copy_output.begin(), kNoLabel, marks[j]);
220         if ((num_input_changes > 0 || i == 0) &&
221             (num_output_changes > 0 || j == 0))
222           num_added +=
223               builder->AddWeight(group, copy_input, copy_output, weight);
224       }
225     }
226   }
227   VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
228           << num_line << " lines.";
229 }
230 
231 template <class Arc>
AddModel(const string & model,SymbolTable * fsyms,SymbolTable * osyms,LinearClassifierFstDataBuilder<Arc> * builder)232 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
233               LinearClassifierFstDataBuilder<Arc> *builder) {
234   ifstream in(model.c_str());
235   if (!in) LOG(FATAL) << "Can't open file: " << model;
236   string line;
237   std::getline(in, line);
238   if (!in) LOG(FATAL) << "Empty file: " << model;
239   size_t future_size;
240   {
241     istringstream strm(line);
242     strm >> future_size;
243     if (!strm) LOG(FATAL) << "Can't read future size: " << model;
244   }
245   if (future_size != 0)
246     LOG(FATAL) << "Classifier model must have future size = 0; got "
247                << future_size << " from " << model;
248   size_t num_line = 1, num_added = 0;
249   const int group = builder->AddGroup();
250   CHECK_GE(group, 0);
251   VLOG(1) << "Group " << group << ": from " << model << "; future size is "
252           << future_size << ".";
253   // Add the rest of lines as a single feature group
254   vector<string> fields;
255   vector<typename Arc::Label> input_labels, output_labels;
256   typename Arc::Weight weight;
257   while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
258                              &output_labels, &weight, &num_line)) {
259     if (output_labels.size() != 1)
260       LOG(FATAL) << "Output not a single label in source " << model << ", line "
261                  << num_line;
262 
263     const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
264                                          LinearFstData<Arc>::kEndOfSentence};
265 
266     typename Arc::Label pred = output_labels[0];
267 
268     vector<typename Arc::Label> copy_input(input_labels.size());
269     for (int i = 0; i < 2; ++i) {
270       size_t num_input_changes =
271           ReplaceCopy(input_labels.begin(), input_labels.end(),
272                       copy_input.begin(), kNoLabel, marks[i]);
273       if (num_input_changes > 0 || i == 0)
274         num_added += builder->AddWeight(group, copy_input, pred, weight);
275     }
276   }
277   VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
278           << num_line << " lines.";
279 }
280 
281 void SplitByWhitespace(const string &str, vector<string> *out);
282 int ScanNumClasses(char **models, int models_length);
283 
284 template <class Arc>
LinearCompileTpl(LinearCompileArgs * args)285 void LinearCompileTpl(LinearCompileArgs *args) {
286   const string &epsilon_symbol = args->arg1;
287   const string &unknown_symbol = args->arg2;
288   const string &vocab = args->arg3;
289   char **models = args->arg4;
290   const int models_length = args->arg5;
291   const string &out = args->arg6;
292   const string &save_isymbols = args->arg7;
293   const string &save_fsymbols = args->arg8;
294   const string &save_osymbols = args->arg9;
295 
296   SymbolTable isyms,  // input (e.g. word tokens)
297       osyms,          // output (e.g. tags)
298       fsyms;          // feature (e.g. word identity, suffix, etc.)
299   isyms.AddSymbol(epsilon_symbol);
300   osyms.AddSymbol(epsilon_symbol);
301   fsyms.AddSymbol(epsilon_symbol);
302   isyms.AddSymbol(unknown_symbol);
303 
304   VLOG(1) << "start-of-sentence label is "
305           << LinearFstData<Arc>::kStartOfSentence;
306   VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
307 
308   if (FLAGS_classifier) {
309     int num_classes = ScanNumClasses(models, models_length);
310     LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
311                                                 &osyms);
312 
313     AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
314     for (int i = 0; i < models_length; ++i)
315       AddModel(models[i], &fsyms, &osyms, &builder);
316 
317     LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
318     fst.Write(out);
319   } else {
320     LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
321 
322     AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
323     for (int i = 0; i < models_length; ++i)
324       AddModel(models[i], &fsyms, &osyms, &builder);
325 
326     LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
327     fst.Write(out);
328   }
329 
330   if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
331   if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
332   if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
333 }
334 
335 void LinearCompile(const string &arc_type, const string &epsilon_symbol,
336                    const string &unknown_symbol, const string &vocab,
337                    char **models, int models_len, const string &out,
338                    const string &save_isymbols, const string &save_fsymbols,
339                    const string &save_osymbols);
340 
341 template <class Arc>
GetVocabRecord(const string & vocab,istream & strm,SymbolTable * isyms,SymbolTable * fsyms,SymbolTable * osyms,typename Arc::Label * word,vector<typename Arc::Label> * feature_labels,vector<typename Arc::Label> * possible_labels,size_t * num_line)342 bool GetVocabRecord(const string &vocab, istream &strm,  // NOLINT
343                     SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
344                     typename Arc::Label *word,
345                     vector<typename Arc::Label> *feature_labels,
346                     vector<typename Arc::Label> *possible_labels,
347                     size_t *num_line) {
348   string line;
349   if (!std::getline(strm, line)) return false;
350   ++(*num_line);
351 
352   vector<string> fields;
353   SplitByWhitespace(line, &fields);
354   if (fields.size() != 3)
355     LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
356                << num_line;
357 
358   feature_labels->clear();
359   possible_labels->clear();
360 
361   *word = LookUp<Arc>(fields[0], isyms);
362 
363   const char delim = FLAGS_delimiter[0];
364   SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
365   SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
366 
367   return true;
368 }
369 
370 template <class Arc>
GetModelRecord(const string & model,istream & strm,SymbolTable * fsyms,SymbolTable * osyms,vector<typename Arc::Label> * input_labels,vector<typename Arc::Label> * output_labels,typename Arc::Weight * weight,size_t * num_line)371 bool GetModelRecord(const string &model, istream &strm,  // NOLINT
372                     SymbolTable *fsyms, SymbolTable *osyms,
373                     vector<typename Arc::Label> *input_labels,
374                     vector<typename Arc::Label> *output_labels,
375                     typename Arc::Weight *weight, size_t *num_line) {
376   string line;
377   if (!std::getline(strm, line)) return false;
378   ++(*num_line);
379 
380   vector<string> fields;
381   SplitByWhitespace(line, &fields);
382   if (fields.size() != 3)
383     LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
384                << num_line;
385 
386   input_labels->clear();
387   output_labels->clear();
388 
389   const char delim = FLAGS_delimiter[0];
390   SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
391   SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
392 
393   *weight = StrToWeight<typename Arc::Weight>(fields[2], model, *num_line);
394 
395   GuessStartOrEnd<Arc>(input_labels, kNoLabel);
396   GuessStartOrEnd<Arc>(output_labels, kNoLabel);
397 
398   return true;
399 }
400 }  // namespace script
401 }  // namespace fst
402 
403 #define REGISTER_FST_LINEAR_OPERATIONS(Arc) \
404   REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs);
405 
406 #endif  // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
407