1 namespace diy
2 {
3 namespace detail
4 {
5 struct CollectiveOp
6 {
7 virtual void init() =0;
8 virtual void update(const CollectiveOp& other) =0;
9 virtual void global(const mpi::communicator& comm) =0;
10 virtual void copy_from(const CollectiveOp& other) =0;
11 virtual void result_out(void* dest) const =0;
~CollectiveOpdiy::detail::CollectiveOp12 virtual ~CollectiveOp() {}
13 };
14
15 template<class T, class Op>
16 struct AllReduceOp: public CollectiveOp
17 {
AllReduceOpdiy::detail::AllReduceOp18 AllReduceOp(const T& x, Op op):
19 in_(x), op_(op) {}
20
initdiy::detail::AllReduceOp21 void init() { out_ = in_; }
updatediy::detail::AllReduceOp22 void update(const CollectiveOp& other) { out_ = op_(out_, static_cast<const AllReduceOp&>(other).in_); }
globaldiy::detail::AllReduceOp23 void global(const mpi::communicator& comm) { T res; mpi::all_reduce(comm, out_, res, op_); out_ = res; }
copy_fromdiy::detail::AllReduceOp24 void copy_from(const CollectiveOp& other) { out_ = static_cast<const AllReduceOp&>(other).out_; }
result_outdiy::detail::AllReduceOp25 void result_out(void* dest) const { *reinterpret_cast<T*>(dest) = out_; }
26
27 private:
28 T in_, out_;
29 Op op_;
30 };
31
32 template<class T>
33 struct Scratch: public CollectiveOp
34 {
Scratchdiy::detail::Scratch35 Scratch(const T& x):
36 x_(x) {}
37
initdiy::detail::Scratch38 void init() {}
updatediy::detail::Scratch39 void update(const CollectiveOp&) {}
globaldiy::detail::Scratch40 void global(const mpi::communicator&) {}
copy_fromdiy::detail::Scratch41 void copy_from(const CollectiveOp&) {}
result_outdiy::detail::Scratch42 void result_out(void* dest) const { *reinterpret_cast<T*>(dest) = x_; }
43
44 private:
45 T x_;
46 };
47 }
48
49 struct Master::Collective
50 {
Collectivediy::Master::Collective51 Collective():
52 cop_(0) {}
Collectivediy::Master::Collective53 Collective(detail::CollectiveOp* cop):
54 cop_(cop) {}
Collectivediy::Master::Collective55 Collective(Collective&& other):
56 cop_(0) { swap(const_cast<Collective&>(other)); }
~Collectivediy::Master::Collective57 ~Collective() { delete cop_; }
58
59 Collective& operator=(const Collective& other) = delete;
60 Collective(Collective& other) = delete;
61
initdiy::Master::Collective62 void init() { cop_->init(); }
swapdiy::Master::Collective63 void swap(Collective& other) { std::swap(cop_, other.cop_); }
updatediy::Master::Collective64 void update(const Collective& other) { cop_->update(*other.cop_); }
globaldiy::Master::Collective65 void global(const mpi::communicator& c) { cop_->global(c); }
copy_fromdiy::Master::Collective66 void copy_from(Collective& other) const { cop_->copy_from(*other.cop_); }
result_outdiy::Master::Collective67 void result_out(void* x) const { cop_->result_out(x); }
68
69 detail::CollectiveOp* cop_;
70 };
71
72 struct Master::CollectivesList: public std::list<Collective>
73 {};
74
75 struct Master::CollectivesMap: public std::map<int, CollectivesList>
76 {};
77 }
78
79 diy::Master::CollectivesMap&
80 diy::Master::
collectives()81 collectives()
82 {
83 return *collectives_;
84 }
85
86 diy::Master::CollectivesList&
87 diy::Master::
collectives(int gid__)88 collectives(int gid__)
89 {
90 return (*collectives_)[gid__];
91 }
92
93 void
94 diy::Master::
process_collectives()95 process_collectives()
96 {
97 auto scoped = prof.scoped("collectives");
98 DIY_UNUSED(scoped);
99
100 if (collectives().empty())
101 return;
102
103 using CollectivesIterator = CollectivesList::iterator;
104 std::vector<CollectivesIterator> iters;
105 std::vector<int> gids;
106 for (auto& x : collectives())
107 {
108 gids.push_back(x.first);
109 iters.push_back(x.second.begin());
110 }
111
112 while (iters[0] != collectives().begin()->second.end())
113 {
114 iters[0]->init();
115 for (unsigned j = 1; j < iters.size(); ++j)
116 {
117 // NB: this assumes that the operations are commutative
118 iters[0]->update(*iters[j]);
119 }
120 iters[0]->global(comm_); // do the mpi collective
121
122 for (unsigned j = 1; j < iters.size(); ++j)
123 {
124 iters[j]->copy_from(*iters[0]);
125 ++iters[j];
126 }
127
128 ++iters[0];
129 }
130 }
131
132