1 // Copyright 2005-2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the 'License'); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an 'AS IS' BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // See www.openfst.org for extensive documentation on this weighted 16 // finite-state transducer library. 17 // 18 // Classes for representing a bijective mapping between an arbitrary entry 19 // of type T and a signed integral ID. 20 21 #ifndef FST_BI_TABLE_H_ 22 #define FST_BI_TABLE_H_ 23 24 #include <deque> 25 #include <functional> 26 #include <memory> 27 #include <type_traits> 28 #include <unordered_set> 29 #include <vector> 30 31 #include <fst/types.h> 32 #include <fst/log.h> 33 #include <fst/memory.h> 34 #include <fst/windows_defs.inc> 35 #include <unordered_map> 36 #include <unordered_set> 37 38 namespace fst { 39 40 // Bitables model bijective mappings between entries of an arbitrary type T and 41 // an signed integral ID of type I. The IDs are allocated starting from 0 in 42 // order. 43 // 44 // template <class I, class T> 45 // class BiTable { 46 // public: 47 // 48 // // Required constructors. 49 // BiTable(); 50 // 51 // // Looks up integer ID from entry. If it doesn't exist and insert 52 // / is true, adds it; otherwise, returns -1. 53 // I FindId(const T &entry, bool insert = true); 54 // 55 // // Looks up entry from integer ID. 56 // const T &FindEntry(I) const; 57 // 58 // // Returns number of stored entries. 59 // I Size() const; 60 // }; 61 62 // An implementation using a hash map for the entry to ID mapping. H is the 63 // hash function and E is the equality function. 64 template <class I, class T, class H, class E = std::equal_to<T>> 65 class HashBiTable { 66 public: 67 // Reserves space for table_size elements. 68 explicit HashBiTable(size_t table_size = 0, const H &h = H(), 69 const E &e = E()) hash_func_(h)70 : hash_func_(h), 71 hash_equal_(e), 72 entry2id_(table_size, hash_func_, hash_equal_) { 73 if (table_size) id2entry_.reserve(table_size); 74 } 75 HashBiTable(const HashBiTable<I,T,H,E> & table)76 HashBiTable(const HashBiTable<I, T, H, E> &table) 77 : hash_func_(table.hash_func_), 78 hash_equal_(table.hash_equal_), 79 entry2id_(table.entry2id_.begin(), table.entry2id_.end(), 80 table.entry2id_.size(), hash_func_, hash_equal_), 81 id2entry_(table.id2entry_) {} 82 83 I FindId(const T &entry, bool insert = true) { 84 if (!insert) { 85 const auto it = entry2id_.find(entry); 86 return it == entry2id_.end() ? -1 : it->second - 1; 87 } 88 I &id_ref = entry2id_[entry]; 89 if (id_ref == 0) { // T not found; stores and assigns a new ID. 90 id2entry_.push_back(entry); 91 id_ref = id2entry_.size(); 92 } 93 return id_ref - 1; // NB: id_ref = ID + 1. 94 } 95 FindEntry(I s)96 const T &FindEntry(I s) const { return id2entry_[s]; } 97 Size()98 I Size() const { return id2entry_.size(); } 99 100 // TODO(riley): Add fancy clear-to-size, as in CompactHashBiTable. Clear()101 void Clear() { 102 entry2id_.clear(); 103 id2entry_.clear(); 104 } 105 106 private: 107 H hash_func_; 108 E hash_equal_; 109 std::unordered_map<T, I, H, E> entry2id_; 110 std::vector<T> id2entry_; 111 }; 112 113 // Enables alternative hash set representations below. 114 enum HSType { HS_STL, HS_FLAT }; 115 116 // Default hash set is STL hash_set. 117 template <class K, class H, class E, HSType HS> 118 struct HashSet : public std::unordered_set<K, H, E, PoolAllocator<K>> { 119 explicit HashSet(size_t n = 0, const H &h = H(), const E &e = E()) 120 : std::unordered_set<K, H, E, PoolAllocator<K>>(n, h, e) {} 121 rehashHashSet122 void rehash(size_t n) {} 123 }; 124 125 // An implementation using a hash set for the entry to ID mapping. The hash set 126 // holds keys which are either the ID or kCurrentKey. These keys can be mapped 127 // to entries either by looking up in the entry vector or, if kCurrentKey, in 128 // current_entry_. The hash and key equality functions map to entries first. H 129 // is the hash function and E is the equality function. 130 template <class I, class T, class H, class E = std::equal_to<T>, 131 HSType HS = HS_FLAT> 132 class CompactHashBiTable { 133 static_assert(HS == HS_STL || HS == HS_FLAT, "Unsupported hash set type"); 134 135 public: 136 friend class HashFunc; 137 friend class HashEqual; 138 139 // Reserves space for table_size elements. 140 explicit CompactHashBiTable(size_t table_size = 0, const H &h = H(), 141 const E &e = E()) hash_func_(h)142 : hash_func_(h), 143 hash_equal_(e), 144 compact_hash_func_(*this), 145 compact_hash_equal_(*this), 146 keys_(table_size, compact_hash_func_, compact_hash_equal_) { 147 if (table_size) id2entry_.reserve(table_size); 148 } 149 CompactHashBiTable(const CompactHashBiTable<I,T,H,E,HS> & table)150 CompactHashBiTable(const CompactHashBiTable<I, T, H, E, HS> &table) 151 : hash_func_(table.hash_func_), 152 hash_equal_(table.hash_equal_), 153 compact_hash_func_(*this), 154 compact_hash_equal_(*this), 155 keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_), 156 id2entry_(table.id2entry_) { 157 keys_.insert(table.keys_.begin(), table.keys_.end()); 158 } 159 160 I FindId(const T &entry, bool insert = true) { 161 current_entry_ = &entry; 162 if (insert) { 163 auto result = keys_.insert(kCurrentKey); 164 if (!result.second) return *result.first; // Already exists. 165 // Overwrites kCurrentKey with a new key value; this is safe because it 166 // doesn't affect hashing or equality testing. 167 I key = id2entry_.size(); 168 const_cast<I &>(*result.first) = key; 169 id2entry_.push_back(entry); 170 return key; 171 } 172 const auto it = keys_.find(kCurrentKey); 173 return it == keys_.end() ? -1 : *it; 174 } 175 FindEntry(I s)176 const T &FindEntry(I s) const { return id2entry_[s]; } 177 Size()178 I Size() const { return id2entry_.size(); } 179 180 // Clears content; with argument, erases last n IDs. 181 void Clear(ssize_t n = -1) { 182 if (n < 0 || n >= id2entry_.size()) { // Clears completely. 183 keys_.clear(); 184 id2entry_.clear(); 185 } else if (n == id2entry_.size() - 1) { // Leaves only key 0. 186 const T entry = FindEntry(0); 187 keys_.clear(); 188 id2entry_.clear(); 189 FindId(entry, true); 190 } else { 191 while (n-- > 0) { 192 I key = id2entry_.size() - 1; 193 keys_.erase(key); 194 id2entry_.pop_back(); 195 } 196 keys_.rehash(0); 197 } 198 } 199 200 private: 201 static_assert(std::is_signed<I>::value, "I must be a signed type"); 202 // ... otherwise >= kCurrentKey comparisons as used below don't work. 203 // TODO(rybach): (1) don't use >= for key comparison, (2) allow unsigned key 204 // types. 205 static constexpr I kCurrentKey = -1; 206 207 class HashFunc { 208 public: HashFunc(const CompactHashBiTable & ht)209 explicit HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {} 210 operator()211 size_t operator()(I k) const { 212 if (k >= kCurrentKey) { 213 return (ht_->hash_func_)(ht_->Key2Entry(k)); 214 } else { 215 return 0; 216 } 217 } 218 219 private: 220 const CompactHashBiTable *ht_; 221 }; 222 223 class HashEqual { 224 public: HashEqual(const CompactHashBiTable & ht)225 explicit HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {} 226 operator()227 bool operator()(I k1, I k2) const { 228 if (k1 == k2) { 229 return true; 230 } else if (k1 >= kCurrentKey && k2 >= kCurrentKey) { 231 return (ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2)); 232 } else { 233 return false; 234 } 235 } 236 237 private: 238 const CompactHashBiTable *ht_; 239 }; 240 241 using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>; 242 Key2Entry(I k)243 const T &Key2Entry(I k) const { 244 if (k == kCurrentKey) { 245 return *current_entry_; 246 } else { 247 return id2entry_[k]; 248 } 249 } 250 251 H hash_func_; 252 E hash_equal_; 253 HashFunc compact_hash_func_; 254 HashEqual compact_hash_equal_; 255 KeyHashSet keys_; 256 std::vector<T> id2entry_; 257 const T *current_entry_; 258 }; 259 260 template <class I, class T, class H, class E, HSType HS> 261 constexpr I CompactHashBiTable<I, T, H, E, HS>::kCurrentKey; 262 263 // An implementation using a vector for the entry to ID mapping. It is passed a 264 // function object FP that should fingerprint entries uniquely to an integer 265 // that can used as a vector index. Normally, VectorBiTable constructs the FP 266 // object. The user can instead pass in this object. 267 template <class I, class T, class FP> 268 class VectorBiTable { 269 public: 270 // Reserves table_size cells of space. fp_(fp)271 explicit VectorBiTable(const FP &fp = FP(), size_t table_size = 0) : fp_(fp) { 272 if (table_size) id2entry_.reserve(table_size); 273 } 274 VectorBiTable(const VectorBiTable<I,T,FP> & table)275 VectorBiTable(const VectorBiTable<I, T, FP> &table) 276 : fp_(table.fp_), fp2id_(table.fp2id_), id2entry_(table.id2entry_) {} 277 278 I FindId(const T &entry, bool insert = true) { 279 ssize_t fp = (fp_)(entry); 280 if (fp >= fp2id_.size()) fp2id_.resize(fp + 1); 281 I &id_ref = fp2id_[fp]; 282 if (id_ref == 0) { // T not found. 283 if (insert) { // Stores and assigns a new ID. 284 id2entry_.push_back(entry); 285 id_ref = id2entry_.size(); 286 } else { 287 return -1; 288 } 289 } 290 return id_ref - 1; // NB: id_ref = ID + 1. 291 } 292 FindEntry(I s)293 const T &FindEntry(I s) const { return id2entry_[s]; } 294 Size()295 I Size() const { return id2entry_.size(); } 296 Fingerprint()297 const FP &Fingerprint() const { return fp_; } 298 299 private: 300 FP fp_; 301 std::vector<I> fp2id_; 302 std::vector<T> id2entry_; 303 }; 304 305 // An implementation using a vector and a compact hash table. The selecting 306 // functor S returns true for entries to be hashed in the vector. The 307 // fingerprinting functor FP returns a unique fingerprint for each entry to be 308 // hashed in the vector (these need to be suitable for indexing in a vector). 309 // The hash functor H is used when hashing entry into the compact hash table. 310 template <class I, class T, class S, class FP, class H, HSType HS = HS_FLAT> 311 class VectorHashBiTable { 312 public: 313 friend class HashFunc; 314 friend class HashEqual; 315 316 explicit VectorHashBiTable(const S &s = S(), const FP &fp = FP(), 317 const H &h = H(), size_t vector_size = 0, 318 size_t entry_size = 0) selector_(s)319 : selector_(s), 320 fp_(fp), 321 h_(h), 322 hash_func_(*this), 323 hash_equal_(*this), 324 keys_(0, hash_func_, hash_equal_) { 325 if (vector_size) fp2id_.reserve(vector_size); 326 if (entry_size) id2entry_.reserve(entry_size); 327 } 328 VectorHashBiTable(const VectorHashBiTable<I,T,S,FP,H,HS> & table)329 VectorHashBiTable(const VectorHashBiTable<I, T, S, FP, H, HS> &table) 330 : selector_(table.s_), 331 fp_(table.fp_), 332 h_(table.h_), 333 id2entry_(table.id2entry_), 334 fp2id_(table.fp2id_), 335 hash_func_(*this), 336 hash_equal_(*this), 337 keys_(table.keys_.size(), hash_func_, hash_equal_) { 338 keys_.insert(table.keys_.begin(), table.keys_.end()); 339 } 340 341 I FindId(const T &entry, bool insert = true) { 342 if ((selector_)(entry)) { // Uses the vector if selector_(entry) == true. 343 uint64 fp = (fp_)(entry); 344 if (fp2id_.size() <= fp) fp2id_.resize(fp + 1, 0); 345 if (fp2id_[fp] == 0) { // T not found. 346 if (insert) { // Stores and assigns a new ID. 347 id2entry_.push_back(entry); 348 fp2id_[fp] = id2entry_.size(); 349 } else { 350 return -1; 351 } 352 } 353 return fp2id_[fp] - 1; // NB: assoc_value = ID + 1. 354 } else { // Uses the hash table otherwise. 355 current_entry_ = &entry; 356 const auto it = keys_.find(kCurrentKey); 357 if (it == keys_.end()) { 358 if (insert) { 359 I key = id2entry_.size(); 360 id2entry_.push_back(entry); 361 keys_.insert(key); 362 return key; 363 } else { 364 return -1; 365 } 366 } else { 367 return *it; 368 } 369 } 370 } 371 FindEntry(I s)372 const T &FindEntry(I s) const { return id2entry_[s]; } 373 Size()374 I Size() const { return id2entry_.size(); } 375 Selector()376 const S &Selector() const { return selector_; } 377 Fingerprint()378 const FP &Fingerprint() const { return fp_; } 379 Hash()380 const H &Hash() const { return h_; } 381 382 private: 383 static constexpr I kCurrentKey = -1; 384 385 class HashFunc { 386 public: HashFunc(const VectorHashBiTable & ht)387 explicit HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {} 388 operator()389 size_t operator()(I k) const { 390 if (k >= kCurrentKey) { 391 return (ht_->h_)(ht_->Key2Entry(k)); 392 } else { 393 return 0; 394 } 395 } 396 397 private: 398 const VectorHashBiTable *ht_; 399 }; 400 401 class HashEqual { 402 public: HashEqual(const VectorHashBiTable & ht)403 explicit HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {} 404 operator()405 bool operator()(I k1, I k2) const { 406 if (k1 >= kCurrentKey && k2 >= kCurrentKey) { 407 return ht_->Key2Entry(k1) == ht_->Key2Entry(k2); 408 } else { 409 return k1 == k2; 410 } 411 } 412 413 private: 414 const VectorHashBiTable *ht_; 415 }; 416 417 using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>; 418 Key2Entry(I k)419 const T &Key2Entry(I k) const { 420 if (k == kCurrentKey) { 421 return *current_entry_; 422 } else { 423 return id2entry_[k]; 424 } 425 } 426 427 S selector_; // True if entry hashed into vector. 428 FP fp_; // Fingerprint used for hashing into vector. 429 H h_; // Hash funcion used for hashing into hash_set. 430 431 std::vector<T> id2entry_; // Maps state IDs to entry. 432 std::vector<I> fp2id_; // Maps entry fingerprints to IDs. 433 434 // Compact implementation of the hash table mapping entries to state IDs 435 // using the hash function h_. 436 HashFunc hash_func_; 437 HashEqual hash_equal_; 438 KeyHashSet keys_; 439 const T *current_entry_; 440 }; 441 442 template <class I, class T, class S, class FP, class H, HSType HS> 443 constexpr I VectorHashBiTable<I, T, S, FP, H, HS>::kCurrentKey; 444 445 // An implementation using a hash map for the entry to ID mapping. This version 446 // permits erasing of arbitrary states. The entry T must have == defined and 447 // its default constructor must produce a entry that will never be seen. F is 448 // the hash function. 449 template <class I, class T, class F> 450 class ErasableBiTable { 451 public: ErasableBiTable()452 ErasableBiTable() : first_(0) {} 453 454 I FindId(const T &entry, bool insert = true) { 455 I &id_ref = entry2id_[entry]; 456 if (id_ref == 0) { // T not found. 457 if (insert) { // Stores and assigns a new ID. 458 id2entry_.push_back(entry); 459 id_ref = id2entry_.size() + first_; 460 } else { 461 return -1; 462 } 463 } 464 return id_ref - 1; // NB: id_ref = ID + 1. 465 } 466 FindEntry(I s)467 const T &FindEntry(I s) const { return id2entry_[s - first_]; } 468 Size()469 I Size() const { return id2entry_.size(); } 470 Erase(I s)471 void Erase(I s) { 472 auto &ref = id2entry_[s - first_]; 473 entry2id_.erase(ref); 474 ref = empty_entry_; 475 while (!id2entry_.empty() && id2entry_.front() == empty_entry_) { 476 id2entry_.pop_front(); 477 ++first_; 478 } 479 } 480 481 private: 482 std::unordered_map<T, I, F> entry2id_; 483 std::deque<T> id2entry_; 484 const T empty_entry_; 485 I first_; // I of first element in the deque. 486 }; 487 488 } // namespace fst 489 490 #endif // FST_BI_TABLE_H_ 491