1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Class to to compile a binary FST from textual input.
19 
20 #ifndef FST_SCRIPT_COMPILE_IMPL_H_
21 #define FST_SCRIPT_COMPILE_IMPL_H_
22 
23 #include <iostream>
24 #include <memory>
25 #include <sstream>
26 #include <string>
27 #include <vector>
28 
29 #include <fst/fst.h>
30 #include <fst/util.h>
31 #include <fst/vector-fst.h>
32 #include <unordered_map>
33 #include <string_view>
34 
35 DECLARE_string(fst_field_separator);
36 
37 namespace fst {
38 
39 // Compile a binary Fst from textual input, helper class for fstcompile.cc
40 // WARNING: Stand-alone use of this class not recommended, most code should
41 // read/write using the binary format which is much more efficient.
42 template <class Arc>
43 class FstCompiler {
44  public:
45   using Label = typename Arc::Label;
46   using StateId = typename Arc::StateId;
47   using Weight = typename Arc::Weight;
48 
49   // WARNING: use of negative labels not recommended as it may cause conflicts.
50   // If add_symbols_ is true, then the symbols will be dynamically added to the
51   // symbol tables. This is only useful if you set the (i/o)keep flag to attach
52   // the final symbol table, or use the accessors. (The input symbol tables are
53   // const and therefore not changed.)
54   FstCompiler(std::istream &istrm, const std::string &source,
55               const SymbolTable *isyms, const SymbolTable *osyms,
56               const SymbolTable *ssyms, bool accep, bool ikeep, bool okeep,
57               bool nkeep, bool allow_negative_labels = false) {
58     std::unique_ptr<SymbolTable> misyms(isyms ? isyms->Copy() : nullptr);
59     std::unique_ptr<SymbolTable> mosyms(osyms ? osyms->Copy() : nullptr);
60     std::unique_ptr<SymbolTable> mssyms(ssyms ? ssyms->Copy() : nullptr);
61     Init(istrm, source, misyms.get(), mosyms.get(), mssyms.get(), accep, ikeep,
62          okeep, nkeep, allow_negative_labels, false);
63   }
64 
FstCompiler(std::istream & istrm,const std::string & source,SymbolTable * isyms,SymbolTable * osyms,SymbolTable * ssyms,bool accep,bool ikeep,bool okeep,bool nkeep,bool allow_negative_labels,bool add_symbols)65   FstCompiler(std::istream &istrm, const std::string &source,
66               SymbolTable *isyms, SymbolTable *osyms, SymbolTable *ssyms,
67               bool accep, bool ikeep, bool okeep, bool nkeep,
68               bool allow_negative_labels, bool add_symbols) {
69     Init(istrm, source, isyms, osyms, ssyms, accep, ikeep, okeep, nkeep,
70          allow_negative_labels, add_symbols);
71   }
72 
Init(std::istream & istrm,const std::string & source,SymbolTable * isyms,SymbolTable * osyms,SymbolTable * ssyms,bool accep,bool ikeep,bool okeep,bool nkeep,bool allow_negative_labels,bool add_symbols)73   void Init(std::istream &istrm, const std::string &source, SymbolTable *isyms,
74             SymbolTable *osyms, SymbolTable *ssyms, bool accep, bool ikeep,
75             bool okeep, bool nkeep, bool allow_negative_labels,
76             bool add_symbols) {
77     nline_ = 0;
78     source_ = source;
79     isyms_ = isyms;
80     osyms_ = osyms;
81     ssyms_ = ssyms;
82     nstates_ = 0;
83     keep_state_numbering_ = nkeep;
84     allow_negative_labels_ = allow_negative_labels;
85     add_symbols_ = add_symbols;
86     bool start_state_populated = false;
87     char line[kLineLen];
88     const std::string separator =
89         FST_FLAGS_fst_field_separator + "\n";
90     while (istrm.getline(line, kLineLen)) {
91       ++nline_;
92       std::vector<std::string_view> col = SplitString(line, separator, true);
93       if (col.empty() || col[0].empty()) continue;
94       if (col.size() > 5 || (col.size() > 4 && accep) ||
95           (col.size() == 3 && !accep)) {
96         FSTERROR() << "FstCompiler: Bad number of columns, source = " << source_
97                    << ", line = " << nline_;
98         fst_.SetProperties(kError, kError);
99         return;
100       }
101       StateId s = StrToStateId(col[0]);
102       while (s >= fst_.NumStates()) fst_.AddState();
103       if (!start_state_populated) {
104         fst_.SetStart(s);
105         start_state_populated = true;
106       }
107 
108       Arc arc;
109       StateId d = s;
110       switch (col.size()) {
111         case 1:
112           fst_.SetFinal(s, Weight::One());
113           break;
114         case 2:
115           fst_.SetFinal(s, StrToWeight(col[1], true));
116           break;
117         case 3:
118           arc.nextstate = d = StrToStateId(col[1]);
119           arc.ilabel = StrToILabel(col[2]);
120           arc.olabel = arc.ilabel;
121           arc.weight = Weight::One();
122           fst_.AddArc(s, arc);
123           break;
124         case 4:
125           arc.nextstate = d = StrToStateId(col[1]);
126           arc.ilabel = StrToILabel(col[2]);
127           if (accep) {
128             arc.olabel = arc.ilabel;
129             arc.weight = StrToWeight(col[3], true);
130           } else {
131             arc.olabel = StrToOLabel(col[3]);
132             arc.weight = Weight::One();
133           }
134           fst_.AddArc(s, arc);
135           break;
136         case 5:
137           arc.nextstate = d = StrToStateId(col[1]);
138           arc.ilabel = StrToILabel(col[2]);
139           arc.olabel = StrToOLabel(col[3]);
140           arc.weight = StrToWeight(col[4], true);
141           fst_.AddArc(s, arc);
142       }
143       while (d >= fst_.NumStates()) fst_.AddState();
144     }
145     if (ikeep) fst_.SetInputSymbols(isyms);
146     if (okeep) fst_.SetOutputSymbols(osyms);
147   }
148 
Fst()149   const VectorFst<Arc> &Fst() const { return fst_; }
150 
151  private:
152   // Maximum line length in text file.
153   static constexpr int kLineLen = 8096;
154 
155   StateId StrToId(std::string_view s, SymbolTable *syms, const char *name,
156                   bool allow_negative = false) const {
157     StateId n = 0;
158     if (syms) {
159       n = (add_symbols_) ? syms->AddSymbol(s) : syms->Find(s);
160       if (n == -1 || (!allow_negative && n < 0)) {
161         FSTERROR() << "FstCompiler: Symbol \"" << s
162                    << "\" is not mapped to any integer " << name
163                    << ", symbol table = " << syms->Name()
164                    << ", source = " << source_ << ", line = " << nline_;
165         fst_.SetProperties(kError, kError);
166       }
167     } else {
168       auto maybe_n = ParseInt64(s);
169       if (!maybe_n.has_value() || (!allow_negative && *maybe_n < 0)) {
170         FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
171                    << "\", source = " << source_ << ", line = " << nline_;
172         fst_.SetProperties(kError, kError);
173       }
174       n = *maybe_n;
175     }
176     return n;
177   }
178 
StrToStateId(std::string_view s)179   StateId StrToStateId(std::string_view s) {
180     StateId n = StrToId(s, ssyms_, "state ID");
181     if (keep_state_numbering_) return n;
182     // Remaps state IDs to make dense set.
183     const auto it = states_.find(n);
184     if (it == states_.end()) {
185       states_[n] = nstates_;
186       return nstates_++;
187     } else {
188       return it->second;
189     }
190   }
191 
StrToILabel(std::string_view s)192   StateId StrToILabel(std::string_view s) const {
193     return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_);
194   }
195 
StrToOLabel(std::string_view s)196   StateId StrToOLabel(std::string_view s) const {
197     return StrToId(s, osyms_, "arc olabel", allow_negative_labels_);
198   }
199 
StrToWeight(std::string_view s,bool allow_zero)200   Weight StrToWeight(std::string_view s, bool allow_zero) const {
201     Weight w;
202     std::istringstream strm(std::string{s});
203     strm >> w;
204     if (!strm || (!allow_zero && w == Weight::Zero())) {
205       FSTERROR() << "FstCompiler: Bad weight = \"" << s
206                  << "\", source = " << source_ << ", line = " << nline_;
207       fst_.SetProperties(kError, kError);
208       w = Weight::NoWeight();
209     }
210     return w;
211   }
212 
213   mutable VectorFst<Arc> fst_;
214   size_t nline_;
215   std::string source_;  // Text FST source name.
216   SymbolTable *isyms_;  // ilabel symbol table (not owned).
217   SymbolTable *osyms_;  // olabel symbol table (not owned).
218   SymbolTable *ssyms_;  // slabel symbol table (not owned).
219   std::unordered_map<StateId, StateId> states_;  // State ID map.
220   StateId nstates_;                               // Number of seen states.
221   bool keep_state_numbering_;
222   bool allow_negative_labels_;  // Not recommended; may cause conflicts.
223   bool add_symbols_;            // Add to symbol tables on-the fly.
224 
225   FstCompiler(const FstCompiler &) = delete;
226   FstCompiler &operator=(const FstCompiler &) = delete;
227 };
228 
229 }  // namespace fst
230 
231 #endif  // FST_SCRIPT_COMPILE_IMPL_H_
232