1 /**
2  * @file core/tree/binary_space_tree/ub_tree_split_impl.hpp
3  * @author Mikhail Lozhnikov
4  *
5  * Implementation of UBTreeSplit, a class that splits a node according
6  * to the median address of points contained in the node.
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
15 
16 #include "ub_tree_split.hpp"
17 #include <mlpack/core/tree/bounds.hpp>
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename BoundType, typename MatType>
SplitNode(BoundType & bound,MatType & data,const size_t begin,const size_t count,SplitInfo & splitInfo)23 bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
24                                                 MatType& data,
25                                                 const size_t begin,
26                                                 const size_t count,
27                                                 SplitInfo& splitInfo)
28 {
29   constexpr size_t order = sizeof(AddressElemType) * CHAR_BIT;
30   if (begin == 0 && count == data.n_cols)
31   {
32     // Calculate all addresses.
33     InitializeAddresses(data);
34 
35     // Probably this is not a good idea. Maybe it is better to get
36     // a number of distinct samples and find the median.
37     std::sort(addresses.begin(), addresses.end(), ComparePair);
38 
39     // Save the vector in order to rearrange the dataset later.
40     splitInfo.addresses = &addresses;
41   }
42   else
43   {
44     // We have already rearranged the dataset.
45     splitInfo.addresses = NULL;
46   }
47 
48   // The bound shouldn't contain too many subrectangles.
49   // In order to minimize the number of hyperrectangles we set last bits
50   // of the last address in the node to 1 and last bits of the first  address
51   // in the next node to zero in such a way that the ordering is not
52   // disturbed.
53   if (begin + count < data.n_cols)
54   {
55     // Omit leading equal bits.
56     size_t row = 0;
57     arma::Col<AddressElemType>& lo = addresses[begin + count - 1].first;
58     const arma::Col<AddressElemType>& hi = addresses[begin + count].first;
59 
60     for (; row < data.n_rows; row++)
61       if (lo[row] != hi[row])
62         break;
63 
64     size_t bit = 0;
65 
66     for (; bit < order; bit++)
67       if ((lo[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
68           (hi[row] & ((AddressElemType) 1 << (order - 1 - bit))))
69         break;
70 
71     bit++;
72 
73     // Replace insignificant bits.
74     if (bit == order)
75     {
76       bit = 0;
77       row++;
78     }
79     else
80     {
81       for (; bit < order; bit++)
82         lo[row] |= ((AddressElemType) 1 << (order - 1 - bit));
83       row++;
84     }
85 
86     for (; row < data.n_rows; row++)
87       for (; bit < order; bit++)
88         lo[row] |= ((AddressElemType) 1 << (order - 1 - bit));
89   }
90 
91   // The bound shouldn't contain too many subrectangles.
92   // In order to minimize the number of hyperrectangles we set last bits
93   // of the first address in the next node to 0 and last bits of the last
94   // address in the previous node to 1 in such a way that the ordering is not
95   // disturbed.
96   if (begin > 0)
97   {
98     // Omit leading equal bits.
99     size_t row = 0;
100     const arma::Col<AddressElemType>& lo = addresses[begin - 1].first;
101     arma::Col<AddressElemType>& hi = addresses[begin].first;
102 
103     for (; row < data.n_rows; row++)
104       if (lo[row] != hi[row])
105         break;
106 
107     size_t bit = 0;
108 
109     for (; bit < order; bit++)
110       if ((lo[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
111           (hi[row] & ((AddressElemType) 1 << (order - 1 - bit))))
112         break;
113 
114     bit++;
115 
116     // Replace insignificant bits.
117     if (bit == order)
118     {
119       bit = 0;
120       row++;
121     }
122     else
123     {
124       for (; bit < order; bit++)
125         hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit));
126       row++;
127     }
128 
129     for (; row < data.n_rows; row++)
130       for (; bit < order; bit++)
131         hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit));
132   }
133 
134   // Set the minimum and the maximum addresses.
135   for (size_t k = 0; k < bound.Dim(); ++k)
136   {
137     bound.LoAddress()[k] = addresses[begin].first[k];
138     bound.HiAddress()[k] = addresses[begin + count - 1].first[k];
139   }
140   bound.UpdateAddressBounds(data.cols(begin, begin + count - 1));
141 
142   return true;
143 }
144 
145 template<typename BoundType, typename MatType>
InitializeAddresses(const MatType & data)146 void UBTreeSplit<BoundType, MatType>::InitializeAddresses(const MatType& data)
147 {
148   addresses.resize(data.n_cols);
149 
150   // Calculate all addresses.
151   for (size_t i = 0; i < data.n_cols; ++i)
152   {
153     addresses[i].first.zeros(data.n_rows);
154     bound::addr::PointToAddress(addresses[i].first, data.col(i));
155     addresses[i].second = i;
156   }
157 }
158 
159 template<typename BoundType, typename MatType>
PerformSplit(MatType & data,const size_t begin,const size_t count,const SplitInfo & splitInfo)160 size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
161     MatType& data,
162     const size_t begin,
163     const size_t count,
164     const SplitInfo& splitInfo)
165 {
166   // For the first time we have to rearrange the dataset.
167   if (splitInfo.addresses)
168   {
169     std::vector<size_t> newFromOld(data.n_cols);
170     std::vector<size_t> oldFromNew(data.n_cols);
171 
172     for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
173     {
174       newFromOld[i] = i;
175       oldFromNew[i] = i;
176     }
177 
178     for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
179     {
180       size_t index = (*splitInfo.addresses)[i].second;
181       size_t oldI = oldFromNew[i];
182       size_t newIndex = newFromOld[index];
183 
184       data.swap_cols(i, newFromOld[index]);
185 
186       size_t tmp = newFromOld[index];
187       newFromOld[index] = i;
188       newFromOld[oldI] = tmp;
189 
190       tmp = oldFromNew[i];
191       oldFromNew[i] = oldFromNew[newIndex];
192       oldFromNew[newIndex] = tmp;
193     }
194   }
195 
196   // Since the dataset is sorted we can easily obtain the split column.
197   return begin + count / 2;
198 }
199 
200 template<typename BoundType, typename MatType>
PerformSplit(MatType & data,const size_t begin,const size_t count,const SplitInfo & splitInfo,std::vector<size_t> & oldFromNew)201 size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
202     MatType& data,
203     const size_t begin,
204     const size_t count,
205     const SplitInfo& splitInfo,
206     std::vector<size_t>& oldFromNew)
207 {
208   // For the first time we have to rearrange the dataset.
209   if (splitInfo.addresses)
210   {
211     std::vector<size_t> newFromOld(data.n_cols);
212 
213     for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
214       newFromOld[i] = i;
215 
216     for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
217     {
218       size_t index = (*splitInfo.addresses)[i].second;
219       size_t oldI = oldFromNew[i];
220       size_t newIndex = newFromOld[index];
221 
222       data.swap_cols(i, newFromOld[index]);
223 
224       size_t tmp = newFromOld[index];
225       newFromOld[index] = i;
226       newFromOld[oldI] = tmp;
227 
228       tmp = oldFromNew[i];
229       oldFromNew[i] = oldFromNew[newIndex];
230       oldFromNew[newIndex] = tmp;
231     }
232   }
233 
234   // Since the dataset is sorted we can easily obtain the split column.
235   return begin + count / 2;
236 }
237 
238 } // namespace tree
239 } // namespace mlpack
240 
241 #endif
242