1 /**
2 * @file core/tree/binary_space_tree/ub_tree_split_impl.hpp
3 * @author Mikhail Lozhnikov
4 *
5 * Implementation of UBTreeSplit, a class that splits a node according
6 * to the median address of points contained in the node.
7 *
8 * mlpack is free software; you may redistribute it and/or modify it under the
9 * terms of the 3-clause BSD license. You should have received a copy of the
10 * 3-clause BSD license along with mlpack. If not, see
11 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12 */
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
15
16 #include "ub_tree_split.hpp"
17 #include <mlpack/core/tree/bounds.hpp>
18
19 namespace mlpack {
20 namespace tree {
21
22 template<typename BoundType, typename MatType>
SplitNode(BoundType & bound,MatType & data,const size_t begin,const size_t count,SplitInfo & splitInfo)23 bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
24 MatType& data,
25 const size_t begin,
26 const size_t count,
27 SplitInfo& splitInfo)
28 {
29 constexpr size_t order = sizeof(AddressElemType) * CHAR_BIT;
30 if (begin == 0 && count == data.n_cols)
31 {
32 // Calculate all addresses.
33 InitializeAddresses(data);
34
35 // Probably this is not a good idea. Maybe it is better to get
36 // a number of distinct samples and find the median.
37 std::sort(addresses.begin(), addresses.end(), ComparePair);
38
39 // Save the vector in order to rearrange the dataset later.
40 splitInfo.addresses = &addresses;
41 }
42 else
43 {
44 // We have already rearranged the dataset.
45 splitInfo.addresses = NULL;
46 }
47
48 // The bound shouldn't contain too many subrectangles.
49 // In order to minimize the number of hyperrectangles we set last bits
50 // of the last address in the node to 1 and last bits of the first address
51 // in the next node to zero in such a way that the ordering is not
52 // disturbed.
53 if (begin + count < data.n_cols)
54 {
55 // Omit leading equal bits.
56 size_t row = 0;
57 arma::Col<AddressElemType>& lo = addresses[begin + count - 1].first;
58 const arma::Col<AddressElemType>& hi = addresses[begin + count].first;
59
60 for (; row < data.n_rows; row++)
61 if (lo[row] != hi[row])
62 break;
63
64 size_t bit = 0;
65
66 for (; bit < order; bit++)
67 if ((lo[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
68 (hi[row] & ((AddressElemType) 1 << (order - 1 - bit))))
69 break;
70
71 bit++;
72
73 // Replace insignificant bits.
74 if (bit == order)
75 {
76 bit = 0;
77 row++;
78 }
79 else
80 {
81 for (; bit < order; bit++)
82 lo[row] |= ((AddressElemType) 1 << (order - 1 - bit));
83 row++;
84 }
85
86 for (; row < data.n_rows; row++)
87 for (; bit < order; bit++)
88 lo[row] |= ((AddressElemType) 1 << (order - 1 - bit));
89 }
90
91 // The bound shouldn't contain too many subrectangles.
92 // In order to minimize the number of hyperrectangles we set last bits
93 // of the first address in the next node to 0 and last bits of the last
94 // address in the previous node to 1 in such a way that the ordering is not
95 // disturbed.
96 if (begin > 0)
97 {
98 // Omit leading equal bits.
99 size_t row = 0;
100 const arma::Col<AddressElemType>& lo = addresses[begin - 1].first;
101 arma::Col<AddressElemType>& hi = addresses[begin].first;
102
103 for (; row < data.n_rows; row++)
104 if (lo[row] != hi[row])
105 break;
106
107 size_t bit = 0;
108
109 for (; bit < order; bit++)
110 if ((lo[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
111 (hi[row] & ((AddressElemType) 1 << (order - 1 - bit))))
112 break;
113
114 bit++;
115
116 // Replace insignificant bits.
117 if (bit == order)
118 {
119 bit = 0;
120 row++;
121 }
122 else
123 {
124 for (; bit < order; bit++)
125 hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit));
126 row++;
127 }
128
129 for (; row < data.n_rows; row++)
130 for (; bit < order; bit++)
131 hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit));
132 }
133
134 // Set the minimum and the maximum addresses.
135 for (size_t k = 0; k < bound.Dim(); ++k)
136 {
137 bound.LoAddress()[k] = addresses[begin].first[k];
138 bound.HiAddress()[k] = addresses[begin + count - 1].first[k];
139 }
140 bound.UpdateAddressBounds(data.cols(begin, begin + count - 1));
141
142 return true;
143 }
144
145 template<typename BoundType, typename MatType>
InitializeAddresses(const MatType & data)146 void UBTreeSplit<BoundType, MatType>::InitializeAddresses(const MatType& data)
147 {
148 addresses.resize(data.n_cols);
149
150 // Calculate all addresses.
151 for (size_t i = 0; i < data.n_cols; ++i)
152 {
153 addresses[i].first.zeros(data.n_rows);
154 bound::addr::PointToAddress(addresses[i].first, data.col(i));
155 addresses[i].second = i;
156 }
157 }
158
159 template<typename BoundType, typename MatType>
PerformSplit(MatType & data,const size_t begin,const size_t count,const SplitInfo & splitInfo)160 size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
161 MatType& data,
162 const size_t begin,
163 const size_t count,
164 const SplitInfo& splitInfo)
165 {
166 // For the first time we have to rearrange the dataset.
167 if (splitInfo.addresses)
168 {
169 std::vector<size_t> newFromOld(data.n_cols);
170 std::vector<size_t> oldFromNew(data.n_cols);
171
172 for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
173 {
174 newFromOld[i] = i;
175 oldFromNew[i] = i;
176 }
177
178 for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
179 {
180 size_t index = (*splitInfo.addresses)[i].second;
181 size_t oldI = oldFromNew[i];
182 size_t newIndex = newFromOld[index];
183
184 data.swap_cols(i, newFromOld[index]);
185
186 size_t tmp = newFromOld[index];
187 newFromOld[index] = i;
188 newFromOld[oldI] = tmp;
189
190 tmp = oldFromNew[i];
191 oldFromNew[i] = oldFromNew[newIndex];
192 oldFromNew[newIndex] = tmp;
193 }
194 }
195
196 // Since the dataset is sorted we can easily obtain the split column.
197 return begin + count / 2;
198 }
199
200 template<typename BoundType, typename MatType>
PerformSplit(MatType & data,const size_t begin,const size_t count,const SplitInfo & splitInfo,std::vector<size_t> & oldFromNew)201 size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
202 MatType& data,
203 const size_t begin,
204 const size_t count,
205 const SplitInfo& splitInfo,
206 std::vector<size_t>& oldFromNew)
207 {
208 // For the first time we have to rearrange the dataset.
209 if (splitInfo.addresses)
210 {
211 std::vector<size_t> newFromOld(data.n_cols);
212
213 for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
214 newFromOld[i] = i;
215
216 for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
217 {
218 size_t index = (*splitInfo.addresses)[i].second;
219 size_t oldI = oldFromNew[i];
220 size_t newIndex = newFromOld[index];
221
222 data.swap_cols(i, newFromOld[index]);
223
224 size_t tmp = newFromOld[index];
225 newFromOld[index] = i;
226 newFromOld[oldI] = tmp;
227
228 tmp = oldFromNew[i];
229 oldFromNew[i] = oldFromNew[newIndex];
230 oldFromNew[newIndex] = tmp;
231 }
232 }
233
234 // Since the dataset is sorted we can easily obtain the split column.
235 return begin + count / 2;
236 }
237
238 } // namespace tree
239 } // namespace mlpack
240
241 #endif
242