1 /**
2  * @file core/tree/binary_space_tree/ub_tree_split.hpp
3  * @author Mikhail Lozhnikov
4  *
5  * Definition of UBTreeSplit, a class that splits the space according
6  * to the median address of points contained in the node.
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "../address.hpp"
18 
19 namespace mlpack {
20 namespace tree /** Trees and tree-building procedures. */ {
21 
22 /**
23  * Split a node into two parts according to the median address of points
24  * contained in the node. The class reorders the dataset such that points
25  * with lower addresses belong to the left subtree and points with high
26  * addresses belong to the right subtree.
27  */
28 template<typename BoundType, typename MatType = arma::mat>
29 class UBTreeSplit
30 {
31  public:
32   //! The type of an address element.
33   typedef typename std::conditional<
34       sizeof(typename MatType::elem_type) * CHAR_BIT <= 32,
35       uint32_t,
36       uint64_t>::type AddressElemType;
37 
38   //! An information about the partition.
39   struct SplitInfo
40   {
41     //! This vector contains addresses of all points in the dataset.
42     std::vector<std::pair<arma::Col<AddressElemType>, size_t>>* addresses;
43   };
44 
45   /**
46    * Split the node according to the median address of points contained in the
47    * node.
48    *
49    * @param bound The bound used for this node.
50    * @param data The dataset used by the binary space tree.
51    * @param begin Index of the starting point in the dataset that belongs to
52    *    this node.
53    * @param count Number of points in this node.
54    * @param splitInfo An information about the split (not used here).
55    */
56   bool SplitNode(BoundType& bound,
57                  MatType& data,
58                  const size_t begin,
59                  const size_t count,
60                  SplitInfo&  splitInfo);
61 
62   /**
63    * Rearrange the dataset according to the addresses.
64    *
65    * @param data The dataset used by the binary space tree.
66    * @param begin Index of the starting point in the dataset that belongs to
67    *    this node.
68    * @param count Number of points in this node.
69    * @param splitInfo The information about the split.
70    */
71   static size_t PerformSplit(MatType& data,
72                              const size_t begin,
73                              const size_t count,
74                              const SplitInfo& splitInfo);
75 
76   /**
77    * Rearrange the dataset according to the addresses and return the list
78    * of changed indices.
79    *
80    * @param data The dataset used by the binary space tree.
81    * @param begin Index of the starting point in the dataset that belongs to
82    *    this node.
83    * @param count Number of points in this node.
84    * @param splitInfo The information about the split.
85    * @param oldFromNew Vector which will be filled with the old positions for
86    *    each new point.
87    */
88   static size_t PerformSplit(MatType& data,
89                              const size_t begin,
90                              const size_t count,
91                              const SplitInfo& splitInfo,
92                              std::vector<size_t>& oldFromNew);
93 
94  private:
95   //! This vector contains addresses of all points in the dataset.
96   std::vector<std::pair<arma::Col<AddressElemType>, size_t>> addresses;
97 
98   /**
99    * Calculate addresses for all points in the dataset.
100    *
101    * @param data The dataset used by the binary space tree.
102    */
103   void InitializeAddresses(const MatType& data);
104 
105   //! A comparator for sorting addresses.
ComparePair(const std::pair<arma::Col<AddressElemType>,size_t> & p1,const std::pair<arma::Col<AddressElemType>,size_t> & p2)106   static bool ComparePair(
107       const std::pair<arma::Col<AddressElemType>, size_t>& p1,
108       const std::pair<arma::Col<AddressElemType>, size_t>& p2)
109   {
110     return bound::addr::CompareAddresses(p1.first, p2.first) < 0;
111   }
112 };
113 
114 } // namespace tree
115 } // namespace mlpack
116 
117 // Include implementation.
118 #include "ub_tree_split_impl.hpp"
119 
120 #endif
121