1 /*++
2 Copyright (c) 2013 Microsoft Corporation
3 
4 Module Name:
5 
6     heap_trie.h
7 
8 Abstract:
9 
10     Heap trie structure.
11 
12     Structure that lets you retrieve point-wise smaller entries
13     of a tuple. A lookup is to identify entries whose keys
14     are point-wise dominated by the lookup key.
15 
16 Author:
17 
18     Nikolaj Bjorner (nbjorner) 2013-02-15.
19 
20 Notes:
21 
22     tries are unordered vectors of keys. This could be enhanced to use either
23     heaps or sorting. The problem with using the heap implementation directly is that there is no way to
24     retrieve elements less or equal to a key that is not already in the heap.
25     If nodes have only a few elements, then this would also be a bloated data-structure to maintain.
26 
27     Nodes are not de-allocated. Their reference count indicates if they are valid.
28     Possibly, add garbage collection.
29 
30     Maintaining sorted ranges for larger domains is another option.
31 
32     Another possible enhancement is to resplay the tree.
33     Keep current key index in the nodes.
34 
35 --*/
36 
37 #pragma once
38 
39 #include <cstring>
40 #include "util/map.h"
41 #include "util/vector.h"
42 #include "util/buffer.h"
43 #include "util/statistics.h"
44 #include "util/small_object_allocator.h"
45 
46 template<typename Key, typename KeyLE, typename KeyHash, typename Value>
47 class heap_trie {
48 
49     struct stats {
50         unsigned m_num_inserts;
51         unsigned m_num_removes;
52         unsigned m_num_find_eq;
53         unsigned m_num_find_le;
54         unsigned m_num_find_le_nodes;
statsstats55         stats() { reset(); }
resetstats56         void reset() { memset(this, 0, sizeof(*this)); }
57     };
58 
59     enum node_t {
60         trie_t,
61         leaf_t
62     };
63 
64     class node {
65         node_t   m_type;
66         unsigned m_ref;
67     public:
node(node_t t)68         node(node_t t): m_type(t), m_ref(0) {}
~node()69         virtual ~node() {}
type()70         node_t type() const { return m_type; }
inc_ref()71         void inc_ref() { ++m_ref; }
dec_ref()72         void dec_ref() { SASSERT(m_ref > 0); --m_ref; }
ref_count()73         unsigned ref_count() const { return m_ref; }
74         virtual void display(std::ostream& out, unsigned indent) const = 0;
75         virtual unsigned num_nodes() const = 0;
76         virtual unsigned num_leaves() const = 0;
77     };
78 
79     class leaf : public node {
80         Value m_value;
81     public:
leaf()82         leaf(): node(leaf_t) {}
~leaf()83         ~leaf() override {}
get_value()84         Value const& get_value() const { return m_value; }
set_value(Value const & v)85         void set_value(Value const& v) { m_value = v; }
display(std::ostream & out,unsigned indent)86         void display(std::ostream& out, unsigned indent) const override {
87             out << " value: " << m_value;
88         }
num_nodes()89         unsigned num_nodes() const override { return 1; }
num_leaves()90         unsigned num_leaves() const override { return this->ref_count()>0?1:0; }
91     };
92 
93     typedef buffer<std::pair<Key,node*>, true, 2> children_t;
94 
95     // lean trie node
96     class trie : public node {
97         children_t m_nodes;
98     public:
trie()99         trie(): node(trie_t) {}
100 
~trie()101         ~trie() override {
102         }
103 
find_or_insert(Key k,node * n)104         node* find_or_insert(Key k, node* n) {
105             for (unsigned i = 0; i < m_nodes.size(); ++i) {
106                 if (m_nodes[i].first == k) {
107                     return m_nodes[i].second;
108                 }
109             }
110             m_nodes.push_back(std::make_pair(k, n));
111             return n;
112         }
113 
find(Key k,node * & n)114         bool find(Key k, node*& n) const {
115             for (unsigned i = 0; i < m_nodes.size(); ++i) {
116                 if (m_nodes[i].first == k) {
117                     n = m_nodes[i].second;
118                     return n->ref_count() > 0;
119                 }
120             }
121             return false;
122         }
123 
124         // push nodes whose keys are <= key into vector.
find_le(KeyLE & le,Key key,ptr_vector<node> & nodes)125         void find_le(KeyLE& le, Key key, ptr_vector<node>& nodes) {
126             for (unsigned i = 0; i < m_nodes.size(); ++i) {
127                 if (le.le(m_nodes[i].first, key)) {
128                     node* n = m_nodes[i].second;
129                     if (n->ref_count() > 0){
130                         nodes.push_back(n);
131                     }
132                 }
133             }
134         }
135 
nodes()136         children_t const& nodes() const { return m_nodes; }
nodes()137         children_t & nodes() { return m_nodes; }
138 
display(std::ostream & out,unsigned indent)139         void display(std::ostream& out, unsigned indent) const override {
140             for (unsigned j = 0; j < m_nodes.size(); ++j) {
141                 if (j != 0 || indent > 0) {
142                     out << "\n";
143                 }
144                 for (unsigned i = 0; i < indent; ++i) {
145                     out << " ";
146                 }
147                 node* n = m_nodes[j].second;
148                 out << m_nodes[j].first << " refs: " << n->ref_count();
149                 n->display(out, indent + 1);
150             }
151         }
152 
num_nodes()153         unsigned num_nodes() const override {
154             unsigned sz = 1;
155             for (unsigned j = 0; j < m_nodes.size(); ++j) {
156                 sz += m_nodes[j].second->num_nodes();
157             }
158             return sz;
159         }
160 
num_leaves()161         unsigned num_leaves() const override {
162             unsigned sz = 0;
163             for (unsigned j = 0; j < m_nodes.size(); ++j) {
164                 sz += m_nodes[j].second->num_leaves();
165             }
166             return sz;
167         }
168 
169     private:
contains(Key k)170         bool contains(Key k) {
171             for (unsigned j = 0; j < m_nodes.size(); ++j) {
172                 if (m_nodes[j].first == k) {
173                     return true;
174                 }
175             }
176             return false;
177         }
178     };
179 
180     small_object_allocator m_alloc;
181     KeyLE&   m_le;
182     unsigned m_num_keys;
183     unsigned_vector m_keys;
184     unsigned m_do_reshuffle;
185     node*    m_root;
186     stats    m_stats;
187     node*    m_spare_leaf;
188     node*    m_spare_trie;
189 
190 public:
191 
heap_trie(KeyLE & le)192     heap_trie(KeyLE& le):
193         m_alloc("heap_trie"),
194         m_le(le),
195         m_num_keys(0),
196         m_do_reshuffle(4),
197         m_root(nullptr),
198         m_spare_leaf(nullptr),
199         m_spare_trie(nullptr)
200     {}
201 
~heap_trie()202     ~heap_trie() {
203         del_node(m_root);
204         del_node(m_spare_leaf);
205         del_node(m_spare_trie);
206     }
207 
size()208     unsigned size() const {
209         return m_root?m_root->num_leaves():0;
210     }
211 
reset(unsigned num_keys)212     void reset(unsigned num_keys) {
213         del_node(m_root);
214         del_node(m_spare_leaf);
215         del_node(m_spare_trie);
216         m_num_keys = num_keys;
217         m_keys.resize(num_keys);
218         for (unsigned i = 0; i < num_keys; ++i) {
219             m_keys[i] = i;
220         }
221         m_root = mk_trie();
222         m_spare_trie = mk_trie();
223         m_spare_leaf = mk_leaf();
224     }
225 
insert(Key const * keys,Value const & val)226     void insert(Key const* keys, Value const& val) {
227         ++m_stats.m_num_inserts;
228         insert(m_root, num_keys(), keys, m_keys.c_ptr(), val);
229 #if 0
230         if (m_stats.m_num_inserts == (1 << m_do_reshuffle)) {
231             m_do_reshuffle++;
232             reorder_keys();
233         }
234 #endif
235     }
236 
find_eq(Key const * keys,Value & value)237     bool find_eq(Key const* keys, Value& value) {
238         ++m_stats.m_num_find_eq;
239         node* n = m_root;
240         node* m;
241         for (unsigned i = 0; i < num_keys(); ++i) {
242             if (!to_trie(n)->find(get_key(keys, i), m)) {
243                 return false;
244             }
245             n = m;
246         }
247         value = to_leaf(n)->get_value();
248         return true;
249     }
250 
find_all_le(Key const * keys,vector<Value> & values)251     void find_all_le(Key const* keys, vector<Value>& values) {
252         ++m_stats.m_num_find_le;
253         ptr_vector<node> todo[2];
254         todo[0].push_back(m_root);
255         bool index = false;
256         for (unsigned i = 0; i < num_keys(); ++i) {
257             for (unsigned j = 0; j < todo[index].size(); ++j) {
258                 ++m_stats.m_num_find_le_nodes;
259                 to_trie(todo[index][j])->find_le(m_le, get_key(keys, i), todo[!index]);
260             }
261             todo[index].reset();
262             index = !index;
263         }
264         for (unsigned j = 0; j < todo[index].size(); ++j) {
265             values.push_back(to_leaf(todo[index][j])->get_value());
266         }
267     }
268 
269     // callback based find function
270     class check_value {
271     public:
272         virtual bool operator()(Value const& v) = 0;
273     };
274 
find_le(Key const * keys,check_value & check)275     bool find_le(Key const* keys, check_value& check) {
276         ++m_stats.m_num_find_le;
277         ++m_stats.m_num_find_le_nodes;
278         return find_le(m_root, 0, keys, check);
279     }
280 
remove(Key const * keys)281     void remove(Key const* keys) {
282         ++m_stats.m_num_removes;
283         // assumption: key is in table.
284         node* n = m_root;
285         node* m = nullptr;
286         for (unsigned i = 0; i < num_keys(); ++i) {
287             n->dec_ref();
288             VERIFY (to_trie(n)->find(get_key(keys, i), m));
289             n = m;
290         }
291         n->dec_ref();
292     }
293 
reset_statistics()294     void reset_statistics() {
295         m_stats.reset();
296     }
297 
collect_statistics(statistics & st)298     void collect_statistics(statistics& st) const {
299         st.update("heap_trie.num_inserts", m_stats.m_num_inserts);
300         st.update("heap_trie.num_removes", m_stats.m_num_removes);
301         st.update("heap_trie.num_find_eq", m_stats.m_num_find_eq);
302         st.update("heap_trie.num_find_le", m_stats.m_num_find_le);
303         st.update("heap_trie.num_find_le_nodes", m_stats.m_num_find_le_nodes);
304         if (m_root) st.update("heap_trie.num_nodes", m_root->num_nodes());
305         unsigned_vector nums;
306         ptr_vector<node> todo;
307         if (m_root) todo.push_back(m_root);
308         while (!todo.empty()) {
309             node* n = todo.back();
310             todo.pop_back();
311             if (is_trie(n)) {
312                 trie* t = to_trie(n);
313                 unsigned sz = t->nodes().size();
314                 if (nums.size() <= sz) {
315                     nums.resize(sz+1);
316                 }
317                 ++nums[sz];
318                 for (unsigned i = 0; i < sz; ++i) {
319                     todo.push_back(t->nodes()[i].second);
320                 }
321             }
322         }
323         if (nums.size() < 16) nums.resize(16);
324         st.update("heap_trie.num_1_children", nums[1]);
325         st.update("heap_trie.num_2_children", nums[2]);
326         st.update("heap_trie.num_3_children", nums[3]);
327         st.update("heap_trie.num_4_children", nums[4]);
328         st.update("heap_trie.num_5_children", nums[5]);
329         st.update("heap_trie.num_6_children", nums[6]);
330         st.update("heap_trie.num_7_children", nums[7]);
331         st.update("heap_trie.num_8_children", nums[8]);
332         st.update("heap_trie.num_9_children", nums[9]);
333         st.update("heap_trie.num_10_children", nums[10]);
334         st.update("heap_trie.num_11_children", nums[11]);
335         st.update("heap_trie.num_12_children", nums[12]);
336         st.update("heap_trie.num_13_children", nums[13]);
337         st.update("heap_trie.num_14_children", nums[14]);
338         st.update("heap_trie.num_15_children", nums[15]);
339         unsigned sz = 0;
340         for (unsigned i = 16; i < nums.size(); ++i) {
341             sz += nums[i];
342         }
343         st.update("heap_trie.num_16+_children", sz);
344     }
345 
display(std::ostream & out)346     void display(std::ostream& out) const {
347         m_root->display(out, 0);
348         out << "\n";
349     }
350 
351     class iterator {
352         ptr_vector<node> m_path;
353         unsigned_vector  m_idx;
354         vector<Key>      m_keys;
355         unsigned         m_count;
356     public:
iterator(node * n)357         iterator(node* n) {
358             if (!n) {
359                 m_count = UINT_MAX;
360             }
361             else {
362                 m_count = 0;
363                 first(n);
364             }
365         }
keys()366         Key const* keys() {
367             return m_keys.c_ptr();
368         }
369 
value()370         Value const& value() const {
371             return to_leaf(m_path.back())->get_value();
372         }
373         iterator& operator++() { fwd(); return *this; }
374         iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; }
375         bool operator==(iterator const& it) const {return m_count == it.m_count; }
376         bool operator!=(iterator const& it) const {return m_count != it.m_count; }
377 
378     private:
first(node * r)379         void first(node* r) {
380             SASSERT(r->ref_count() > 0);
381             while (is_trie(r)) {
382                 trie* t = to_trie(r);
383                 m_path.push_back(r);
384                 unsigned sz = t->nodes().size();
385                 for (unsigned i = 0; i < sz; ++i) {
386                     r = t->nodes()[i].second;
387                     if (r->ref_count() > 0) {
388                         m_idx.push_back(i);
389                         m_keys.push_back(t->nodes()[i].first);
390                         break;
391                     }
392                 }
393             }
394             SASSERT(is_leaf(r));
395             m_path.push_back(r);
396         }
397 
fwd()398         void fwd() {
399             if (m_path.empty()) {
400                 m_count = UINT_MAX;
401                 return;
402             }
403             m_path.pop_back();
404             while (!m_path.empty()) {
405                 trie* t = to_trie(m_path.back());
406                 unsigned idx = m_idx.back();
407                 unsigned sz = t->nodes().size();
408                 m_idx.pop_back();
409                 m_keys.pop_back();
410                 for (unsigned i = idx+1; i < sz; ++i) {
411                     node* r = t->nodes()[i].second;
412                     if (r->ref_count() > 0) {
413                         m_idx.push_back(i);
414                         m_keys.push_back(t->nodes()[i].first);
415                         first(r);
416                         ++m_count;
417                         return;
418                     }
419                 }
420                 m_path.pop_back();
421             }
422             m_count = UINT_MAX;
423         }
424     };
425 
begin()426     iterator begin() const {
427         return iterator(m_root);
428     }
429 
end()430     iterator end() const {
431         return iterator(0);
432     }
433 
434 
435 private:
436 
num_keys()437     inline unsigned num_keys() const {
438         return m_num_keys;
439     }
440 
get_key(Key const * keys,unsigned i)441     inline Key const& get_key(Key const* keys, unsigned i) const {
442         return keys[m_keys[i]];
443     }
444 
445     struct KeyEq {
operatorKeyEq446         bool operator()(Key const& k1, Key const& k2) const {
447             return k1 == k2;
448         }
449     };
450 
451 
452     typedef hashtable<Key, KeyHash, KeyEq> key_set;
453 
454     struct key_info {
455         unsigned m_index;
456         unsigned m_index_size;
key_infokey_info457         key_info(unsigned i, unsigned sz):
458             m_index(i),
459             m_index_size(sz)
460         {}
461 
462         bool operator<(key_info const& other) const {
463             return
464                 (m_index_size < other.m_index_size) ||
465                 ((m_index_size == other.m_index_size) &&
466                  (m_index < other.m_index));
467         }
468     };
469 
reorder_keys()470     void reorder_keys() {
471         vector<key_set> weights;
472         weights.resize(num_keys());
473         unsigned_vector depth;
474         ptr_vector<node> nodes;
475         depth.push_back(0);
476         nodes.push_back(m_root);
477         while (!nodes.empty()) {
478             node* n = nodes.back();
479             unsigned d = depth.back();
480             nodes.pop_back();
481             depth.pop_back();
482             if (is_trie(n)) {
483                 trie* t = to_trie(n);
484                 unsigned sz = t->nodes().size();
485                 for (unsigned i = 0; i < sz; ++i) {
486                     nodes.push_back(t->nodes()[i].second);
487                     depth.push_back(d+1);
488                     weights[d].insert(t->nodes()[i].first);
489                 }
490             }
491         }
492         SASSERT(weights.size() == num_keys());
493         svector<key_info> infos;
494         unsigned sz = 0;
495         bool is_sorted = true;
496         for (unsigned i = 0; i < weights.size(); ++i) {
497             unsigned sz2 = weights[i].size();
498             if (sz > sz2) {
499                 is_sorted = false;
500             }
501             sz = sz2;
502             infos.push_back(key_info(i, sz));
503         }
504         if (is_sorted) {
505             return;
506         }
507         std::sort(infos.begin(), infos.end());
508         unsigned_vector sorted_keys, new_keys;
509         for (unsigned i = 0; i < num_keys(); ++i) {
510             unsigned j = infos[i].m_index;
511             sorted_keys.push_back(j);
512             new_keys.push_back(m_keys[j]);
513         }
514         // m_keys:    i |-> key_index
515         // new_keys:  i |-> new_key_index
516         // permutation: key_index |-> new_key_index
517         SASSERT(sorted_keys.size() == num_keys());
518         SASSERT(new_keys.size() == num_keys());
519         SASSERT(m_keys.size() == num_keys());
520         iterator it = begin();
521         trie* new_root = mk_trie();
522         IF_VERBOSE(2, verbose_stream() << "before reshuffle: " << m_root->num_nodes() << " nodes\n";);
523         for (; it != end(); ++it) {
524             IF_VERBOSE(2,
525                        for (unsigned i = 0; i < num_keys(); ++i) {
526                            for (unsigned j = 0; j < num_keys(); ++j) {
527                                if (m_keys[j] == i) {
528                                    verbose_stream() << it.keys()[j] << " ";
529                                    break;
530                                }
531                            }
532                        }
533                        verbose_stream() << " |-> " << it.value() << "\n";);
534 
535             insert(new_root, num_keys(), it.keys(), sorted_keys.c_ptr(), it.value());
536         }
537         del_node(m_root);
538         m_root = new_root;
539         for (unsigned i = 0; i < m_keys.size(); ++i) {
540             m_keys[i] = new_keys[i];
541         }
542 
543         IF_VERBOSE(2, verbose_stream() << "after reshuffle: " << new_root->num_nodes() << " nodes\n";);
544         IF_VERBOSE(2,
545                    it = begin();
546                    for (; it != end(); ++it) {
547                        for (unsigned i = 0; i < num_keys(); ++i) {
548                            for (unsigned j = 0; j < num_keys(); ++j) {
549                                if (m_keys[j] == i) {
550                                    verbose_stream() << it.keys()[j] << " ";
551                                    break;
552                                }
553                            }
554                        }
555                        verbose_stream() << " |-> " << it.value() << "\n";
556                    });
557     }
558 
find_le(node * n,unsigned index,Key const * keys,check_value & check)559     bool find_le(node* n, unsigned index, Key const* keys, check_value& check) {
560         if (index == num_keys()) {
561             SASSERT(n->ref_count() > 0);
562             bool r = check(to_leaf(n)->get_value());
563             IF_VERBOSE(2,
564                        for (unsigned j = 0; j < index; ++j) {
565                            verbose_stream() << " ";
566                        }
567                        verbose_stream() << to_leaf(n)->get_value() << (r?" hit\n":" miss\n"););
568             return r;
569         }
570         else {
571             Key const& key = get_key(keys, index);
572             children_t& nodes = to_trie(n)->nodes();
573             for (unsigned i = 0; i < nodes.size(); ++i) {
574                 ++m_stats.m_num_find_le_nodes;
575                 node* m = nodes[i].second;
576                 IF_VERBOSE(2,
577                            for (unsigned j = 0; j < index; ++j) {
578                                verbose_stream() << " ";
579                            }
580                            verbose_stream() << nodes[i].first << " <=? " << key << " rc:" << m->ref_count() << "\n";);
581                 if (m->ref_count() > 0 && m_le.le(nodes[i].first, key) && find_le(m, index+1, keys, check)) {
582                     if (i > 0) {
583                         std::swap(nodes[i], nodes[0]);
584                     }
585                     return true;
586                 }
587             }
588             return false;
589         }
590     }
591 
insert(node * n,unsigned num_keys,Key const * keys,unsigned const * permutation,Value const & val)592     void insert(node* n, unsigned num_keys, Key const* keys, unsigned const* permutation, Value const& val) {
593         // assumption: key is not in table.
594         for (unsigned i = 0; i < num_keys; ++i) {
595             n->inc_ref();
596             n = insert_key(to_trie(n), (i + 1 == num_keys), keys[permutation[i]]);
597         }
598         n->inc_ref();
599         to_leaf(n)->set_value(val);
600         SASSERT(n->ref_count() == 1);
601     }
602 
insert_key(trie * n,bool is_leaf,Key const & key)603     node* insert_key(trie* n, bool is_leaf, Key const& key) {
604         node* m1 = is_leaf?m_spare_leaf:m_spare_trie;
605         node* m2 = n->find_or_insert(key, m1);
606         if (m1 == m2) {
607             if (is_leaf) {
608                 m_spare_leaf = mk_leaf();
609             }
610             else {
611                 m_spare_trie = mk_trie();
612             }
613         }
614         return m2;
615     }
616 
mk_leaf()617     leaf* mk_leaf() {
618         void* mem = m_alloc.allocate(sizeof(leaf));
619         return new (mem) leaf();
620     }
621 
mk_trie()622     trie* mk_trie() {
623         void* mem = m_alloc.allocate(sizeof(trie));
624         return new (mem) trie();
625     }
626 
del_node(node * n)627     void del_node(node* n) {
628         if (!n) {
629             return;
630         }
631         if (is_trie(n)) {
632             trie* t = to_trie(n);
633             for (unsigned i = 0; i < t->nodes().size(); ++i) {
634                 del_node(t->nodes()[i].second);
635             }
636             t->~trie();
637             m_alloc.deallocate(sizeof(trie), t);
638         }
639         else {
640             leaf* l = to_leaf(n);
641             l->~leaf();
642             m_alloc.deallocate(sizeof(leaf), l);
643         }
644     }
645 
to_trie(node * n)646     static trie* to_trie(node* n) {
647         SASSERT(is_trie(n));
648         return static_cast<trie*>(n);
649     }
650 
to_leaf(node * n)651     static leaf* to_leaf(node* n) {
652         SASSERT(is_leaf(n));
653         return static_cast<leaf*>(n);
654     }
655 
is_leaf(node * n)656     static bool is_leaf(node* n) {
657         return n->type() == leaf_t;
658     }
659 
is_trie(node * n)660     static bool is_trie(node* n) {
661         return n->type() == trie_t;
662     }
663 };
664 
665