1
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // All Rights Reserved.
16 //
17 // Author : Johan Schalkwyk
18 //
19 // \file
20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
21
22 #ifndef FST_LIB_SYMBOL_TABLE_H__
23 #define FST_LIB_SYMBOL_TABLE_H__
24
25 #include <cstring>
26 #include <unordered_map>
27 using std::unordered_map;
28 using std::unordered_multimap;
29 #include <memory>
30 #include <string>
31 #include <utility>
32 using std::pair; using std::make_pair;
33 #include <vector>
34 using std::vector;
35
36
37 #include <fst/compat.h>
38 #include <iostream>
39 #include <fstream>
40 #include <sstream>
41
42
43 #include <map>
44
45 DECLARE_bool(fst_compat_symbols);
46
47 namespace fst {
48
49 // WARNING: Reading via symbol table read options should
50 // not be used. This is a temporary work around for
51 // reading symbol ranges of previously stored symbol sets.
52 struct SymbolTableReadOptions {
SymbolTableReadOptionsSymbolTableReadOptions53 SymbolTableReadOptions() { }
54
SymbolTableReadOptionsSymbolTableReadOptions55 SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
56 const string& source_)
57 : string_hash_ranges(string_hash_ranges_),
58 source(source_) { }
59
60 vector<pair<int64, int64> > string_hash_ranges;
61 string source;
62 };
63
64 struct SymbolTableTextOptions {
65 SymbolTableTextOptions();
66
67 bool allow_negative;
68 string fst_field_separator;
69 };
70
71 namespace internal {
72
73 // List of symbols with a dense hash for looking up symbol index.
74 // Hash uses linear probe, rehashes at 0.75% occupancy, avg 6 bytes overhead
75 // per entry. Rehash in place from symbol list.
76 //
77 // Symbols are stored as c strings to avoid adding memory overhead, but the
78 // performance penalty for this is high because rehash must call strlen on
79 // every symbol. AddSymbol can be another 2x faster if symbol lengths were
80 // stored.
81 class DenseSymbolMap {
82 public:
83 DenseSymbolMap();
84 DenseSymbolMap(const DenseSymbolMap& x);
85 ~DenseSymbolMap();
86
87 pair<int64, bool> InsertOrFind(const string& key);
88 int64 Find(const string& key) const;
89
size()90 const size_t size() const { return symbols_.size(); }
GetSymbol(size_t idx)91 const string GetSymbol(size_t idx) const {
92 return string(symbols_[idx], strlen(symbols_[idx]));
93 }
94
95 private:
96 void Rehash();
97 const char* NewSymbol(const string& sym);
98
99 int64 empty_;
100 vector<const char*> symbols_;
101 std::hash<string> str_hash_;
102 vector<int64> buckets_;
103 uint64 hash_mask_;
104 int size_;
105 };
106
107 } // namespace internal
108
109 class SymbolTableImpl {
110 public:
SymbolTableImpl(const string & name)111 SymbolTableImpl(const string &name)
112 : name_(name),
113 available_key_(0),
114 dense_key_limit_(0),
115 check_sum_finalized_(false) {}
116
SymbolTableImpl(const SymbolTableImpl & impl)117 explicit SymbolTableImpl(const SymbolTableImpl& impl)
118 : name_(impl.name_),
119 available_key_(impl.available_key_),
120 dense_key_limit_(impl.dense_key_limit_),
121 symbols_(impl.symbols_),
122 idx_key_(impl.idx_key_),
123 key_map_(impl.key_map_),
124 check_sum_finalized_(false) {}
125
126 int64 AddSymbol(const string& symbol, int64 key);
127
AddSymbol(const string & symbol)128 int64 AddSymbol(const string& symbol) {
129 return AddSymbol(symbol, available_key_);
130 }
131
132 static SymbolTableImpl* ReadText(
133 istream &strm, const string &name,
134 const SymbolTableTextOptions &opts = SymbolTableTextOptions());
135
136 static SymbolTableImpl* Read(istream &strm,
137 const SymbolTableReadOptions& opts);
138
139 bool Write(ostream &strm) const;
140
141 // Return the string associated with the key. If the key is out of
142 // range (<0, >max), return an empty string.
Find(int64 key)143 string Find(int64 key) const {
144 int64 idx = key;
145 if (key < 0 || key >= dense_key_limit_) {
146 map<int64, int64>::const_iterator iter
147 = key_map_.find(key);
148 if (iter == key_map_.end()) return "";
149 idx = iter->second;
150 }
151 if (idx < 0 || idx >= symbols_.size()) return "";
152 return symbols_.GetSymbol(idx);
153 }
154
155 // Return the key associated with the symbol. If the symbol
156 // does not exists, return SymbolTable::kNoSymbol.
Find(const string & symbol)157 int64 Find(const string& symbol) const {
158 int64 idx = symbols_.Find(symbol);
159 if (idx == -1 || idx < dense_key_limit_) return idx;
160 return idx_key_[idx - dense_key_limit_];
161 }
162
163 // Return the key associated with the symbol. If the symbol
164 // does not exists, return SymbolTable::kNoSymbol.
Find(const char * symbol)165 int64 Find(const char* symbol) const {
166 return Find(string(symbol));
167 }
168
GetNthKey(ssize_t pos)169 int64 GetNthKey(ssize_t pos) const {
170 if (pos < 0 || pos >= symbols_.size()) return -1;
171 if (pos < dense_key_limit_) return pos;
172 return Find(symbols_.GetSymbol(pos));
173 }
174
Name()175 const string& Name() const { return name_; }
176
IncrRefCount()177 int IncrRefCount() const {
178 return ref_count_.Incr();
179 }
DecrRefCount()180 int DecrRefCount() const {
181 return ref_count_.Decr();
182 }
RefCount()183 int RefCount() const {
184 return ref_count_.count();
185 }
186
CheckSum()187 string CheckSum() const {
188 MaybeRecomputeCheckSum();
189 return check_sum_string_;
190 }
191
LabeledCheckSum()192 string LabeledCheckSum() const {
193 MaybeRecomputeCheckSum();
194 return labeled_check_sum_string_;
195 }
196
AvailableKey()197 int64 AvailableKey() const {
198 return available_key_;
199 }
200
NumSymbols()201 size_t NumSymbols() const {
202 return symbols_.size();
203 }
204
205 private:
206 // Recomputes the checksums (both of them) if we've had changes since the last
207 // computation (i.e., if check_sum_finalized_ is false).
208 // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
209 // if the checksum is up-to-date (requiring no recomputation).
210 void MaybeRecomputeCheckSum() const;
211
212 string name_;
213 int64 available_key_;
214 int64 dense_key_limit_;
215
216 internal::DenseSymbolMap symbols_;
217 vector<int64> idx_key_;
218 map<int64, int64> key_map_;
219
220 mutable RefCounter ref_count_;
221 mutable bool check_sum_finalized_;
222 mutable string check_sum_string_;
223 mutable string labeled_check_sum_string_;
224 mutable Mutex check_sum_mutex_;
225 };
226
227 //
228 // \class SymbolTable
229 // \brief Symbol (string) to int and reverse mapping
230 //
231 // The SymbolTable implements the mappings of labels to strings and reverse.
232 // SymbolTables are used to describe the alphabet of the input and output
233 // labels for arcs in a Finite State Transducer.
234 //
235 // SymbolTables are reference counted and can therefore be shared across
236 // multiple machines. For example a language model grammar G, with a
237 // SymbolTable for the words in the language model can share this symbol
238 // table with the lexical representation L o G.
239 //
240 class SymbolTable {
241 public:
242 static const int64 kNoSymbol = -1;
243
244 // Construct symbol table with an unspecified name.
SymbolTable()245 SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {}
246
247 // Construct symbol table with a unique name.
SymbolTable(const string & name)248 explicit SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
249
250 // Create a reference counted copy.
SymbolTable(const SymbolTable & table)251 SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
252 impl_->IncrRefCount();
253 }
254
255 // Derefence implentation object. When reference count hits 0, delete
256 // implementation.
~SymbolTable()257 virtual ~SymbolTable() {
258 if (!impl_->DecrRefCount()) delete impl_;
259 }
260
261 // Copys the implemenation from one symbol table to another.
262 void operator=(const SymbolTable &st) {
263 if (impl_ != st.impl_) {
264 st.impl_->IncrRefCount();
265 if (!impl_->DecrRefCount()) delete impl_;
266 impl_ = st.impl_;
267 }
268 }
269
270 // Read an ascii representation of the symbol table from an istream. Pass a
271 // name to give the resulting SymbolTable.
272 static SymbolTable* ReadText(
273 istream &strm, const string& name,
274 const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
275 SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts);
276 if (!impl)
277 return 0;
278 else
279 return new SymbolTable(impl);
280 }
281
282 // read an ascii representation of the symbol table
283 static SymbolTable* ReadText(const string& filename,
284 const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
285 ifstream strm(filename.c_str(), ifstream::in);
286 if (!strm.good()) {
287 LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
288 return 0;
289 }
290 return ReadText(strm, filename, opts);
291 }
292
293
294 // WARNING: Reading via symbol table read options should
295 // not be used. This is a temporary work around.
Read(istream & strm,const SymbolTableReadOptions & opts)296 static SymbolTable* Read(istream &strm,
297 const SymbolTableReadOptions& opts) {
298 SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
299 if (!impl)
300 return 0;
301 else
302 return new SymbolTable(impl);
303 }
304
305 // read a binary dump of the symbol table from a stream
Read(istream & strm,const string & source)306 static SymbolTable* Read(istream &strm, const string& source) {
307 SymbolTableReadOptions opts;
308 opts.source = source;
309 return Read(strm, opts);
310 }
311
312 // read a binary dump of the symbol table
Read(const string & filename)313 static SymbolTable* Read(const string& filename) {
314 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
315 if (!strm.good()) {
316 LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
317 return 0;
318 }
319 return Read(strm, filename);
320 }
321
322 //--------------------------------------------------------
323 // Derivable Interface (final)
324 //--------------------------------------------------------
325 // create a reference counted copy
Copy()326 virtual SymbolTable* Copy() const {
327 return new SymbolTable(*this);
328 }
329
330 // Add a symbol with given key to table. A symbol table also
331 // keeps track of the last available key (highest key value in
332 // the symbol table).
AddSymbol(const string & symbol,int64 key)333 virtual int64 AddSymbol(const string& symbol, int64 key) {
334 MutateCheck();
335 return impl_->AddSymbol(symbol, key);
336 }
337
338 // Add a symbol to the table. The associated value key is automatically
339 // assigned by the symbol table.
AddSymbol(const string & symbol)340 virtual int64 AddSymbol(const string& symbol) {
341 MutateCheck();
342 return impl_->AddSymbol(symbol);
343 }
344
345 // Add another symbol table to this table. All key values will be offset
346 // by the current available key (highest key value in the symbol table).
347 // Note string symbols with the same key value with still have the same
348 // key value after the symbol table has been merged, but a different
349 // value. Adding symbol tables do not result in changes in the base table.
350 virtual void AddTable(const SymbolTable& table);
351
352 // return the name of the symbol table
Name()353 virtual const string& Name() const {
354 return impl_->Name();
355 }
356
357 // Return the label-agnostic MD5 check-sum for this table. All new symbols
358 // added to the table will result in an updated checksum.
359 // DEPRECATED.
CheckSum()360 virtual string CheckSum() const {
361 return impl_->CheckSum();
362 }
363
364 // Same as CheckSum(), but this returns an label-dependent version.
LabeledCheckSum()365 virtual string LabeledCheckSum() const {
366 return impl_->LabeledCheckSum();
367 }
368
Write(ostream & strm)369 virtual bool Write(ostream &strm) const {
370 return impl_->Write(strm);
371 }
372
Write(const string & filename)373 bool Write(const string& filename) const {
374 ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
375 if (!strm.good()) {
376 LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
377 return false;
378 }
379 return Write(strm);
380 }
381
382 // Dump an ascii text representation of the symbol table via a stream
383 virtual bool WriteText(
384 ostream &strm,
385 const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;
386
387 // Dump an ascii text representation of the symbol table
WriteText(const string & filename)388 bool WriteText(const string& filename) const {
389 ofstream strm(filename.c_str());
390 if (!strm.good()) {
391 LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
392 return false;
393 }
394 return WriteText(strm);
395 }
396
397 // Return the string associated with the key. If the key is out of
398 // range (<0, >max), log error and return an empty string.
Find(int64 key)399 virtual string Find(int64 key) const {
400 return impl_->Find(key);
401 }
402
403 // Return the key associated with the symbol. If the symbol
404 // does not exists, log error and return SymbolTable::kNoSymbol
Find(const string & symbol)405 virtual int64 Find(const string& symbol) const {
406 return impl_->Find(symbol);
407 }
408
409 // Return the key associated with the symbol. If the symbol
410 // does not exists, log error and return SymbolTable::kNoSymbol
Find(const char * symbol)411 virtual int64 Find(const char* symbol) const {
412 return impl_->Find(symbol);
413 }
414
415 // Return the current available key (i.e highest key number+1) in
416 // the symbol table
AvailableKey(void)417 virtual int64 AvailableKey(void) const {
418 return impl_->AvailableKey();
419 }
420
421 // Return the current number of symbols in table (not necessarily
422 // equal to AvailableKey())
NumSymbols(void)423 virtual size_t NumSymbols(void) const {
424 return impl_->NumSymbols();
425 }
426
GetNthKey(ssize_t pos)427 virtual int64 GetNthKey(ssize_t pos) const {
428 return impl_->GetNthKey(pos);
429 }
430
431 private:
SymbolTable(SymbolTableImpl * impl)432 explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
433
MutateCheck()434 void MutateCheck() {
435 // Copy on write
436 if (impl_->RefCount() > 1) {
437 impl_->DecrRefCount();
438 impl_ = new SymbolTableImpl(*impl_);
439 }
440 }
441
Impl()442 const SymbolTableImpl* Impl() const {
443 return impl_;
444 }
445
446 private:
447 SymbolTableImpl* impl_;
448 };
449
450
451 //
452 // \class SymbolTableIterator
453 // \brief Iterator class for symbols in a symbol table
454 class SymbolTableIterator {
455 public:
SymbolTableIterator(const SymbolTable & table)456 SymbolTableIterator(const SymbolTable& table)
457 : table_(table),
458 pos_(0),
459 nsymbols_(table.NumSymbols()),
460 key_(table.GetNthKey(0)) { }
461
~SymbolTableIterator()462 ~SymbolTableIterator() { }
463
464 // is iterator done
Done(void)465 bool Done(void) {
466 return (pos_ == nsymbols_);
467 }
468
469 // return the Value() of the current symbol (int64 key)
Value(void)470 int64 Value(void) {
471 return key_;
472 }
473
474 // return the string of the current symbol
Symbol(void)475 string Symbol(void) {
476 return table_.Find(key_);
477 }
478
479 // advance iterator forward
Next(void)480 void Next(void) {
481 ++pos_;
482 if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
483 }
484
485 // reset iterator
Reset(void)486 void Reset(void) {
487 pos_ = 0;
488 key_ = table_.GetNthKey(0);
489 }
490
491 private:
492 const SymbolTable& table_;
493 ssize_t pos_;
494 size_t nsymbols_;
495 int64 key_;
496 };
497
498
499 // Tests compatibilty between two sets of symbol tables
500 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
501 bool warning = true) {
502 if (!FLAGS_fst_compat_symbols) {
503 return true;
504 } else if (!syms1 && !syms2) {
505 return true;
506 } else if (syms1 && !syms2) {
507 if (warning)
508 LOG(WARNING) <<
509 "CompatSymbols: first symbol table present but second missing";
510 return false;
511 } else if (!syms1 && syms2) {
512 if (warning)
513 LOG(WARNING) <<
514 "CompatSymbols: second symbol table present but first missing";
515 return false;
516 } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
517 if (warning)
518 LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
519 return false;
520 } else {
521 return true;
522 }
523 }
524
525
526 // Relabels a symbol table as specified by the input vector of pairs
527 // (old label, new label). The new symbol table only retains symbols
528 // for which a relabeling is *explicitely* specified.
529 // TODO(allauzen): consider adding options to allow for some form
530 // of implicit identity relabeling.
531 template <class Label>
RelabelSymbolTable(const SymbolTable * table,const vector<pair<Label,Label>> & pairs)532 SymbolTable *RelabelSymbolTable(const SymbolTable *table,
533 const vector<pair<Label, Label> > &pairs) {
534 SymbolTable *new_table = new SymbolTable(
535 table->Name().empty() ? string() :
536 (string("relabeled_") + table->Name()));
537
538 for (size_t i = 0; i < pairs.size(); ++i)
539 new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
540
541 return new_table;
542 }
543
544 // Symbol Table Serialization
SymbolTableToString(const SymbolTable * table,string * result)545 inline void SymbolTableToString(const SymbolTable *table, string *result) {
546 ostringstream ostrm;
547 table->Write(ostrm);
548 *result = ostrm.str();
549 }
550
StringToSymbolTable(const string & s)551 inline SymbolTable *StringToSymbolTable(const string &s) {
552 istringstream istrm(s);
553 return SymbolTable::Read(istrm, SymbolTableReadOptions());
554 }
555
556 } // namespace fst
557
558 #endif // FST_LIB_SYMBOL_TABLE_H__
559