1 /*++ 2 Copyright (c) 2013 Microsoft Corporation 3 4 Module Name: 5 6 heap_trie.h 7 8 Abstract: 9 10 Heap trie structure. 11 12 Structure that lets you retrieve point-wise smaller entries 13 of a tuple. A lookup is to identify entries whose keys 14 are point-wise dominated by the lookup key. 15 16 Author: 17 18 Nikolaj Bjorner (nbjorner) 2013-02-15. 19 20 Notes: 21 22 tries are unordered vectors of keys. This could be enhanced to use either 23 heaps or sorting. The problem with using the heap implementation directly is that there is no way to 24 retrieve elements less or equal to a key that is not already in the heap. 25 If nodes have only a few elements, then this would also be a bloated data-structure to maintain. 26 27 Nodes are not de-allocated. Their reference count indicates if they are valid. 28 Possibly, add garbage collection. 29 30 Maintaining sorted ranges for larger domains is another option. 31 32 Another possible enhancement is to resplay the tree. 33 Keep current key index in the nodes. 34 35 --*/ 36 37 #pragma once 38 39 #include <cstring> 40 #include "util/map.h" 41 #include "util/vector.h" 42 #include "util/buffer.h" 43 #include "util/statistics.h" 44 #include "util/small_object_allocator.h" 45 46 template<typename Key, typename KeyLE, typename KeyHash, typename Value> 47 class heap_trie { 48 49 struct stats { 50 unsigned m_num_inserts; 51 unsigned m_num_removes; 52 unsigned m_num_find_eq; 53 unsigned m_num_find_le; 54 unsigned m_num_find_le_nodes; statsstats55 stats() { reset(); } resetstats56 void reset() { memset(this, 0, sizeof(*this)); } 57 }; 58 59 enum node_t { 60 trie_t, 61 leaf_t 62 }; 63 64 class node { 65 node_t m_type; 66 unsigned m_ref; 67 public: node(node_t t)68 node(node_t t): m_type(t), m_ref(0) {} ~node()69 virtual ~node() {} type()70 node_t type() const { return m_type; } inc_ref()71 void inc_ref() { ++m_ref; } dec_ref()72 void dec_ref() { SASSERT(m_ref > 0); --m_ref; } ref_count()73 unsigned ref_count() const { return m_ref; } 74 virtual void display(std::ostream& out, unsigned indent) const = 0; 75 virtual unsigned num_nodes() const = 0; 76 virtual unsigned num_leaves() const = 0; 77 }; 78 79 class leaf : public node { 80 Value m_value; 81 public: leaf()82 leaf(): node(leaf_t) {} ~leaf()83 ~leaf() override {} get_value()84 Value const& get_value() const { return m_value; } set_value(Value const & v)85 void set_value(Value const& v) { m_value = v; } display(std::ostream & out,unsigned indent)86 void display(std::ostream& out, unsigned indent) const override { 87 out << " value: " << m_value; 88 } num_nodes()89 unsigned num_nodes() const override { return 1; } num_leaves()90 unsigned num_leaves() const override { return this->ref_count()>0?1:0; } 91 }; 92 93 typedef buffer<std::pair<Key,node*>, true, 2> children_t; 94 95 // lean trie node 96 class trie : public node { 97 children_t m_nodes; 98 public: trie()99 trie(): node(trie_t) {} 100 ~trie()101 ~trie() override { 102 } 103 find_or_insert(Key k,node * n)104 node* find_or_insert(Key k, node* n) { 105 for (unsigned i = 0; i < m_nodes.size(); ++i) { 106 if (m_nodes[i].first == k) { 107 return m_nodes[i].second; 108 } 109 } 110 m_nodes.push_back(std::make_pair(k, n)); 111 return n; 112 } 113 find(Key k,node * & n)114 bool find(Key k, node*& n) const { 115 for (unsigned i = 0; i < m_nodes.size(); ++i) { 116 if (m_nodes[i].first == k) { 117 n = m_nodes[i].second; 118 return n->ref_count() > 0; 119 } 120 } 121 return false; 122 } 123 124 // push nodes whose keys are <= key into vector. find_le(KeyLE & le,Key key,ptr_vector<node> & nodes)125 void find_le(KeyLE& le, Key key, ptr_vector<node>& nodes) { 126 for (unsigned i = 0; i < m_nodes.size(); ++i) { 127 if (le.le(m_nodes[i].first, key)) { 128 node* n = m_nodes[i].second; 129 if (n->ref_count() > 0){ 130 nodes.push_back(n); 131 } 132 } 133 } 134 } 135 nodes()136 children_t const& nodes() const { return m_nodes; } nodes()137 children_t & nodes() { return m_nodes; } 138 display(std::ostream & out,unsigned indent)139 void display(std::ostream& out, unsigned indent) const override { 140 for (unsigned j = 0; j < m_nodes.size(); ++j) { 141 if (j != 0 || indent > 0) { 142 out << "\n"; 143 } 144 for (unsigned i = 0; i < indent; ++i) { 145 out << " "; 146 } 147 node* n = m_nodes[j].second; 148 out << m_nodes[j].first << " refs: " << n->ref_count(); 149 n->display(out, indent + 1); 150 } 151 } 152 num_nodes()153 unsigned num_nodes() const override { 154 unsigned sz = 1; 155 for (unsigned j = 0; j < m_nodes.size(); ++j) { 156 sz += m_nodes[j].second->num_nodes(); 157 } 158 return sz; 159 } 160 num_leaves()161 unsigned num_leaves() const override { 162 unsigned sz = 0; 163 for (unsigned j = 0; j < m_nodes.size(); ++j) { 164 sz += m_nodes[j].second->num_leaves(); 165 } 166 return sz; 167 } 168 169 private: contains(Key k)170 bool contains(Key k) { 171 for (unsigned j = 0; j < m_nodes.size(); ++j) { 172 if (m_nodes[j].first == k) { 173 return true; 174 } 175 } 176 return false; 177 } 178 }; 179 180 small_object_allocator m_alloc; 181 KeyLE& m_le; 182 unsigned m_num_keys; 183 unsigned_vector m_keys; 184 unsigned m_do_reshuffle; 185 node* m_root; 186 stats m_stats; 187 node* m_spare_leaf; 188 node* m_spare_trie; 189 190 public: 191 heap_trie(KeyLE & le)192 heap_trie(KeyLE& le): 193 m_alloc("heap_trie"), 194 m_le(le), 195 m_num_keys(0), 196 m_do_reshuffle(4), 197 m_root(nullptr), 198 m_spare_leaf(nullptr), 199 m_spare_trie(nullptr) 200 {} 201 ~heap_trie()202 ~heap_trie() { 203 del_node(m_root); 204 del_node(m_spare_leaf); 205 del_node(m_spare_trie); 206 } 207 size()208 unsigned size() const { 209 return m_root?m_root->num_leaves():0; 210 } 211 reset(unsigned num_keys)212 void reset(unsigned num_keys) { 213 del_node(m_root); 214 del_node(m_spare_leaf); 215 del_node(m_spare_trie); 216 m_num_keys = num_keys; 217 m_keys.resize(num_keys); 218 for (unsigned i = 0; i < num_keys; ++i) { 219 m_keys[i] = i; 220 } 221 m_root = mk_trie(); 222 m_spare_trie = mk_trie(); 223 m_spare_leaf = mk_leaf(); 224 } 225 insert(Key const * keys,Value const & val)226 void insert(Key const* keys, Value const& val) { 227 ++m_stats.m_num_inserts; 228 insert(m_root, num_keys(), keys, m_keys.c_ptr(), val); 229 #if 0 230 if (m_stats.m_num_inserts == (1 << m_do_reshuffle)) { 231 m_do_reshuffle++; 232 reorder_keys(); 233 } 234 #endif 235 } 236 find_eq(Key const * keys,Value & value)237 bool find_eq(Key const* keys, Value& value) { 238 ++m_stats.m_num_find_eq; 239 node* n = m_root; 240 node* m; 241 for (unsigned i = 0; i < num_keys(); ++i) { 242 if (!to_trie(n)->find(get_key(keys, i), m)) { 243 return false; 244 } 245 n = m; 246 } 247 value = to_leaf(n)->get_value(); 248 return true; 249 } 250 find_all_le(Key const * keys,vector<Value> & values)251 void find_all_le(Key const* keys, vector<Value>& values) { 252 ++m_stats.m_num_find_le; 253 ptr_vector<node> todo[2]; 254 todo[0].push_back(m_root); 255 bool index = false; 256 for (unsigned i = 0; i < num_keys(); ++i) { 257 for (unsigned j = 0; j < todo[index].size(); ++j) { 258 ++m_stats.m_num_find_le_nodes; 259 to_trie(todo[index][j])->find_le(m_le, get_key(keys, i), todo[!index]); 260 } 261 todo[index].reset(); 262 index = !index; 263 } 264 for (unsigned j = 0; j < todo[index].size(); ++j) { 265 values.push_back(to_leaf(todo[index][j])->get_value()); 266 } 267 } 268 269 // callback based find function 270 class check_value { 271 public: 272 virtual bool operator()(Value const& v) = 0; 273 }; 274 find_le(Key const * keys,check_value & check)275 bool find_le(Key const* keys, check_value& check) { 276 ++m_stats.m_num_find_le; 277 ++m_stats.m_num_find_le_nodes; 278 return find_le(m_root, 0, keys, check); 279 } 280 remove(Key const * keys)281 void remove(Key const* keys) { 282 ++m_stats.m_num_removes; 283 // assumption: key is in table. 284 node* n = m_root; 285 node* m = nullptr; 286 for (unsigned i = 0; i < num_keys(); ++i) { 287 n->dec_ref(); 288 VERIFY (to_trie(n)->find(get_key(keys, i), m)); 289 n = m; 290 } 291 n->dec_ref(); 292 } 293 reset_statistics()294 void reset_statistics() { 295 m_stats.reset(); 296 } 297 collect_statistics(statistics & st)298 void collect_statistics(statistics& st) const { 299 st.update("heap_trie.num_inserts", m_stats.m_num_inserts); 300 st.update("heap_trie.num_removes", m_stats.m_num_removes); 301 st.update("heap_trie.num_find_eq", m_stats.m_num_find_eq); 302 st.update("heap_trie.num_find_le", m_stats.m_num_find_le); 303 st.update("heap_trie.num_find_le_nodes", m_stats.m_num_find_le_nodes); 304 if (m_root) st.update("heap_trie.num_nodes", m_root->num_nodes()); 305 unsigned_vector nums; 306 ptr_vector<node> todo; 307 if (m_root) todo.push_back(m_root); 308 while (!todo.empty()) { 309 node* n = todo.back(); 310 todo.pop_back(); 311 if (is_trie(n)) { 312 trie* t = to_trie(n); 313 unsigned sz = t->nodes().size(); 314 if (nums.size() <= sz) { 315 nums.resize(sz+1); 316 } 317 ++nums[sz]; 318 for (unsigned i = 0; i < sz; ++i) { 319 todo.push_back(t->nodes()[i].second); 320 } 321 } 322 } 323 if (nums.size() < 16) nums.resize(16); 324 st.update("heap_trie.num_1_children", nums[1]); 325 st.update("heap_trie.num_2_children", nums[2]); 326 st.update("heap_trie.num_3_children", nums[3]); 327 st.update("heap_trie.num_4_children", nums[4]); 328 st.update("heap_trie.num_5_children", nums[5]); 329 st.update("heap_trie.num_6_children", nums[6]); 330 st.update("heap_trie.num_7_children", nums[7]); 331 st.update("heap_trie.num_8_children", nums[8]); 332 st.update("heap_trie.num_9_children", nums[9]); 333 st.update("heap_trie.num_10_children", nums[10]); 334 st.update("heap_trie.num_11_children", nums[11]); 335 st.update("heap_trie.num_12_children", nums[12]); 336 st.update("heap_trie.num_13_children", nums[13]); 337 st.update("heap_trie.num_14_children", nums[14]); 338 st.update("heap_trie.num_15_children", nums[15]); 339 unsigned sz = 0; 340 for (unsigned i = 16; i < nums.size(); ++i) { 341 sz += nums[i]; 342 } 343 st.update("heap_trie.num_16+_children", sz); 344 } 345 display(std::ostream & out)346 void display(std::ostream& out) const { 347 m_root->display(out, 0); 348 out << "\n"; 349 } 350 351 class iterator { 352 ptr_vector<node> m_path; 353 unsigned_vector m_idx; 354 vector<Key> m_keys; 355 unsigned m_count; 356 public: iterator(node * n)357 iterator(node* n) { 358 if (!n) { 359 m_count = UINT_MAX; 360 } 361 else { 362 m_count = 0; 363 first(n); 364 } 365 } keys()366 Key const* keys() { 367 return m_keys.c_ptr(); 368 } 369 value()370 Value const& value() const { 371 return to_leaf(m_path.back())->get_value(); 372 } 373 iterator& operator++() { fwd(); return *this; } 374 iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; } 375 bool operator==(iterator const& it) const {return m_count == it.m_count; } 376 bool operator!=(iterator const& it) const {return m_count != it.m_count; } 377 378 private: first(node * r)379 void first(node* r) { 380 SASSERT(r->ref_count() > 0); 381 while (is_trie(r)) { 382 trie* t = to_trie(r); 383 m_path.push_back(r); 384 unsigned sz = t->nodes().size(); 385 for (unsigned i = 0; i < sz; ++i) { 386 r = t->nodes()[i].second; 387 if (r->ref_count() > 0) { 388 m_idx.push_back(i); 389 m_keys.push_back(t->nodes()[i].first); 390 break; 391 } 392 } 393 } 394 SASSERT(is_leaf(r)); 395 m_path.push_back(r); 396 } 397 fwd()398 void fwd() { 399 if (m_path.empty()) { 400 m_count = UINT_MAX; 401 return; 402 } 403 m_path.pop_back(); 404 while (!m_path.empty()) { 405 trie* t = to_trie(m_path.back()); 406 unsigned idx = m_idx.back(); 407 unsigned sz = t->nodes().size(); 408 m_idx.pop_back(); 409 m_keys.pop_back(); 410 for (unsigned i = idx+1; i < sz; ++i) { 411 node* r = t->nodes()[i].second; 412 if (r->ref_count() > 0) { 413 m_idx.push_back(i); 414 m_keys.push_back(t->nodes()[i].first); 415 first(r); 416 ++m_count; 417 return; 418 } 419 } 420 m_path.pop_back(); 421 } 422 m_count = UINT_MAX; 423 } 424 }; 425 begin()426 iterator begin() const { 427 return iterator(m_root); 428 } 429 end()430 iterator end() const { 431 return iterator(0); 432 } 433 434 435 private: 436 num_keys()437 inline unsigned num_keys() const { 438 return m_num_keys; 439 } 440 get_key(Key const * keys,unsigned i)441 inline Key const& get_key(Key const* keys, unsigned i) const { 442 return keys[m_keys[i]]; 443 } 444 445 struct KeyEq { operatorKeyEq446 bool operator()(Key const& k1, Key const& k2) const { 447 return k1 == k2; 448 } 449 }; 450 451 452 typedef hashtable<Key, KeyHash, KeyEq> key_set; 453 454 struct key_info { 455 unsigned m_index; 456 unsigned m_index_size; key_infokey_info457 key_info(unsigned i, unsigned sz): 458 m_index(i), 459 m_index_size(sz) 460 {} 461 462 bool operator<(key_info const& other) const { 463 return 464 (m_index_size < other.m_index_size) || 465 ((m_index_size == other.m_index_size) && 466 (m_index < other.m_index)); 467 } 468 }; 469 reorder_keys()470 void reorder_keys() { 471 vector<key_set> weights; 472 weights.resize(num_keys()); 473 unsigned_vector depth; 474 ptr_vector<node> nodes; 475 depth.push_back(0); 476 nodes.push_back(m_root); 477 while (!nodes.empty()) { 478 node* n = nodes.back(); 479 unsigned d = depth.back(); 480 nodes.pop_back(); 481 depth.pop_back(); 482 if (is_trie(n)) { 483 trie* t = to_trie(n); 484 unsigned sz = t->nodes().size(); 485 for (unsigned i = 0; i < sz; ++i) { 486 nodes.push_back(t->nodes()[i].second); 487 depth.push_back(d+1); 488 weights[d].insert(t->nodes()[i].first); 489 } 490 } 491 } 492 SASSERT(weights.size() == num_keys()); 493 svector<key_info> infos; 494 unsigned sz = 0; 495 bool is_sorted = true; 496 for (unsigned i = 0; i < weights.size(); ++i) { 497 unsigned sz2 = weights[i].size(); 498 if (sz > sz2) { 499 is_sorted = false; 500 } 501 sz = sz2; 502 infos.push_back(key_info(i, sz)); 503 } 504 if (is_sorted) { 505 return; 506 } 507 std::sort(infos.begin(), infos.end()); 508 unsigned_vector sorted_keys, new_keys; 509 for (unsigned i = 0; i < num_keys(); ++i) { 510 unsigned j = infos[i].m_index; 511 sorted_keys.push_back(j); 512 new_keys.push_back(m_keys[j]); 513 } 514 // m_keys: i |-> key_index 515 // new_keys: i |-> new_key_index 516 // permutation: key_index |-> new_key_index 517 SASSERT(sorted_keys.size() == num_keys()); 518 SASSERT(new_keys.size() == num_keys()); 519 SASSERT(m_keys.size() == num_keys()); 520 iterator it = begin(); 521 trie* new_root = mk_trie(); 522 IF_VERBOSE(2, verbose_stream() << "before reshuffle: " << m_root->num_nodes() << " nodes\n";); 523 for (; it != end(); ++it) { 524 IF_VERBOSE(2, 525 for (unsigned i = 0; i < num_keys(); ++i) { 526 for (unsigned j = 0; j < num_keys(); ++j) { 527 if (m_keys[j] == i) { 528 verbose_stream() << it.keys()[j] << " "; 529 break; 530 } 531 } 532 } 533 verbose_stream() << " |-> " << it.value() << "\n";); 534 535 insert(new_root, num_keys(), it.keys(), sorted_keys.c_ptr(), it.value()); 536 } 537 del_node(m_root); 538 m_root = new_root; 539 for (unsigned i = 0; i < m_keys.size(); ++i) { 540 m_keys[i] = new_keys[i]; 541 } 542 543 IF_VERBOSE(2, verbose_stream() << "after reshuffle: " << new_root->num_nodes() << " nodes\n";); 544 IF_VERBOSE(2, 545 it = begin(); 546 for (; it != end(); ++it) { 547 for (unsigned i = 0; i < num_keys(); ++i) { 548 for (unsigned j = 0; j < num_keys(); ++j) { 549 if (m_keys[j] == i) { 550 verbose_stream() << it.keys()[j] << " "; 551 break; 552 } 553 } 554 } 555 verbose_stream() << " |-> " << it.value() << "\n"; 556 }); 557 } 558 find_le(node * n,unsigned index,Key const * keys,check_value & check)559 bool find_le(node* n, unsigned index, Key const* keys, check_value& check) { 560 if (index == num_keys()) { 561 SASSERT(n->ref_count() > 0); 562 bool r = check(to_leaf(n)->get_value()); 563 IF_VERBOSE(2, 564 for (unsigned j = 0; j < index; ++j) { 565 verbose_stream() << " "; 566 } 567 verbose_stream() << to_leaf(n)->get_value() << (r?" hit\n":" miss\n");); 568 return r; 569 } 570 else { 571 Key const& key = get_key(keys, index); 572 children_t& nodes = to_trie(n)->nodes(); 573 for (unsigned i = 0; i < nodes.size(); ++i) { 574 ++m_stats.m_num_find_le_nodes; 575 node* m = nodes[i].second; 576 IF_VERBOSE(2, 577 for (unsigned j = 0; j < index; ++j) { 578 verbose_stream() << " "; 579 } 580 verbose_stream() << nodes[i].first << " <=? " << key << " rc:" << m->ref_count() << "\n";); 581 if (m->ref_count() > 0 && m_le.le(nodes[i].first, key) && find_le(m, index+1, keys, check)) { 582 if (i > 0) { 583 std::swap(nodes[i], nodes[0]); 584 } 585 return true; 586 } 587 } 588 return false; 589 } 590 } 591 insert(node * n,unsigned num_keys,Key const * keys,unsigned const * permutation,Value const & val)592 void insert(node* n, unsigned num_keys, Key const* keys, unsigned const* permutation, Value const& val) { 593 // assumption: key is not in table. 594 for (unsigned i = 0; i < num_keys; ++i) { 595 n->inc_ref(); 596 n = insert_key(to_trie(n), (i + 1 == num_keys), keys[permutation[i]]); 597 } 598 n->inc_ref(); 599 to_leaf(n)->set_value(val); 600 SASSERT(n->ref_count() == 1); 601 } 602 insert_key(trie * n,bool is_leaf,Key const & key)603 node* insert_key(trie* n, bool is_leaf, Key const& key) { 604 node* m1 = is_leaf?m_spare_leaf:m_spare_trie; 605 node* m2 = n->find_or_insert(key, m1); 606 if (m1 == m2) { 607 if (is_leaf) { 608 m_spare_leaf = mk_leaf(); 609 } 610 else { 611 m_spare_trie = mk_trie(); 612 } 613 } 614 return m2; 615 } 616 mk_leaf()617 leaf* mk_leaf() { 618 void* mem = m_alloc.allocate(sizeof(leaf)); 619 return new (mem) leaf(); 620 } 621 mk_trie()622 trie* mk_trie() { 623 void* mem = m_alloc.allocate(sizeof(trie)); 624 return new (mem) trie(); 625 } 626 del_node(node * n)627 void del_node(node* n) { 628 if (!n) { 629 return; 630 } 631 if (is_trie(n)) { 632 trie* t = to_trie(n); 633 for (unsigned i = 0; i < t->nodes().size(); ++i) { 634 del_node(t->nodes()[i].second); 635 } 636 t->~trie(); 637 m_alloc.deallocate(sizeof(trie), t); 638 } 639 else { 640 leaf* l = to_leaf(n); 641 l->~leaf(); 642 m_alloc.deallocate(sizeof(leaf), l); 643 } 644 } 645 to_trie(node * n)646 static trie* to_trie(node* n) { 647 SASSERT(is_trie(n)); 648 return static_cast<trie*>(n); 649 } 650 to_leaf(node * n)651 static leaf* to_leaf(node* n) { 652 SASSERT(is_leaf(n)); 653 return static_cast<leaf*>(n); 654 } 655 is_leaf(node * n)656 static bool is_leaf(node* n) { 657 return n->type() == leaf_t; 658 } 659 is_trie(node * n)660 static bool is_trie(node* n) { 661 return n->type() == trie_t; 662 } 663 }; 664 665