1 #ifndef DIY_DETAIL_ALGORITHMS_KDTREE_HPP
2 #define DIY_DETAIL_ALGORITHMS_KDTREE_HPP
3
4 #include <vector>
5 #include <cassert>
6 #include "../../partners/all-reduce.hpp"
7 #include "../../log.hpp"
8
9 namespace diy
10 {
11 namespace detail
12 {
13
14 struct KDTreePartners;
15
16 template<class Block, class Point>
17 struct KDTreePartition
18 {
19 typedef diy::RegularContinuousLink RCLink;
20 typedef diy::ContinuousBounds Bounds;
21
22 typedef std::vector<size_t> Histogram;
23
KDTreePartitiondiy::detail::KDTreePartition24 KDTreePartition(int dim,
25 std::vector<Point> Block::* points,
26 size_t bins):
27 dim_(dim), points_(points), bins_(bins) {}
28
29 void operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const;
30
31 int divide_gid(int gid, bool lower, int round, int rounds) const;
32 void update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const;
33 void split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const;
34 diy::Direction
35 find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const;
36
37 void compute_local_histogram(Block* b, const diy::ReduceProxy& srp, int dim) const;
38 void add_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const;
39 void receive_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const;
40 void forward_histogram(Block* b, const diy::ReduceProxy& srp, const Histogram& histogram) const;
41
42 void enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Histogram& histogram) const;
43 void dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const;
44
45 void update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const;
46 bool intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const;
47 float find_split(const Bounds& changed, const Bounds& original) const;
48
49 int dim_;
50 std::vector<Point> Block::* points_;
51 size_t bins_;
52 };
53
54 }
55 }
56
57 struct diy::detail::KDTreePartners
58 {
59 // bool = are we in a swap (vs histogram) round
60 // int = round within that partner
61 typedef std::pair<bool, int> RoundType;
62 typedef diy::ContinuousBounds Bounds;
63
KDTreePartnersdiy::detail::KDTreePartners64 KDTreePartners(int dim, int nblocks, bool wrap_, const Bounds& domain_):
65 decomposer(1, interval(0,nblocks-1), nblocks),
66 histogram(decomposer, 2),
67 swap(decomposer, 2, false),
68 wrap(wrap_),
69 domain(domain_)
70 {
71 for (unsigned i = 0; i < swap.rounds(); ++i)
72 {
73 // fill histogram rounds
74 for (unsigned j = 0; j < histogram.rounds(); ++j)
75 {
76 rounds_.push_back(std::make_pair(false, j));
77 dim_.push_back(i % dim);
78 if (j == histogram.rounds() / 2 - 1 - i)
79 j += 2*i;
80 }
81
82 // fill swap round
83 rounds_.push_back(std::make_pair(true, i));
84 dim_.push_back(i % dim);
85
86 // fill link round
87 rounds_.push_back(std::make_pair(true, -1)); // (true, -1) signals link round
88 dim_.push_back(i % dim);
89 }
90 }
91
roundsdiy::detail::KDTreePartners92 size_t rounds() const { return rounds_.size(); }
swap_roundsdiy::detail::KDTreePartners93 size_t swap_rounds() const { return swap.rounds(); }
94
dimdiy::detail::KDTreePartners95 int dim(int round) const { return dim_[round]; }
swap_rounddiy::detail::KDTreePartners96 bool swap_round(int round) const { return rounds_[round].first; }
sub_rounddiy::detail::KDTreePartners97 int sub_round(int round) const { return rounds_[round].second; }
98
activediy::detail::KDTreePartners99 inline bool active(int round, int gid, const diy::Master& m) const
100 {
101 if (round == (int) rounds())
102 return true;
103 else if (swap_round(round) && sub_round(round) < 0) // link round
104 return true;
105 else if (swap_round(round))
106 return swap.active(sub_round(round), gid, m);
107 else
108 return histogram.active(sub_round(round), gid, m);
109 }
110
incomingdiy::detail::KDTreePartners111 inline void incoming(int round, int gid, std::vector<int>& partners, const diy::Master& m) const
112 {
113 if (round == (int) rounds())
114 link_neighbors(-1, gid, partners, m);
115 else if (swap_round(round) && sub_round(round) < 0) // link round
116 swap.incoming(sub_round(round - 1) + 1, gid, partners, m);
117 else if (swap_round(round))
118 histogram.incoming(histogram.rounds(), gid, partners, m);
119 else
120 {
121 if (round > 0 && sub_round(round) == 0)
122 link_neighbors(-1, gid, partners, m);
123 else if (round > 0 && sub_round(round - 1) != sub_round(round) - 1) // jump through the histogram rounds
124 histogram.incoming(sub_round(round - 1) + 1, gid, partners, m);
125 else
126 histogram.incoming(sub_round(round), gid, partners, m);
127 }
128 }
129
outgoingdiy::detail::KDTreePartners130 inline void outgoing(int round, int gid, std::vector<int>& partners, const diy::Master& m) const
131 {
132 if (round == (int) rounds())
133 swap.outgoing(sub_round(round-1) + 1, gid, partners, m);
134 else if (swap_round(round) && sub_round(round) < 0) // link round
135 link_neighbors(-1, gid, partners, m);
136 else if (swap_round(round))
137 swap.outgoing(sub_round(round), gid, partners, m);
138 else
139 histogram.outgoing(sub_round(round), gid, partners, m);
140 }
141
link_neighborsdiy::detail::KDTreePartners142 inline void link_neighbors(int, int gid, std::vector<int>& partners, const diy::Master& m) const
143 {
144 int lid = m.lid(gid);
145 diy::Link* link = m.link(lid);
146
147 std::set<int> result; // partners must be unique
148 for (int i = 0; i < link->size(); ++i)
149 result.insert(link->target(i).gid);
150
151 for (std::set<int>::const_iterator it = result.begin(); it != result.end(); ++it)
152 partners.push_back(*it);
153 }
154
155 // 1-D domain to feed into histogram and swap
156 diy::RegularDecomposer<diy::DiscreteBounds> decomposer;
157
158 diy::RegularAllReducePartners histogram;
159 diy::RegularSwapPartners swap;
160
161 std::vector<RoundType> rounds_;
162 std::vector<int> dim_;
163
164 bool wrap;
165 Bounds domain;
166 };
167
168 template<class Block, class Point>
169 void
170 diy::detail::KDTreePartition<Block,Point>::
operator ()(Block * b,const diy::ReduceProxy & srp,const KDTreePartners & partners) const171 operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const
172 {
173 int dim;
174 if (srp.round() < partners.rounds())
175 dim = partners.dim(srp.round());
176 else
177 dim = partners.dim(srp.round() - 1);
178
179 if (srp.round() == partners.rounds())
180 update_links(b, srp, dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain); // -1 would be the "uninformative" link round
181 else if (partners.swap_round(srp.round()) && partners.sub_round(srp.round()) < 0) // link round
182 {
183 dequeue_exchange(b, srp, dim); // from the swap round
184 split_to_neighbors(b, srp, dim);
185 }
186 else if (partners.swap_round(srp.round()))
187 {
188 Histogram histogram;
189 receive_histogram(b, srp, histogram);
190 enqueue_exchange(b, srp, dim, histogram);
191 } else if (partners.sub_round(srp.round()) == 0)
192 {
193 if (srp.round() > 0)
194 {
195 int prev_dim = dim - 1;
196 if (prev_dim < 0)
197 prev_dim += dim_;
198 update_links(b, srp, prev_dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain); // -1 would be the "uninformative" link round
199 }
200
201 compute_local_histogram(b, srp, dim);
202 } else if (partners.sub_round(srp.round()) < (int) partners.histogram.rounds()/2)
203 {
204 Histogram histogram(bins_);
205 add_histogram(b, srp, histogram);
206 srp.enqueue(srp.out_link().target(0), histogram);
207 }
208 else
209 {
210 Histogram histogram(bins_);
211 add_histogram(b, srp, histogram);
212 forward_histogram(b, srp, histogram);
213 }
214 }
215
216 template<class Block, class Point>
217 int
218 diy::detail::KDTreePartition<Block,Point>::
divide_gid(int gid,bool lower,int round,int rounds) const219 divide_gid(int gid, bool lower, int round, int rounds) const
220 {
221 if (lower)
222 gid &= ~(1 << (rounds - 1 - round));
223 else
224 gid |= (1 << (rounds - 1 - round));
225 return gid;
226 }
227
228 // round here is the outer iteration of the algorithm
229 template<class Block, class Point>
230 void
231 diy::detail::KDTreePartition<Block,Point>::
update_links(Block * b,const diy::ReduceProxy & srp,int dim,int round,int rounds,bool wrap,const Bounds & domain) const232 update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const
233 {
234 int gid = srp.gid();
235 int lid = srp.master()->lid(gid);
236 RCLink* link = static_cast<RCLink*>(srp.master()->link(lid));
237
238 // (gid, dir) -> i
239 std::map<std::pair<int,diy::Direction>, int> link_map;
240 for (int i = 0; i < link->size(); ++i)
241 link_map[std::make_pair(link->target(i).gid, link->direction(i))] = i;
242
243 // NB: srp.enqueue(..., ...) should match the link
244 std::vector<float> splits(link->size());
245 for (int i = 0; i < link->size(); ++i)
246 {
247 float split; diy::Direction dir(dim_,0);
248
249 int in_gid = link->target(i).gid;
250 while(srp.incoming(in_gid))
251 {
252 srp.dequeue(in_gid, split);
253 srp.dequeue(in_gid, dir);
254
255 // reverse dir
256 for (int j = 0; j < dim_; ++j)
257 dir[j] = -dir[j];
258
259 int k = link_map[std::make_pair(in_gid, dir)];
260 splits[k] = split;
261 }
262 }
263
264 RCLink new_link(dim_, link->core(), link->core());
265
266 bool lower = !(gid & (1 << (rounds - 1 - round)));
267
268 // fill out the new link
269 for (int i = 0; i < link->size(); ++i)
270 {
271 diy::Direction dir = link->direction(i);
272 //diy::Direction wrap_dir = link->wrap(i); // we don't use existing wrap, but restore it from scratch
273 if (dir[dim] != 0)
274 {
275 if ((dir[dim] < 0 && lower) || (dir[dim] > 0 && !lower))
276 {
277 int nbr_gid = divide_gid(link->target(i).gid, !lower, round, rounds);
278 diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) };
279 new_link.add_neighbor(nbr);
280
281 new_link.add_direction(dir);
282
283 Bounds bounds = link->bounds(i);
284 update_neighbor_bounds(bounds, splits[i], dim, !lower);
285 new_link.add_bounds(bounds);
286
287 if (wrap)
288 new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
289 else
290 new_link.add_wrap(diy::Direction(dim_,0));
291 }
292 } else // non-aligned side
293 {
294 for (int j = 0; j < 2; ++j)
295 {
296 int nbr_gid = divide_gid(link->target(i).gid, j == 0, round, rounds);
297
298 Bounds bounds = link->bounds(i);
299 update_neighbor_bounds(bounds, splits[i], dim, j == 0);
300
301 if (intersects(bounds, new_link.bounds(), dim, wrap, domain))
302 {
303 diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) };
304 new_link.add_neighbor(nbr);
305 new_link.add_direction(dir);
306 new_link.add_bounds(bounds);
307
308 if (wrap)
309 new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
310 else
311 new_link.add_wrap(diy::Direction(dim_, 0));
312 }
313 }
314 }
315 }
316
317 // add link to the dual block
318 int dual_gid = divide_gid(gid, !lower, round, rounds);
319 diy::BlockID dual = { dual_gid, srp.assigner().rank(dual_gid) };
320 new_link.add_neighbor(dual);
321
322 Bounds nbr_bounds = link->bounds(); // old block bounds
323 update_neighbor_bounds(nbr_bounds, find_split(new_link.bounds(), nbr_bounds), dim, !lower);
324 new_link.add_bounds(nbr_bounds);
325
326 new_link.add_wrap(diy::Direction(dim_,0)); // dual block cannot be wrapped
327
328 if (lower)
329 {
330 diy::Direction right(dim_,0);
331 right[dim] = 1;
332 new_link.add_direction(right);
333 } else
334 {
335 diy::Direction left(dim_,0);
336 left[dim] = -1;
337 new_link.add_direction(left);
338 }
339
340 // update the link; notice that this won't conflict with anything since
341 // reduce is using its own notion of the link constructed through the
342 // partners
343 link->swap(new_link);
344 }
345
346 template<class Block, class Point>
347 void
348 diy::detail::KDTreePartition<Block,Point>::
split_to_neighbors(Block * b,const diy::ReduceProxy & srp,int dim) const349 split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const
350 {
351 int lid = srp.master()->lid(srp.gid());
352 RCLink* link = static_cast<RCLink*>(srp.master()->link(lid));
353
354 // determine split
355 float split = find_split(link->core(), link->bounds());
356
357 for (int i = 0; i < link->size(); ++i)
358 {
359 srp.enqueue(link->target(i), split);
360 srp.enqueue(link->target(i), link->direction(i));
361 }
362 }
363
364 template<class Block, class Point>
365 void
366 diy::detail::KDTreePartition<Block,Point>::
compute_local_histogram(Block * b,const diy::ReduceProxy & srp,int dim) const367 compute_local_histogram(Block* b, const diy::ReduceProxy& srp, int dim) const
368 {
369 int lid = srp.master()->lid(srp.gid());
370 RCLink* link = static_cast<RCLink*>(srp.master()->link(lid));
371
372 // compute and enqueue local histogram
373 Histogram histogram(bins_);
374
375 float width = (link->core().max[dim] - link->core().min[dim])/bins_;
376 for (size_t i = 0; i < (b->*points_).size(); ++i)
377 {
378 float x = (b->*points_)[i][dim];
379 int loc = (x - link->core().min[dim]) / width;
380 if (loc < 0)
381 throw std::runtime_error(fmt::format("{} {} {}", loc, x, link->core().min[dim]));
382 if (loc >= (int) bins_)
383 loc = bins_ - 1;
384 ++(histogram[loc]);
385 }
386
387 srp.enqueue(srp.out_link().target(0), histogram);
388 }
389
390 template<class Block, class Point>
391 void
392 diy::detail::KDTreePartition<Block,Point>::
add_histogram(Block * b,const diy::ReduceProxy & srp,Histogram & histogram) const393 add_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const
394 {
395 // dequeue and add up the histograms
396 for (int i = 0; i < srp.in_link().size(); ++i)
397 {
398 int nbr_gid = srp.in_link().target(i).gid;
399
400 Histogram hist;
401 srp.dequeue(nbr_gid, hist);
402 for (size_t j = 0; j < hist.size(); ++j)
403 histogram[j] += hist[j];
404 }
405 }
406
407 template<class Block, class Point>
408 void
409 diy::detail::KDTreePartition<Block,Point>::
receive_histogram(Block * b,const diy::ReduceProxy & srp,Histogram & histogram) const410 receive_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const
411 {
412 srp.dequeue(srp.in_link().target(0).gid, histogram);
413 }
414
415 template<class Block, class Point>
416 void
417 diy::detail::KDTreePartition<Block,Point>::
forward_histogram(Block * b,const diy::ReduceProxy & srp,const Histogram & histogram) const418 forward_histogram(Block* b, const diy::ReduceProxy& srp, const Histogram& histogram) const
419 {
420 for (int i = 0; i < srp.out_link().size(); ++i)
421 srp.enqueue(srp.out_link().target(i), histogram);
422 }
423
424 template<class Block, class Point>
425 void
426 diy::detail::KDTreePartition<Block,Point>::
enqueue_exchange(Block * b,const diy::ReduceProxy & srp,int dim,const Histogram & histogram) const427 enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Histogram& histogram) const
428 {
429 auto log = get_logger();
430
431 int lid = srp.master()->lid(srp.gid());
432 RCLink* link = static_cast<RCLink*>(srp.master()->link(lid));
433
434 int k = srp.out_link().size();
435
436 if (k == 0) // final round; nothing needs to be sent; this is actually redundant
437 return;
438
439 // pick split points
440 size_t total = 0;
441 for (size_t i = 0; i < histogram.size(); ++i)
442 total += histogram[i];
443 log->trace("Histogram total: {}", total);
444
445 size_t cur = 0;
446 float width = (link->core().max[dim] - link->core().min[dim])/bins_;
447 float split = 0;
448 size_t i = 0;
449 for (; i < histogram.size(); ++i)
450 {
451 if (cur + histogram[i] > total/2)
452 break;
453 cur += histogram[i];
454 }
455 if (i == 0)
456 ++i;
457 else if (i >= histogram.size() - 1)
458 i = histogram.size() - 2;
459 split = link->core().min[dim] + width*i;
460 log->trace("Found split: {} (dim={}) in {} - {}", split, dim, link->core().min[dim], link->core().max[dim]);
461
462 // subset and enqueue
463 std::vector< std::vector<Point> > out_points(srp.out_link().size());
464 for (size_t i = 0; i < (b->*points_).size(); ++i)
465 {
466 float x = (b->*points_)[i][dim];
467 int loc = x < split ? 0 : 1;
468 out_points[loc].push_back((b->*points_)[i]);
469 }
470 int pos = -1;
471 for (int i = 0; i < k; ++i)
472 {
473 if (srp.out_link().target(i).gid == srp.gid())
474 {
475 (b->*points_).swap(out_points[i]);
476 pos = i;
477 }
478 else
479 srp.enqueue(srp.out_link().target(i), out_points[i]);
480 }
481 if (pos == 0)
482 link->core().max[dim] = split;
483 else
484 link->core().min[dim] = split;
485 }
486
487 template<class Block, class Point>
488 void
489 diy::detail::KDTreePartition<Block,Point>::
dequeue_exchange(Block * b,const diy::ReduceProxy & srp,int dim) const490 dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const
491 {
492 int lid = srp.master()->lid(srp.gid());
493 RCLink* link = static_cast<RCLink*>(srp.master()->link(lid));
494
495 for (int i = 0; i < srp.in_link().size(); ++i)
496 {
497 int nbr_gid = srp.in_link().target(i).gid;
498 if (nbr_gid == srp.gid())
499 continue;
500
501 std::vector<Point> in_points;
502 srp.dequeue(nbr_gid, in_points);
503 for (size_t j = 0; j < in_points.size(); ++j)
504 {
505 if (in_points[j][dim] < link->core().min[dim] || in_points[j][dim] > link->core().max[dim])
506 throw std::runtime_error(fmt::format("Dequeued {} outside [{},{}] ({})",
507 in_points[j][dim], link->core().min[dim], link->core().max[dim], dim));
508 (b->*points_).push_back(in_points[j]);
509 }
510 }
511 }
512
513 template<class Block, class Point>
514 void
515 diy::detail::KDTreePartition<Block,Point>::
update_neighbor_bounds(Bounds & bounds,float split,int dim,bool lower) const516 update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const
517 {
518 if (lower)
519 bounds.max[dim] = split;
520 else
521 bounds.min[dim] = split;
522 }
523
524 template<class Block, class Point>
525 bool
526 diy::detail::KDTreePartition<Block,Point>::
intersects(const Bounds & x,const Bounds & y,int dim,bool wrap,const Bounds & domain) const527 intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const
528 {
529 if (wrap)
530 {
531 if (x.min[dim] == domain.min[dim] && y.max[dim] == domain.max[dim])
532 return true;
533 if (y.min[dim] == domain.min[dim] && x.max[dim] == domain.max[dim])
534 return true;
535 }
536 return x.min[dim] <= y.max[dim] && y.min[dim] <= x.max[dim];
537 }
538
539 template<class Block, class Point>
540 float
541 diy::detail::KDTreePartition<Block,Point>::
find_split(const Bounds & changed,const Bounds & original) const542 find_split(const Bounds& changed, const Bounds& original) const
543 {
544 for (int i = 0; i < dim_; ++i)
545 {
546 if (changed.min[i] != original.min[i])
547 return changed.min[i];
548 if (changed.max[i] != original.max[i])
549 return changed.max[i];
550 }
551 assert(0);
552 return -1;
553 }
554
555 template<class Block, class Point>
556 diy::Direction
557 diy::detail::KDTreePartition<Block,Point>::
find_wrap(const Bounds & bounds,const Bounds & nbr_bounds,const Bounds & domain) const558 find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const
559 {
560 diy::Direction wrap(dim_,0);
561 for (int i = 0; i < dim_; ++i)
562 {
563 if (bounds.min[i] == domain.min[i] && nbr_bounds.max[i] == domain.max[i])
564 wrap[i] = -1;
565 if (bounds.max[i] == domain.max[i] && nbr_bounds.min[i] == domain.min[i])
566 wrap[i] = 1;
567 }
568 return wrap;
569 }
570
571
572 #endif
573