1/*************************************************************************
2*
3* Copyright (c) 2015 Kohei Yoshida
4*
5* Permission is hereby granted, free of charge, to any person
6* obtaining a copy of this software and associated documentation
7* files (the "Software"), to deal in the Software without
8* restriction, including without limitation the rights to use,
9* copy, modify, merge, publish, distribute, sublicense, and/or sell
10* copies of the Software, and to permit persons to whom the
11* Software is furnished to do so, subject to the following
12* conditions:
13*
14* The above copyright notice and this permission notice shall be
15* included in all copies or substantial portions of the Software.
16*
17* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
19* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
21* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
22* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24* OTHER DEALINGS IN THE SOFTWARE.
25*
26************************************************************************/
27
28#include <algorithm>
29
30namespace mdds {
31
32namespace __st {
33
34template<typename T, typename _Inserter>
35void descend_tree_for_search(
36    typename T::key_type point, const __st::node_base* pnode, _Inserter& result)
37{
38    typedef typename T::node leaf_node;
39    typedef typename T::nonleaf_node nonleaf_node;
40
41    typedef typename T::nonleaf_value_type nonleaf_value_type;
42    typedef typename T::leaf_value_type leaf_value_type;
43
44    if (!pnode)
45        // This should never happen, but just in case.
46        return;
47
48    if (pnode->is_leaf)
49    {
50        result(static_cast<const leaf_node*>(pnode)->value_leaf.data_chain);
51        return;
52    }
53
54    const nonleaf_node* pnonleaf = static_cast<const nonleaf_node*>(pnode);
55    const nonleaf_value_type& v = pnonleaf->value_nonleaf;
56    if (point < v.low || v.high <= point)
57        // Query point is out-of-range.
58        return;
59
60    result(v.data_chain);
61
62    // Check the left child node first, then the right one.
63    __st::node_base* pchild = pnonleaf->left;
64    if (!pchild)
65        return;
66
67    assert(pnonleaf->right ? pchild->is_leaf == pnonleaf->right->is_leaf : true);
68
69    if (pchild->is_leaf)
70    {
71        // The child node are leaf nodes.
72        const leaf_value_type& vleft = static_cast<const leaf_node*>(pchild)->value_leaf;
73        if (point < vleft.key)
74        {
75            // Out-of-range.  Nothing more to do.
76            return;
77        }
78
79        if (pnonleaf->right)
80        {
81            assert(pnonleaf->right->is_leaf);
82            const leaf_value_type& vright = static_cast<const leaf_node*>(pnonleaf->right)->value_leaf;
83            if (vright.key <= point)
84                // Follow the right node.
85                pchild = pnonleaf->right;
86        }
87    }
88    else
89    {
90        // This child nodes are non-leaf nodes.
91
92        const nonleaf_value_type& vleft =
93            static_cast<const nonleaf_node*>(pchild)->value_nonleaf;
94
95        if (point < vleft.low)
96        {
97            // Out-of-range.  Nothing more to do.
98            return;
99        }
100        if (vleft.high <= point && pnonleaf->right)
101            // Follow the right child.
102            pchild = pnonleaf->right;
103
104        assert(static_cast<const nonleaf_node*>(pchild)->value_nonleaf.low <= point &&
105               point < static_cast<const nonleaf_node*>(pchild)->value_nonleaf.high);
106    }
107
108    descend_tree_for_search<T,_Inserter>(point, pchild, result);
109}
110
111} // namespace __st
112
113template<typename _Key, typename _Value>
114segment_tree<_Key, _Value>::segment_tree()
115    : m_root_node(nullptr)
116    , m_valid_tree(false)
117{
118}
119
120template<typename _Key, typename _Value>
121segment_tree<_Key, _Value>::segment_tree(const segment_tree& r)
122    : m_segment_data(r.m_segment_data)
123    , m_root_node(nullptr)
124    , m_valid_tree(r.m_valid_tree)
125{
126    if (m_valid_tree)
127        build_tree();
128}
129
130template<typename _Key, typename _Value>
131segment_tree<_Key, _Value>::~segment_tree()
132{
133    clear_all_nodes();
134}
135
136template<typename _Key, typename _Value>
137bool segment_tree<_Key, _Value>::operator==(const segment_tree& r) const
138{
139    if (m_valid_tree != r.m_valid_tree)
140        return false;
141
142    // Sort the data by key values first.
143    sorted_segment_map_type seg1(m_segment_data.begin(), m_segment_data.end());
144    sorted_segment_map_type seg2(r.m_segment_data.begin(), r.m_segment_data.end());
145    typename sorted_segment_map_type::const_iterator itr1 = seg1.begin(), itr1_end = seg1.end();
146    typename sorted_segment_map_type::const_iterator itr2 = seg2.begin(), itr2_end = seg2.end();
147
148    for (; itr1 != itr1_end; ++itr1, ++itr2)
149    {
150        if (itr2 == itr2_end)
151            return false;
152
153        if (*itr1 != *itr2)
154            return false;
155    }
156    if (itr2 != itr2_end)
157        return false;
158
159    return true;
160}
161
162template<typename _Key, typename _Value>
163void segment_tree<_Key, _Value>::build_tree()
164{
165    build_leaf_nodes();
166    m_nonleaf_node_pool.clear();
167
168    // Count the number of leaf nodes.
169    size_t leaf_count = __st::count_leaf_nodes(m_left_leaf.get(), m_right_leaf.get());
170
171    // Determine the total number of non-leaf nodes needed to build the whole tree.
172    size_t nonleaf_count = __st::count_needed_nonleaf_nodes(leaf_count);
173
174    m_nonleaf_node_pool.resize(nonleaf_count);
175
176    mdds::__st::tree_builder<segment_tree> builder(m_nonleaf_node_pool);
177    m_root_node = builder.build(m_left_leaf);
178
179    // Start "inserting" all segments from the root.
180    typename segment_map_type::const_iterator itr,
181        itr_beg = m_segment_data.begin(), itr_end = m_segment_data.end();
182
183    data_node_map_type tagged_node_map;
184    for (itr = itr_beg; itr != itr_end; ++itr)
185    {
186        value_type pdata = itr->first;
187        auto r = tagged_node_map.insert(
188            typename data_node_map_type::value_type(
189                pdata, make_unique<node_list_type>()));
190
191        node_list_type* plist = r.first->second.get();
192        plist->reserve(10);
193
194        descend_tree_and_mark(m_root_node, pdata, itr->second.first, itr->second.second, plist);
195    }
196
197    m_tagged_node_map.swap(tagged_node_map);
198    m_valid_tree = true;
199}
200
201template<typename _Key, typename _Value>
202void segment_tree<_Key, _Value>::descend_tree_and_mark(
203    __st::node_base* pnode, value_type pdata, key_type begin_key, key_type end_key, node_list_type* plist)
204{
205    if (!pnode)
206        return;
207
208    if (pnode->is_leaf)
209    {
210        // This is a leaf node.
211        node* pleaf = static_cast<node*>(pnode);
212        if (begin_key <= pleaf->value_leaf.key && pleaf->value_leaf.key < end_key)
213        {
214            leaf_value_type& v = pleaf->value_leaf;
215            if (!v.data_chain)
216                v.data_chain = new data_chain_type;
217            v.data_chain->push_back(pdata);
218            plist->push_back(pnode);
219        }
220        return;
221    }
222
223    nonleaf_node* pnonleaf = static_cast<nonleaf_node*>(pnode);
224    if (end_key < pnonleaf->value_nonleaf.low || pnonleaf->value_nonleaf.high <= begin_key)
225        return;
226
227    nonleaf_value_type& v = pnonleaf->value_nonleaf;
228    if (begin_key <= v.low && v.high < end_key)
229    {
230        // mark this non-leaf node and stop.
231        if (!v.data_chain)
232            v.data_chain = new data_chain_type;
233        v.data_chain->push_back(pdata);
234        plist->push_back(pnode);
235        return;
236    }
237
238    descend_tree_and_mark(pnonleaf->left, pdata, begin_key, end_key, plist);
239    descend_tree_and_mark(pnonleaf->right, pdata, begin_key, end_key, plist);
240}
241
242template<typename _Key, typename _Value>
243void segment_tree<_Key, _Value>::build_leaf_nodes()
244{
245    using namespace std;
246
247    disconnect_leaf_nodes(m_left_leaf.get(), m_right_leaf.get());
248
249    // In 1st pass, collect unique end-point values and sort them.
250    vector<key_type> keys_uniq;
251    keys_uniq.reserve(m_segment_data.size()*2);
252    typename segment_map_type::const_iterator itr, itr_beg = m_segment_data.begin(), itr_end = m_segment_data.end();
253    for (itr = itr_beg; itr != itr_end; ++itr)
254    {
255        keys_uniq.push_back(itr->second.first);
256        keys_uniq.push_back(itr->second.second);
257    }
258
259    // sort and remove duplicates.
260    sort(keys_uniq.begin(), keys_uniq.end());
261    keys_uniq.erase(unique(keys_uniq.begin(), keys_uniq.end()), keys_uniq.end());
262
263    create_leaf_node_instances(keys_uniq, m_left_leaf, m_right_leaf);
264}
265
266template<typename _Key, typename _Value>
267void segment_tree<_Key, _Value>::create_leaf_node_instances(const ::std::vector<key_type>& keys, node_ptr& left, node_ptr& right)
268{
269    if (keys.empty() || keys.size() < 2)
270        // We need at least two keys in order to build tree.
271        return;
272
273    typename ::std::vector<key_type>::const_iterator itr = keys.begin(), itr_end = keys.end();
274
275    // left-most node
276    left.reset(new node);
277    left->value_leaf.key = *itr;
278
279    // move on to next.
280    left->next.reset(new node);
281    node_ptr prev_node = left;
282    node_ptr cur_node = left->next;
283    cur_node->prev = prev_node;
284
285    for (++itr; itr != itr_end; ++itr)
286    {
287        cur_node->value_leaf.key = *itr;
288
289        // move on to next
290        cur_node->next.reset(new node);
291        prev_node = cur_node;
292        cur_node = cur_node->next;
293        cur_node->prev = prev_node;
294    }
295
296    // Remove the excess node.
297    prev_node->next.reset();
298    right = prev_node;
299}
300
301template<typename _Key, typename _Value>
302bool segment_tree<_Key, _Value>::insert(key_type begin_key, key_type end_key, value_type pdata)
303{
304    if (begin_key >= end_key)
305        return false;
306
307    if (m_segment_data.find(pdata) != m_segment_data.end())
308        // Insertion of duplicate data is not allowed.
309        return false;
310
311    ::std::pair<key_type, key_type> range;
312    range.first = begin_key;
313    range.second = end_key;
314    m_segment_data.insert(typename segment_map_type::value_type(pdata, range));
315
316    m_valid_tree = false;
317    return true;
318}
319
320template<typename _Key, typename _Value>
321bool segment_tree<_Key, _Value>::search(key_type point, search_result_type& result) const
322{
323    if (!m_valid_tree)
324        // Tree is invalidated.
325        return false;
326
327    if (!m_root_node)
328        // Tree doesn't exist.  Since the tree is flagged valid, this means no
329        // segments have been inserted.
330        return true;
331
332    search_result_vector_inserter result_inserter(result);
333    typedef segment_tree<_Key,_Value> tree_type;
334    __st::descend_tree_for_search<
335        tree_type, search_result_vector_inserter>(point, m_root_node, result_inserter);
336    return true;
337}
338
339template<typename _Key, typename _Value>
340typename segment_tree<_Key, _Value>::search_result
341segment_tree<_Key, _Value>::search(key_type point) const
342{
343    search_result result;
344    if (!m_valid_tree || !m_root_node)
345        return result;
346
347    search_result_inserter result_inserter(result);
348    typedef segment_tree<_Key,_Value> tree_type;
349    __st::descend_tree_for_search<tree_type, search_result_inserter>(
350        point, m_root_node, result_inserter);
351
352    return result;
353}
354
355template<typename _Key, typename _Value>
356void segment_tree<_Key, _Value>::search(key_type point, search_result_base& result) const
357{
358    if (!m_valid_tree || !m_root_node)
359        return;
360
361    search_result_inserter result_inserter(result);
362    typedef segment_tree<_Key,_Value> tree_type;
363    __st::descend_tree_for_search<tree_type>(point, m_root_node, result_inserter);
364}
365
366template<typename _Key, typename _Value>
367void segment_tree<_Key, _Value>::remove(value_type value)
368{
369    using namespace std;
370
371    typename data_node_map_type::iterator itr = m_tagged_node_map.find(value);
372    if (itr != m_tagged_node_map.end())
373    {
374        // Tagged node list found.  Remove all the tags from the tree nodes.
375        node_list_type* plist = itr->second.get();
376        if (!plist)
377            return;
378
379        remove_data_from_nodes(plist, value);
380
381        // Remove the tags associated with this pointer from the data set.
382        m_tagged_node_map.erase(itr);
383    }
384
385    // Remove from the segment data array.
386    m_segment_data.erase(value);
387}
388
389template<typename _Key, typename _Value>
390void segment_tree<_Key, _Value>::clear()
391{
392    m_tagged_node_map.clear();
393    m_segment_data.clear();
394    clear_all_nodes();
395    m_valid_tree = false;
396}
397
398template<typename _Key, typename _Value>
399size_t segment_tree<_Key, _Value>::size() const
400{
401    return m_segment_data.size();
402}
403
404template<typename _Key, typename _Value>
405bool segment_tree<_Key, _Value>::empty() const
406{
407    return m_segment_data.empty();
408}
409
410template<typename _Key, typename _Value>
411size_t segment_tree<_Key, _Value>::leaf_size() const
412{
413    return __st::count_leaf_nodes(m_left_leaf.get(), m_right_leaf.get());
414}
415
416template<typename _Key, typename _Value>
417void segment_tree<_Key, _Value>::remove_data_from_nodes(node_list_type* plist, const value_type pdata)
418{
419    typename node_list_type::iterator itr = plist->begin(), itr_end = plist->end();
420    for (; itr != itr_end; ++itr)
421    {
422        data_chain_type* chain = nullptr;
423        __st::node_base* p = *itr;
424        if (p->is_leaf)
425            chain = static_cast<node*>(p)->value_leaf.data_chain;
426        else
427            chain = static_cast<nonleaf_node*>(p)->value_nonleaf.data_chain;
428
429        if (!chain)
430            continue;
431
432        remove_data_from_chain(*chain, pdata);
433    }
434}
435
436template<typename _Key, typename _Value>
437void segment_tree<_Key, _Value>::remove_data_from_chain(data_chain_type& chain, const value_type pdata)
438{
439    typename data_chain_type::iterator itr = ::std::find(chain.begin(), chain.end(), pdata);
440    if (itr != chain.end())
441    {
442        *itr = chain.back();
443        chain.pop_back();
444    }
445}
446
447template<typename _Key, typename _Value>
448void segment_tree<_Key, _Value>::clear_all_nodes()
449{
450    disconnect_leaf_nodes(m_left_leaf.get(), m_right_leaf.get());
451    m_nonleaf_node_pool.clear();
452    m_left_leaf.reset();
453    m_right_leaf.reset();
454    m_root_node = nullptr;
455}
456
457#ifdef MDDS_UNIT_TEST
458template<typename _Key, typename _Value>
459void segment_tree<_Key, _Value>::dump_tree() const
460{
461    using ::std::cout;
462    using ::std::endl;
463
464    if (!m_valid_tree)
465        assert(!"attempted to dump an invalid tree!");
466
467    cout << "dump tree ------------------------------------------------------" << endl;
468    size_t node_count = mdds::__st::tree_dumper<node, nonleaf_node>::dump(m_root_node);
469    size_t node_instance_count = node::get_instance_count();
470
471    cout << "tree node count = " << node_count << "    node instance count = " << node_instance_count << endl;
472}
473
474template<typename _Key, typename _Value>
475void segment_tree<_Key, _Value>::dump_leaf_nodes() const
476{
477    using ::std::cout;
478    using ::std::endl;
479
480    cout << "dump leaf nodes ------------------------------------------------" << endl;
481
482    node* p = m_left_leaf.get();
483    while (p)
484    {
485        print_leaf_value(p->value_leaf);
486        p = p->next.get();
487    }
488    cout << "  node instance count = " << node::get_instance_count() << endl;
489}
490
491template<typename _Key, typename _Value>
492void segment_tree<_Key, _Value>::dump_segment_data() const
493{
494    using namespace std;
495    cout << "dump segment data ----------------------------------------------" << endl;
496
497    segment_map_printer func;
498    for_each(m_segment_data.begin(), m_segment_data.end(), func);
499}
500
501template<typename _Key, typename _Value>
502bool segment_tree<_Key, _Value>::verify_node_lists() const
503{
504    using namespace std;
505
506    typename data_node_map_type::const_iterator
507        itr = m_tagged_node_map.begin(), itr_end = m_tagged_node_map.end();
508    for (; itr != itr_end; ++itr)
509    {
510        // Print stored nodes.
511        cout << "node list " << itr->first->name << ": ";
512        const node_list_type* plist = itr->second.get();
513        assert(plist);
514        node_printer func;
515        for_each(plist->begin(), plist->end(), func);
516        cout << endl;
517
518        // Verify that all of these nodes have the data pointer.
519        if (!has_data_pointer(*plist, itr->first))
520            return false;
521    }
522    return true;
523}
524
525template<typename _Key, typename _Value>
526bool segment_tree<_Key, _Value>::verify_leaf_nodes(const ::std::vector<leaf_node_check>& checks) const
527{
528    using namespace std;
529
530    node* cur_node = m_left_leaf.get();
531    typename ::std::vector<leaf_node_check>::const_iterator itr = checks.begin(), itr_end = checks.end();
532    for (; itr != itr_end; ++itr)
533    {
534        if (!cur_node)
535            // Position past the right-mode node.  Invalid.
536            return false;
537
538        if (cur_node->value_leaf.key != itr->key)
539            // Key values differ.
540            return false;
541
542        if (itr->data_chain.empty())
543        {
544            if (cur_node->value_leaf.data_chain)
545                // The data chain should be empty (i.e. the pointer should be nullptr).
546                return false;
547        }
548        else
549        {
550            if (!cur_node->value_leaf.data_chain)
551                // This node should have data pointers!
552                return false;
553
554            data_chain_type chain1 = itr->data_chain;
555            data_chain_type chain2 = *cur_node->value_leaf.data_chain;
556
557            if (chain1.size() != chain2.size())
558                return false;
559
560            ::std::vector<value_type> test1, test2;
561            test1.reserve(chain1.size());
562            test2.reserve(chain2.size());
563            copy(chain1.begin(), chain1.end(), back_inserter(test1));
564            copy(chain2.begin(), chain2.end(), back_inserter(test2));
565
566            // Sort both arrays before comparing them.
567            sort(test1.begin(), test1.end());
568            sort(test2.begin(), test2.end());
569
570            if (test1 != test2)
571                return false;
572        }
573
574        cur_node = cur_node->next.get();
575    }
576
577    if (cur_node)
578        // At this point, we expect the current node to be at the position
579        // past the right-most node, which is nullptr.
580        return false;
581
582    return true;
583}
584
585template<typename _Key, typename _Value>
586bool segment_tree<_Key, _Value>::verify_segment_data(const segment_map_type& checks) const
587{
588    // Sort the data by key values first.
589    sorted_segment_map_type seg1(checks.begin(), checks.end());
590    sorted_segment_map_type seg2(m_segment_data.begin(), m_segment_data.end());
591
592    typename sorted_segment_map_type::const_iterator itr1 = seg1.begin(), itr1_end = seg1.end();
593    typename sorted_segment_map_type::const_iterator itr2 = seg2.begin(), itr2_end = seg2.end();
594    for (; itr1 != itr1_end; ++itr1, ++itr2)
595    {
596        if (itr2 == itr2_end)
597            return false;
598
599        if (*itr1 != *itr2)
600            return false;
601    }
602    if (itr2 != itr2_end)
603        return false;
604
605    return true;
606}
607
608template<typename _Key, typename _Value>
609bool segment_tree<_Key, _Value>::has_data_pointer(const node_list_type& node_list, const value_type pdata)
610{
611    using namespace std;
612
613    typename node_list_type::const_iterator
614        itr = node_list.begin(), itr_end = node_list.end();
615
616    for (; itr != itr_end; ++itr)
617    {
618        // Check each node, and make sure each node has the pdata pointer
619        // listed.
620        const __st::node_base* pnode = *itr;
621        const data_chain_type* chain = nullptr;
622        if (pnode->is_leaf)
623            chain = static_cast<const node*>(pnode)->value_leaf.data_chain;
624        else
625            chain = static_cast<const nonleaf_node*>(pnode)->value_nonleaf.data_chain;
626
627        if (!chain)
628            return false;
629
630        if (find(chain->begin(), chain->end(), pdata) == chain->end())
631            return false;
632    }
633    return true;
634}
635
636template<typename _Key, typename _Value>
637void segment_tree<_Key, _Value>::print_leaf_value(const leaf_value_type& v)
638{
639    using namespace std;
640    cout << v.key << ": { ";
641    if (v.data_chain)
642    {
643        const data_chain_type* pchain = v.data_chain;
644        typename data_chain_type::const_iterator itr, itr_beg = pchain->begin(), itr_end = pchain->end();
645        for (itr = itr_beg; itr != itr_end; ++itr)
646        {
647            if (itr != itr_beg)
648                cout << ", ";
649            cout << (*itr)->name;
650        }
651    }
652    cout << " }" << endl;
653}
654#endif
655
656}
657