1 namespace diy
2 {
3     namespace detail
4     {
5         struct CollectiveOp
6         {
7           virtual void    init()                                  =0;
8           virtual void    update(const CollectiveOp& other)       =0;
9           virtual void    global(const mpi::communicator& comm)   =0;
10           virtual void    copy_from(const CollectiveOp& other)    =0;
11           virtual void    result_out(void* dest) const            =0;
~CollectiveOpdiy::detail::CollectiveOp12           virtual         ~CollectiveOp()                         {}
13         };
14 
15         template<class T, class Op>
16         struct AllReduceOp: public CollectiveOp
17         {
AllReduceOpdiy::detail::AllReduceOp18                 AllReduceOp(const T& x, Op op):
19                   in_(x), op_(op)         {}
20 
initdiy::detail::AllReduceOp21           void  init()                                    { out_ = in_; }
updatediy::detail::AllReduceOp22           void  update(const CollectiveOp& other)         { out_ = op_(out_, static_cast<const AllReduceOp&>(other).in_); }
globaldiy::detail::AllReduceOp23           void  global(const mpi::communicator& comm)     { T res; mpi::all_reduce(comm, out_, res, op_); out_ = res; }
copy_fromdiy::detail::AllReduceOp24           void  copy_from(const CollectiveOp& other)      { out_ = static_cast<const AllReduceOp&>(other).out_; }
result_outdiy::detail::AllReduceOp25           void  result_out(void* dest) const              { *reinterpret_cast<T*>(dest) = out_; }
26 
27           private:
28             T     in_, out_;
29             Op    op_;
30         };
31 
32         template<class T>
33         struct Scratch: public CollectiveOp
34         {
Scratchdiy::detail::Scratch35                 Scratch(const T& x):
36                   x_(x)                                   {}
37 
initdiy::detail::Scratch38           void  init()                                    {}
updatediy::detail::Scratch39           void  update(const CollectiveOp&)               {}
globaldiy::detail::Scratch40           void  global(const mpi::communicator&)          {}
copy_fromdiy::detail::Scratch41           void  copy_from(const CollectiveOp&)            {}
result_outdiy::detail::Scratch42           void  result_out(void* dest) const              { *reinterpret_cast<T*>(dest) = x_; }
43 
44           private:
45             T     x_;
46         };
47     }
48 
49     struct Master::Collective
50     {
Collectivediy::Master::Collective51                     Collective():
52                       cop_(0)                           {}
Collectivediy::Master::Collective53                     Collective(detail::CollectiveOp* cop):
54                       cop_(cop)                         {}
Collectivediy::Master::Collective55                     Collective(Collective&& other):
56                       cop_(0)                           { swap(const_cast<Collective&>(other)); }
~Collectivediy::Master::Collective57                     ~Collective()                       { delete cop_; }
58 
59         Collective& operator=(const Collective& other)  = delete;
60                     Collective(Collective& other)       = delete;
61 
initdiy::Master::Collective62         void        init()                              { cop_->init(); }
swapdiy::Master::Collective63         void        swap(Collective& other)             { std::swap(cop_, other.cop_); }
updatediy::Master::Collective64         void        update(const Collective& other)     { cop_->update(*other.cop_); }
globaldiy::Master::Collective65         void        global(const mpi::communicator& c)  { cop_->global(c); }
copy_fromdiy::Master::Collective66         void        copy_from(Collective& other) const  { cop_->copy_from(*other.cop_); }
result_outdiy::Master::Collective67         void        result_out(void* x) const           { cop_->result_out(x); }
68 
69         detail::CollectiveOp*                           cop_;
70     };
71 
72     struct Master::CollectivesList: public std::list<Collective>
73     {};
74 
75     struct Master::CollectivesMap: public std::map<int, CollectivesList>
76     {};
77 }
78 
79 diy::Master::CollectivesMap&
80 diy::Master::
collectives()81 collectives()
82 {
83     return *collectives_;
84 }
85 
86 diy::Master::CollectivesList&
87 diy::Master::
collectives(int gid__)88 collectives(int gid__)
89 {
90     return (*collectives_)[gid__];
91 }
92 
93 void
94 diy::Master::
process_collectives()95 process_collectives()
96 {
97   auto scoped = prof.scoped("collectives");
98   DIY_UNUSED(scoped);
99 
100   if (collectives().empty())
101       return;
102 
103   using CollectivesIterator = CollectivesList::iterator;
104   std::vector<CollectivesIterator>  iters;
105   std::vector<int>                  gids;
106   for (auto& x : collectives())
107   {
108     gids.push_back(x.first);
109     iters.push_back(x.second.begin());
110   }
111 
112   while (iters[0] != collectives().begin()->second.end())
113   {
114     iters[0]->init();
115     for (unsigned j = 1; j < iters.size(); ++j)
116     {
117       // NB: this assumes that the operations are commutative
118       iters[0]->update(*iters[j]);
119     }
120     iters[0]->global(comm_);        // do the mpi collective
121 
122     for (unsigned j = 1; j < iters.size(); ++j)
123     {
124       iters[j]->copy_from(*iters[0]);
125       ++iters[j];
126     }
127 
128     ++iters[0];
129   }
130 }
131 
132