1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes for representing a bijective mapping between an arbitrary entry
19 // of type T and a signed integral ID.
20 
21 #ifndef FST_BI_TABLE_H_
22 #define FST_BI_TABLE_H_
23 
24 #include <deque>
25 #include <functional>
26 #include <memory>
27 #include <type_traits>
28 #include <unordered_set>
29 #include <vector>
30 
31 #include <fst/types.h>
32 #include <fst/log.h>
33 #include <fst/memory.h>
34 #include <fst/windows_defs.inc>
35 #include <unordered_map>
36 #include <unordered_set>
37 
38 namespace fst {
39 
40 // Bitables model bijective mappings between entries of an arbitrary type T and
41 // an signed integral ID of type I. The IDs are allocated starting from 0 in
42 // order.
43 //
44 // template <class I, class T>
45 // class BiTable {
46 //  public:
47 //
48 //   // Required constructors.
49 //   BiTable();
50 //
51 //   // Looks up integer ID from entry. If it doesn't exist and insert
52 //   / is true, adds it; otherwise, returns -1.
53 //   I FindId(const T &entry, bool insert = true);
54 //
55 //   // Looks up entry from integer ID.
56 //   const T &FindEntry(I) const;
57 //
58 //   // Returns number of stored entries.
59 //   I Size() const;
60 // };
61 
62 // An implementation using a hash map for the entry to ID mapping. H is the
63 // hash function and E is the equality function.
64 template <class I, class T, class H, class E = std::equal_to<T>>
65 class HashBiTable {
66  public:
67   // Reserves space for table_size elements.
68   explicit HashBiTable(size_t table_size = 0, const H &h = H(),
69                        const E &e = E())
hash_func_(h)70       : hash_func_(h),
71         hash_equal_(e),
72         entry2id_(table_size, hash_func_, hash_equal_) {
73     if (table_size) id2entry_.reserve(table_size);
74   }
75 
HashBiTable(const HashBiTable<I,T,H,E> & table)76   HashBiTable(const HashBiTable<I, T, H, E> &table)
77       : hash_func_(table.hash_func_),
78         hash_equal_(table.hash_equal_),
79         entry2id_(table.entry2id_.begin(), table.entry2id_.end(),
80                   table.entry2id_.size(), hash_func_, hash_equal_),
81         id2entry_(table.id2entry_) {}
82 
83   I FindId(const T &entry, bool insert = true) {
84     if (!insert) {
85       const auto it = entry2id_.find(entry);
86       return it == entry2id_.end() ? -1 : it->second - 1;
87     }
88     I &id_ref = entry2id_[entry];
89     if (id_ref == 0) {  // T not found; stores and assigns a new ID.
90       id2entry_.push_back(entry);
91       id_ref = id2entry_.size();
92     }
93     return id_ref - 1;  // NB: id_ref = ID + 1.
94   }
95 
FindEntry(I s)96   const T &FindEntry(I s) const { return id2entry_[s]; }
97 
Size()98   I Size() const { return id2entry_.size(); }
99 
100   // TODO(riley): Add fancy clear-to-size, as in CompactHashBiTable.
Clear()101   void Clear() {
102     entry2id_.clear();
103     id2entry_.clear();
104   }
105 
106  private:
107   H hash_func_;
108   E hash_equal_;
109   std::unordered_map<T, I, H, E> entry2id_;
110   std::vector<T> id2entry_;
111 };
112 
113 // Enables alternative hash set representations below.
114 enum HSType { HS_STL, HS_FLAT };
115 
116 // Default hash set is STL hash_set.
117 template <class K, class H, class E, HSType HS>
118 struct HashSet : public std::unordered_set<K, H, E, PoolAllocator<K>> {
119   explicit HashSet(size_t n = 0, const H &h = H(), const E &e = E())
120       : std::unordered_set<K, H, E, PoolAllocator<K>>(n, h, e) {}
121 
rehashHashSet122   void rehash(size_t n) {}
123 };
124 
125 // An implementation using a hash set for the entry to ID mapping. The hash set
126 // holds keys which are either the ID or kCurrentKey. These keys can be mapped
127 // to entries either by looking up in the entry vector or, if kCurrentKey, in
128 // current_entry_. The hash and key equality functions map to entries first. H
129 // is the hash function and E is the equality function.
130 template <class I, class T, class H, class E = std::equal_to<T>,
131           HSType HS = HS_FLAT>
132 class CompactHashBiTable {
133   static_assert(HS == HS_STL || HS == HS_FLAT, "Unsupported hash set type");
134 
135  public:
136   friend class HashFunc;
137   friend class HashEqual;
138 
139   // Reserves space for table_size elements.
140   explicit CompactHashBiTable(size_t table_size = 0, const H &h = H(),
141                               const E &e = E())
hash_func_(h)142       : hash_func_(h),
143         hash_equal_(e),
144         compact_hash_func_(*this),
145         compact_hash_equal_(*this),
146         keys_(table_size, compact_hash_func_, compact_hash_equal_) {
147     if (table_size) id2entry_.reserve(table_size);
148   }
149 
CompactHashBiTable(const CompactHashBiTable<I,T,H,E,HS> & table)150   CompactHashBiTable(const CompactHashBiTable<I, T, H, E, HS> &table)
151       : hash_func_(table.hash_func_),
152         hash_equal_(table.hash_equal_),
153         compact_hash_func_(*this),
154         compact_hash_equal_(*this),
155         keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_),
156         id2entry_(table.id2entry_) {
157     keys_.insert(table.keys_.begin(), table.keys_.end());
158   }
159 
160   I FindId(const T &entry, bool insert = true) {
161     current_entry_ = &entry;
162     if (insert) {
163       auto result = keys_.insert(kCurrentKey);
164       if (!result.second) return *result.first;  // Already exists.
165       // Overwrites kCurrentKey with a new key value; this is safe because it
166       // doesn't affect hashing or equality testing.
167       I key = id2entry_.size();
168       const_cast<I &>(*result.first) = key;
169       id2entry_.push_back(entry);
170       return key;
171     }
172     const auto it = keys_.find(kCurrentKey);
173     return it == keys_.end() ? -1 : *it;
174   }
175 
FindEntry(I s)176   const T &FindEntry(I s) const { return id2entry_[s]; }
177 
Size()178   I Size() const { return id2entry_.size(); }
179 
180   // Clears content; with argument, erases last n IDs.
181   void Clear(ssize_t n = -1) {
182     if (n < 0 || n >= id2entry_.size()) {  // Clears completely.
183       keys_.clear();
184       id2entry_.clear();
185     } else if (n == id2entry_.size() - 1) {  // Leaves only key 0.
186       const T entry = FindEntry(0);
187       keys_.clear();
188       id2entry_.clear();
189       FindId(entry, true);
190     } else {
191       while (n-- > 0) {
192         I key = id2entry_.size() - 1;
193         keys_.erase(key);
194         id2entry_.pop_back();
195       }
196       keys_.rehash(0);
197     }
198   }
199 
200  private:
201   static_assert(std::is_signed<I>::value, "I must be a signed type");
202   // ... otherwise >= kCurrentKey comparisons as used below don't work.
203   // TODO(rybach): (1) don't use >= for key comparison, (2) allow unsigned key
204   // types.
205   static constexpr I kCurrentKey = -1;
206 
207   class HashFunc {
208    public:
HashFunc(const CompactHashBiTable & ht)209     explicit HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {}
210 
operator()211     size_t operator()(I k) const {
212       if (k >= kCurrentKey) {
213         return (ht_->hash_func_)(ht_->Key2Entry(k));
214       } else {
215         return 0;
216       }
217     }
218 
219    private:
220     const CompactHashBiTable *ht_;
221   };
222 
223   class HashEqual {
224    public:
HashEqual(const CompactHashBiTable & ht)225     explicit HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {}
226 
operator()227     bool operator()(I k1, I k2) const {
228       if (k1 == k2) {
229         return true;
230       } else if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
231         return (ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2));
232       } else {
233         return false;
234       }
235     }
236 
237    private:
238     const CompactHashBiTable *ht_;
239   };
240 
241   using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>;
242 
Key2Entry(I k)243   const T &Key2Entry(I k) const {
244     if (k == kCurrentKey) {
245       return *current_entry_;
246     } else {
247       return id2entry_[k];
248     }
249   }
250 
251   H hash_func_;
252   E hash_equal_;
253   HashFunc compact_hash_func_;
254   HashEqual compact_hash_equal_;
255   KeyHashSet keys_;
256   std::vector<T> id2entry_;
257   const T *current_entry_;
258 };
259 
260 template <class I, class T, class H, class E, HSType HS>
261 constexpr I CompactHashBiTable<I, T, H, E, HS>::kCurrentKey;
262 
263 // An implementation using a vector for the entry to ID mapping. It is passed a
264 // function object FP that should fingerprint entries uniquely to an integer
265 // that can used as a vector index. Normally, VectorBiTable constructs the FP
266 // object. The user can instead pass in this object.
267 template <class I, class T, class FP>
268 class VectorBiTable {
269  public:
270   // Reserves table_size cells of space.
fp_(fp)271   explicit VectorBiTable(const FP &fp = FP(), size_t table_size = 0) : fp_(fp) {
272     if (table_size) id2entry_.reserve(table_size);
273   }
274 
VectorBiTable(const VectorBiTable<I,T,FP> & table)275   VectorBiTable(const VectorBiTable<I, T, FP> &table)
276       : fp_(table.fp_), fp2id_(table.fp2id_), id2entry_(table.id2entry_) {}
277 
278   I FindId(const T &entry, bool insert = true) {
279     ssize_t fp = (fp_)(entry);
280     if (fp >= fp2id_.size()) fp2id_.resize(fp + 1);
281     I &id_ref = fp2id_[fp];
282     if (id_ref == 0) {  // T not found.
283       if (insert) {     // Stores and assigns a new ID.
284         id2entry_.push_back(entry);
285         id_ref = id2entry_.size();
286       } else {
287         return -1;
288       }
289     }
290     return id_ref - 1;  // NB: id_ref = ID + 1.
291   }
292 
FindEntry(I s)293   const T &FindEntry(I s) const { return id2entry_[s]; }
294 
Size()295   I Size() const { return id2entry_.size(); }
296 
Fingerprint()297   const FP &Fingerprint() const { return fp_; }
298 
299  private:
300   FP fp_;
301   std::vector<I> fp2id_;
302   std::vector<T> id2entry_;
303 };
304 
305 // An implementation using a vector and a compact hash table. The selecting
306 // functor S returns true for entries to be hashed in the vector. The
307 // fingerprinting functor FP returns a unique fingerprint for each entry to be
308 // hashed in the vector (these need to be suitable for indexing in a vector).
309 // The hash functor H is used when hashing entry into the compact hash table.
310 template <class I, class T, class S, class FP, class H, HSType HS = HS_FLAT>
311 class VectorHashBiTable {
312  public:
313   friend class HashFunc;
314   friend class HashEqual;
315 
316   explicit VectorHashBiTable(const S &s = S(), const FP &fp = FP(),
317                              const H &h = H(), size_t vector_size = 0,
318                              size_t entry_size = 0)
selector_(s)319       : selector_(s),
320         fp_(fp),
321         h_(h),
322         hash_func_(*this),
323         hash_equal_(*this),
324         keys_(0, hash_func_, hash_equal_) {
325     if (vector_size) fp2id_.reserve(vector_size);
326     if (entry_size) id2entry_.reserve(entry_size);
327   }
328 
VectorHashBiTable(const VectorHashBiTable<I,T,S,FP,H,HS> & table)329   VectorHashBiTable(const VectorHashBiTable<I, T, S, FP, H, HS> &table)
330       : selector_(table.s_),
331         fp_(table.fp_),
332         h_(table.h_),
333         id2entry_(table.id2entry_),
334         fp2id_(table.fp2id_),
335         hash_func_(*this),
336         hash_equal_(*this),
337         keys_(table.keys_.size(), hash_func_, hash_equal_) {
338     keys_.insert(table.keys_.begin(), table.keys_.end());
339   }
340 
341   I FindId(const T &entry, bool insert = true) {
342     if ((selector_)(entry)) {  // Uses the vector if selector_(entry) == true.
343       uint64 fp = (fp_)(entry);
344       if (fp2id_.size() <= fp) fp2id_.resize(fp + 1, 0);
345       if (fp2id_[fp] == 0) {  // T not found.
346         if (insert) {         // Stores and assigns a new ID.
347           id2entry_.push_back(entry);
348           fp2id_[fp] = id2entry_.size();
349         } else {
350           return -1;
351         }
352       }
353       return fp2id_[fp] - 1;  // NB: assoc_value = ID + 1.
354     } else {                  // Uses the hash table otherwise.
355       current_entry_ = &entry;
356       const auto it = keys_.find(kCurrentKey);
357       if (it == keys_.end()) {
358         if (insert) {
359           I key = id2entry_.size();
360           id2entry_.push_back(entry);
361           keys_.insert(key);
362           return key;
363         } else {
364           return -1;
365         }
366       } else {
367         return *it;
368       }
369     }
370   }
371 
FindEntry(I s)372   const T &FindEntry(I s) const { return id2entry_[s]; }
373 
Size()374   I Size() const { return id2entry_.size(); }
375 
Selector()376   const S &Selector() const { return selector_; }
377 
Fingerprint()378   const FP &Fingerprint() const { return fp_; }
379 
Hash()380   const H &Hash() const { return h_; }
381 
382  private:
383   static constexpr I kCurrentKey = -1;
384 
385   class HashFunc {
386    public:
HashFunc(const VectorHashBiTable & ht)387     explicit HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {}
388 
operator()389     size_t operator()(I k) const {
390       if (k >= kCurrentKey) {
391         return (ht_->h_)(ht_->Key2Entry(k));
392       } else {
393         return 0;
394       }
395     }
396 
397    private:
398     const VectorHashBiTable *ht_;
399   };
400 
401   class HashEqual {
402    public:
HashEqual(const VectorHashBiTable & ht)403     explicit HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {}
404 
operator()405     bool operator()(I k1, I k2) const {
406       if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
407         return ht_->Key2Entry(k1) == ht_->Key2Entry(k2);
408       } else {
409         return k1 == k2;
410       }
411     }
412 
413    private:
414     const VectorHashBiTable *ht_;
415   };
416 
417   using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>;
418 
Key2Entry(I k)419   const T &Key2Entry(I k) const {
420     if (k == kCurrentKey) {
421       return *current_entry_;
422     } else {
423       return id2entry_[k];
424     }
425   }
426 
427   S selector_;  // True if entry hashed into vector.
428   FP fp_;       // Fingerprint used for hashing into vector.
429   H h_;         // Hash funcion used for hashing into hash_set.
430 
431   std::vector<T> id2entry_;  // Maps state IDs to entry.
432   std::vector<I> fp2id_;     // Maps entry fingerprints to IDs.
433 
434   // Compact implementation of the hash table mapping entries to state IDs
435   // using the hash function h_.
436   HashFunc hash_func_;
437   HashEqual hash_equal_;
438   KeyHashSet keys_;
439   const T *current_entry_;
440 };
441 
442 template <class I, class T, class S, class FP, class H, HSType HS>
443 constexpr I VectorHashBiTable<I, T, S, FP, H, HS>::kCurrentKey;
444 
445 // An implementation using a hash map for the entry to ID mapping. This version
446 // permits erasing of arbitrary states. The entry T must have == defined and
447 // its default constructor must produce a entry that will never be seen. F is
448 // the hash function.
449 template <class I, class T, class F>
450 class ErasableBiTable {
451  public:
ErasableBiTable()452   ErasableBiTable() : first_(0) {}
453 
454   I FindId(const T &entry, bool insert = true) {
455     I &id_ref = entry2id_[entry];
456     if (id_ref == 0) {  // T not found.
457       if (insert) {     // Stores and assigns a new ID.
458         id2entry_.push_back(entry);
459         id_ref = id2entry_.size() + first_;
460       } else {
461         return -1;
462       }
463     }
464     return id_ref - 1;  // NB: id_ref = ID + 1.
465   }
466 
FindEntry(I s)467   const T &FindEntry(I s) const { return id2entry_[s - first_]; }
468 
Size()469   I Size() const { return id2entry_.size(); }
470 
Erase(I s)471   void Erase(I s) {
472     auto &ref = id2entry_[s - first_];
473     entry2id_.erase(ref);
474     ref = empty_entry_;
475     while (!id2entry_.empty() && id2entry_.front() == empty_entry_) {
476       id2entry_.pop_front();
477       ++first_;
478     }
479   }
480 
481  private:
482   std::unordered_map<T, I, F> entry2id_;
483   std::deque<T> id2entry_;
484   const T empty_entry_;
485   I first_;  // I of first element in the deque.
486 };
487 
488 }  // namespace fst
489 
490 #endif  // FST_BI_TABLE_H_
491