1 #ifndef DIY_DETAIL_ALGORITHMS_KDTREE_HPP
2 #define DIY_DETAIL_ALGORITHMS_KDTREE_HPP
3 
4 #include <vector>
5 #include <cassert>
6 #include "../../partners/all-reduce.hpp"
7 #include "../../log.hpp"
8 
9 namespace diy
10 {
11 namespace detail
12 {
13 
14 struct KDTreePartners;
15 
16 template<class Block, class Point>
17 struct KDTreePartition
18 {
19     typedef     diy::RegularContinuousLink      RCLink;
20     typedef     diy::ContinuousBounds           Bounds;
21 
22     typedef     std::vector<size_t>             Histogram;
23 
KDTreePartitiondiy::detail::KDTreePartition24                 KDTreePartition(int                             dim,
25                                 std::vector<Point>  Block::*    points,
26                                 size_t                          bins):
27                     dim_(dim), points_(points), bins_(bins)            {}
28 
29     void        operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const;
30 
31     int         divide_gid(int gid, bool lower, int round, int rounds) const;
32     void        update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const;
33     void        split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const;
34     diy::Direction
35                 find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const;
36 
37     void        compute_local_histogram(Block* b, const diy::ReduceProxy& srp, int dim) const;
38     void        add_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const;
39     void        receive_histogram(Block* b, const diy::ReduceProxy& srp,       Histogram& histogram) const;
40     void        forward_histogram(Block* b, const diy::ReduceProxy& srp, const Histogram& histogram) const;
41 
42     void        enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Histogram& histogram) const;
43     void        dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const;
44 
45     void        update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const;
46     bool        intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const;
47     float       find_split(const Bounds& changed, const Bounds& original) const;
48 
49     int                             dim_;
50     std::vector<Point>  Block::*    points_;
51     size_t                          bins_;
52 };
53 
54 }
55 }
56 
57 struct diy::detail::KDTreePartners
58 {
59   // bool = are we in a swap (vs histogram) round
60   // int  = round within that partner
61   typedef           std::pair<bool, int>                    RoundType;
62   typedef           diy::ContinuousBounds                   Bounds;
63 
KDTreePartnersdiy::detail::KDTreePartners64                     KDTreePartners(int dim, int nblocks, bool wrap_, const Bounds& domain_):
65                         decomposer(1, interval(0,nblocks-1), nblocks),
66                         histogram(decomposer, 2),
67                         swap(decomposer, 2, false),
68                         wrap(wrap_),
69                         domain(domain_)
70   {
71     for (unsigned i = 0; i < swap.rounds(); ++i)
72     {
73       // fill histogram rounds
74       for (unsigned j = 0; j < histogram.rounds(); ++j)
75       {
76         rounds_.push_back(std::make_pair(false, j));
77         dim_.push_back(i % dim);
78         if (j == histogram.rounds() / 2 - 1 - i)
79             j += 2*i;
80       }
81 
82       // fill swap round
83       rounds_.push_back(std::make_pair(true, i));
84       dim_.push_back(i % dim);
85 
86       // fill link round
87       rounds_.push_back(std::make_pair(true, -1));          // (true, -1) signals link round
88       dim_.push_back(i % dim);
89     }
90   }
91 
roundsdiy::detail::KDTreePartners92   size_t        rounds() const                              { return rounds_.size(); }
swap_roundsdiy::detail::KDTreePartners93   size_t        swap_rounds() const                         { return swap.rounds(); }
94 
dimdiy::detail::KDTreePartners95   int           dim(int round) const                        { return dim_[round]; }
swap_rounddiy::detail::KDTreePartners96   bool          swap_round(int round) const                 { return rounds_[round].first; }
sub_rounddiy::detail::KDTreePartners97   int           sub_round(int round) const                  { return rounds_[round].second; }
98 
activediy::detail::KDTreePartners99   inline bool   active(int round, int gid, const diy::Master& m) const
100   {
101     if (round == (int) rounds())
102         return true;
103     else if (swap_round(round) && sub_round(round) < 0)     // link round
104         return true;
105     else if (swap_round(round))
106         return swap.active(sub_round(round), gid, m);
107     else
108         return histogram.active(sub_round(round), gid, m);
109   }
110 
incomingdiy::detail::KDTreePartners111   inline void   incoming(int round, int gid, std::vector<int>& partners, const diy::Master& m) const
112   {
113     if (round == (int) rounds())
114         link_neighbors(-1, gid, partners, m);
115     else if (swap_round(round) && sub_round(round) < 0)       // link round
116         swap.incoming(sub_round(round - 1) + 1, gid, partners, m);
117     else if (swap_round(round))
118         histogram.incoming(histogram.rounds(), gid, partners, m);
119     else
120     {
121         if (round > 0 && sub_round(round) == 0)
122             link_neighbors(-1, gid, partners, m);
123         else if (round > 0 && sub_round(round - 1) != sub_round(round) - 1)        // jump through the histogram rounds
124             histogram.incoming(sub_round(round - 1) + 1, gid, partners, m);
125         else
126             histogram.incoming(sub_round(round), gid, partners, m);
127     }
128   }
129 
outgoingdiy::detail::KDTreePartners130   inline void   outgoing(int round, int gid, std::vector<int>& partners, const diy::Master& m) const
131   {
132     if (round == (int) rounds())
133         swap.outgoing(sub_round(round-1) + 1, gid, partners, m);
134     else if (swap_round(round) && sub_round(round) < 0)       // link round
135         link_neighbors(-1, gid, partners, m);
136     else if (swap_round(round))
137         swap.outgoing(sub_round(round), gid, partners, m);
138     else
139         histogram.outgoing(sub_round(round), gid, partners, m);
140   }
141 
link_neighborsdiy::detail::KDTreePartners142   inline void   link_neighbors(int, int gid, std::vector<int>& partners, const diy::Master& m) const
143   {
144     int         lid  = m.lid(gid);
145     diy::Link*  link = m.link(lid);
146 
147     std::set<int> result;       // partners must be unique
148     for (int i = 0; i < link->size(); ++i)
149         result.insert(link->target(i).gid);
150 
151     for (std::set<int>::const_iterator it = result.begin(); it != result.end(); ++it)
152         partners.push_back(*it);
153   }
154 
155   // 1-D domain to feed into histogram and swap
156   diy::RegularDecomposer<diy::DiscreteBounds>   decomposer;
157 
158   diy::RegularAllReducePartners     histogram;
159   diy::RegularSwapPartners          swap;
160 
161   std::vector<RoundType>            rounds_;
162   std::vector<int>                  dim_;
163 
164   bool                              wrap;
165   Bounds                            domain;
166 };
167 
168 template<class Block, class Point>
169 void
170 diy::detail::KDTreePartition<Block,Point>::
operator ()(Block * b,const diy::ReduceProxy & srp,const KDTreePartners & partners) const171 operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const
172 {
173     int dim;
174     if (srp.round() < partners.rounds())
175         dim = partners.dim(srp.round());
176     else
177         dim = partners.dim(srp.round() - 1);
178 
179     if (srp.round() == partners.rounds())
180         update_links(b, srp, dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain); // -1 would be the "uninformative" link round
181     else if (partners.swap_round(srp.round()) && partners.sub_round(srp.round()) < 0)       // link round
182     {
183         dequeue_exchange(b, srp, dim);         // from the swap round
184         split_to_neighbors(b, srp, dim);
185     }
186     else if (partners.swap_round(srp.round()))
187     {
188         Histogram   histogram;
189         receive_histogram(b, srp, histogram);
190         enqueue_exchange(b, srp, dim, histogram);
191     } else if (partners.sub_round(srp.round()) == 0)
192     {
193         if (srp.round() > 0)
194         {
195             int prev_dim = dim - 1;
196             if (prev_dim < 0)
197                 prev_dim += dim_;
198             update_links(b, srp, prev_dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain);    // -1 would be the "uninformative" link round
199         }
200 
201         compute_local_histogram(b, srp, dim);
202     } else if (partners.sub_round(srp.round()) < (int) partners.histogram.rounds()/2)
203     {
204         Histogram   histogram(bins_);
205         add_histogram(b, srp, histogram);
206         srp.enqueue(srp.out_link().target(0), histogram);
207     }
208     else
209     {
210         Histogram   histogram(bins_);
211         add_histogram(b, srp, histogram);
212         forward_histogram(b, srp, histogram);
213     }
214 }
215 
216 template<class Block, class Point>
217 int
218 diy::detail::KDTreePartition<Block,Point>::
divide_gid(int gid,bool lower,int round,int rounds) const219 divide_gid(int gid, bool lower, int round, int rounds) const
220 {
221     if (lower)
222         gid &= ~(1 << (rounds - 1 - round));
223     else
224         gid |=  (1 << (rounds - 1 - round));
225     return gid;
226 }
227 
228 // round here is the outer iteration of the algorithm
229 template<class Block, class Point>
230 void
231 diy::detail::KDTreePartition<Block,Point>::
update_links(Block * b,const diy::ReduceProxy & srp,int dim,int round,int rounds,bool wrap,const Bounds & domain) const232 update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const
233 {
234     int         gid  = srp.gid();
235     int         lid  = srp.master()->lid(gid);
236     RCLink*     link = static_cast<RCLink*>(srp.master()->link(lid));
237 
238     // (gid, dir) -> i
239     std::map<std::pair<int,diy::Direction>, int> link_map;
240     for (int i = 0; i < link->size(); ++i)
241         link_map[std::make_pair(link->target(i).gid, link->direction(i))] = i;
242 
243     // NB: srp.enqueue(..., ...) should match the link
244     std::vector<float>  splits(link->size());
245     for (int i = 0; i < link->size(); ++i)
246     {
247         float split; diy::Direction dir(dim_,0);
248 
249         int in_gid = link->target(i).gid;
250         while(srp.incoming(in_gid))
251         {
252             srp.dequeue(in_gid, split);
253             srp.dequeue(in_gid, dir);
254 
255             // reverse dir
256             for (int j = 0; j < dim_; ++j)
257                 dir[j] = -dir[j];
258 
259             int k = link_map[std::make_pair(in_gid, dir)];
260             splits[k] = split;
261         }
262     }
263 
264     RCLink      new_link(dim_, link->core(), link->core());
265 
266     bool lower = !(gid & (1 << (rounds - 1 - round)));
267 
268     // fill out the new link
269     for (int i = 0; i < link->size(); ++i)
270     {
271         diy::Direction  dir      = link->direction(i);
272         //diy::Direction  wrap_dir = link->wrap(i);     // we don't use existing wrap, but restore it from scratch
273         if (dir[dim] != 0)
274         {
275             if ((dir[dim] < 0 && lower) || (dir[dim] > 0 && !lower))
276             {
277                 int nbr_gid = divide_gid(link->target(i).gid, !lower, round, rounds);
278                 diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) };
279                 new_link.add_neighbor(nbr);
280 
281                 new_link.add_direction(dir);
282 
283                 Bounds bounds = link->bounds(i);
284                 update_neighbor_bounds(bounds, splits[i], dim, !lower);
285                 new_link.add_bounds(bounds);
286 
287                 if (wrap)
288                     new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
289                 else
290                     new_link.add_wrap(diy::Direction(dim_,0));
291             }
292         } else // non-aligned side
293         {
294             for (int j = 0; j < 2; ++j)
295             {
296                 int nbr_gid = divide_gid(link->target(i).gid, j == 0, round, rounds);
297 
298                 Bounds  bounds  = link->bounds(i);
299                 update_neighbor_bounds(bounds, splits[i], dim, j == 0);
300 
301                 if (intersects(bounds, new_link.bounds(), dim, wrap, domain))
302                 {
303                     diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) };
304                     new_link.add_neighbor(nbr);
305                     new_link.add_direction(dir);
306                     new_link.add_bounds(bounds);
307 
308                     if (wrap)
309                         new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain));
310                     else
311                         new_link.add_wrap(diy::Direction(dim_, 0));
312                 }
313             }
314         }
315     }
316 
317     // add link to the dual block
318     int dual_gid = divide_gid(gid, !lower, round, rounds);
319     diy::BlockID dual = { dual_gid, srp.assigner().rank(dual_gid) };
320     new_link.add_neighbor(dual);
321 
322     Bounds nbr_bounds = link->bounds();     // old block bounds
323     update_neighbor_bounds(nbr_bounds, find_split(new_link.bounds(), nbr_bounds), dim, !lower);
324     new_link.add_bounds(nbr_bounds);
325 
326     new_link.add_wrap(diy::Direction(dim_,0));    // dual block cannot be wrapped
327 
328     if (lower)
329     {
330         diy::Direction right(dim_,0);
331         right[dim] = 1;
332         new_link.add_direction(right);
333     } else
334     {
335         diy::Direction left(dim_,0);
336         left[dim] = -1;
337         new_link.add_direction(left);
338     }
339 
340     // update the link; notice that this won't conflict with anything since
341     // reduce is using its own notion of the link constructed through the
342     // partners
343     link->swap(new_link);
344 }
345 
346 template<class Block, class Point>
347 void
348 diy::detail::KDTreePartition<Block,Point>::
split_to_neighbors(Block * b,const diy::ReduceProxy & srp,int dim) const349 split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const
350 {
351     int         lid  = srp.master()->lid(srp.gid());
352     RCLink*     link = static_cast<RCLink*>(srp.master()->link(lid));
353 
354     // determine split
355     float split = find_split(link->core(), link->bounds());
356 
357     for (int i = 0; i < link->size(); ++i)
358     {
359         srp.enqueue(link->target(i), split);
360         srp.enqueue(link->target(i), link->direction(i));
361     }
362 }
363 
364 template<class Block, class Point>
365 void
366 diy::detail::KDTreePartition<Block,Point>::
compute_local_histogram(Block * b,const diy::ReduceProxy & srp,int dim) const367 compute_local_histogram(Block* b, const diy::ReduceProxy& srp, int dim) const
368 {
369     int         lid  = srp.master()->lid(srp.gid());
370     RCLink*     link = static_cast<RCLink*>(srp.master()->link(lid));
371 
372     // compute and enqueue local histogram
373     Histogram histogram(bins_);
374 
375     float   width = (link->core().max[dim] - link->core().min[dim])/bins_;
376     for (size_t i = 0; i < (b->*points_).size(); ++i)
377     {
378         float x = (b->*points_)[i][dim];
379         int loc = (x - link->core().min[dim]) / width;
380         if (loc < 0)
381             throw std::runtime_error(fmt::format("{} {} {}", loc, x, link->core().min[dim]));
382         if (loc >= (int) bins_)
383             loc = bins_ - 1;
384         ++(histogram[loc]);
385     }
386 
387     srp.enqueue(srp.out_link().target(0), histogram);
388 }
389 
390 template<class Block, class Point>
391 void
392 diy::detail::KDTreePartition<Block,Point>::
add_histogram(Block * b,const diy::ReduceProxy & srp,Histogram & histogram) const393 add_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const
394 {
395     // dequeue and add up the histograms
396     for (int i = 0; i < srp.in_link().size(); ++i)
397     {
398         int nbr_gid = srp.in_link().target(i).gid;
399 
400         Histogram hist;
401         srp.dequeue(nbr_gid, hist);
402         for (size_t j = 0; j < hist.size(); ++j)
403             histogram[j] += hist[j];
404     }
405 }
406 
407 template<class Block, class Point>
408 void
409 diy::detail::KDTreePartition<Block,Point>::
receive_histogram(Block * b,const diy::ReduceProxy & srp,Histogram & histogram) const410 receive_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const
411 {
412     srp.dequeue(srp.in_link().target(0).gid, histogram);
413 }
414 
415 template<class Block, class Point>
416 void
417 diy::detail::KDTreePartition<Block,Point>::
forward_histogram(Block * b,const diy::ReduceProxy & srp,const Histogram & histogram) const418 forward_histogram(Block* b, const diy::ReduceProxy& srp, const Histogram& histogram) const
419 {
420     for (int i = 0; i < srp.out_link().size(); ++i)
421         srp.enqueue(srp.out_link().target(i), histogram);
422 }
423 
424 template<class Block, class Point>
425 void
426 diy::detail::KDTreePartition<Block,Point>::
enqueue_exchange(Block * b,const diy::ReduceProxy & srp,int dim,const Histogram & histogram) const427 enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Histogram& histogram) const
428 {
429     auto        log = get_logger();
430 
431     int         lid  = srp.master()->lid(srp.gid());
432     RCLink*     link = static_cast<RCLink*>(srp.master()->link(lid));
433 
434     int k = srp.out_link().size();
435 
436     if (k == 0)        // final round; nothing needs to be sent; this is actually redundant
437         return;
438 
439     // pick split points
440     size_t total = 0;
441     for (size_t i = 0; i < histogram.size(); ++i)
442         total += histogram[i];
443     log->trace("Histogram total: {}", total);
444 
445     size_t cur   = 0;
446     float  width = (link->core().max[dim] - link->core().min[dim])/bins_;
447     float  split = 0;
448     size_t i = 0;
449     for (; i < histogram.size(); ++i)
450     {
451         if (cur + histogram[i] > total/2)
452             break;
453         cur += histogram[i];
454     }
455     if (i == 0)
456         ++i;
457     else if (i >= histogram.size() - 1)
458         i = histogram.size() - 2;
459     split = link->core().min[dim] + width*i;
460     log->trace("Found split: {} (dim={}) in {} - {}", split, dim, link->core().min[dim], link->core().max[dim]);
461 
462     // subset and enqueue
463     std::vector< std::vector<Point> > out_points(srp.out_link().size());
464     for (size_t i = 0; i < (b->*points_).size(); ++i)
465     {
466       float x = (b->*points_)[i][dim];
467       int loc = x < split ? 0 : 1;
468       out_points[loc].push_back((b->*points_)[i]);
469     }
470     int pos = -1;
471     for (int i = 0; i < k; ++i)
472     {
473       if (srp.out_link().target(i).gid == srp.gid())
474       {
475         (b->*points_).swap(out_points[i]);
476         pos = i;
477       }
478       else
479         srp.enqueue(srp.out_link().target(i), out_points[i]);
480     }
481     if (pos == 0)
482         link->core().max[dim] = split;
483     else
484         link->core().min[dim] = split;
485 }
486 
487 template<class Block, class Point>
488 void
489 diy::detail::KDTreePartition<Block,Point>::
dequeue_exchange(Block * b,const diy::ReduceProxy & srp,int dim) const490 dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const
491 {
492     int         lid  = srp.master()->lid(srp.gid());
493     RCLink*     link = static_cast<RCLink*>(srp.master()->link(lid));
494 
495     for (int i = 0; i < srp.in_link().size(); ++i)
496     {
497       int nbr_gid = srp.in_link().target(i).gid;
498       if (nbr_gid == srp.gid())
499           continue;
500 
501       std::vector<Point>   in_points;
502       srp.dequeue(nbr_gid, in_points);
503       for (size_t j = 0; j < in_points.size(); ++j)
504       {
505         if (in_points[j][dim] < link->core().min[dim] || in_points[j][dim] > link->core().max[dim])
506             throw std::runtime_error(fmt::format("Dequeued {} outside [{},{}] ({})",
507                                      in_points[j][dim], link->core().min[dim], link->core().max[dim], dim));
508         (b->*points_).push_back(in_points[j]);
509       }
510     }
511 }
512 
513 template<class Block, class Point>
514 void
515 diy::detail::KDTreePartition<Block,Point>::
update_neighbor_bounds(Bounds & bounds,float split,int dim,bool lower) const516 update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const
517 {
518     if (lower)
519         bounds.max[dim] = split;
520     else
521         bounds.min[dim] = split;
522 }
523 
524 template<class Block, class Point>
525 bool
526 diy::detail::KDTreePartition<Block,Point>::
intersects(const Bounds & x,const Bounds & y,int dim,bool wrap,const Bounds & domain) const527 intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const
528 {
529     if (wrap)
530     {
531         if (x.min[dim] == domain.min[dim] && y.max[dim] == domain.max[dim])
532             return true;
533         if (y.min[dim] == domain.min[dim] && x.max[dim] == domain.max[dim])
534             return true;
535     }
536     return x.min[dim] <= y.max[dim] && y.min[dim] <= x.max[dim];
537 }
538 
539 template<class Block, class Point>
540 float
541 diy::detail::KDTreePartition<Block,Point>::
find_split(const Bounds & changed,const Bounds & original) const542 find_split(const Bounds& changed, const Bounds& original) const
543 {
544     for (int i = 0; i < dim_; ++i)
545     {
546         if (changed.min[i] != original.min[i])
547             return changed.min[i];
548         if (changed.max[i] != original.max[i])
549             return changed.max[i];
550     }
551     assert(0);
552     return -1;
553 }
554 
555 template<class Block, class Point>
556 diy::Direction
557 diy::detail::KDTreePartition<Block,Point>::
find_wrap(const Bounds & bounds,const Bounds & nbr_bounds,const Bounds & domain) const558 find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const
559 {
560     diy::Direction wrap(dim_,0);
561     for (int i = 0; i < dim_; ++i)
562     {
563         if (bounds.min[i] == domain.min[i] && nbr_bounds.max[i] == domain.max[i])
564             wrap[i] = -1;
565         if (bounds.max[i] == domain.max[i] && nbr_bounds.min[i] == domain.min[i])
566             wrap[i] =  1;
567     }
568     return wrap;
569 }
570 
571 
572 #endif
573