1 /*******************************************************************************
2 * thrill/api/all_reduce.hpp
3 *
4 * Part of Project Thrill - http://project-thrill.org
5 *
6 * Copyright (C) 2015 Matthias Stumpp <mstumpp@gmail.com>
7 * Copyright (C) 2015 Sebastian Lamm <seba.lamm@gmail.com>
8 *
9 * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10 ******************************************************************************/
11
12 #pragma once
13 #ifndef THRILL_API_ALL_REDUCE_HEADER
14 #define THRILL_API_ALL_REDUCE_HEADER
15
16 #include <thrill/api/action_node.hpp>
17 #include <thrill/api/dia.hpp>
18
19 #include <type_traits>
20
21 namespace thrill {
22 namespace api {
23
24 /*!
25 * \ingroup api_layer
26 */
27 template <typename ValueType, typename ReduceFunction>
28 class AllReduceNode final : public ActionResultNode<ValueType>
29 {
30 static constexpr bool debug = false;
31
32 using Super = ActionResultNode<ValueType>;
33 using Super::context_;
34
35 public:
36 template <typename ParentDIA>
AllReduceNode(const ParentDIA & parent,const char * label,const ValueType & initial_value=ValueType (),bool with_initial_value=false,const ReduceFunction & reduce_function=ReduceFunction ())37 AllReduceNode(const ParentDIA& parent,
38 const char* label,
39 const ValueType& initial_value = ValueType(),
40 bool with_initial_value = false,
41 const ReduceFunction& reduce_function = ReduceFunction())
42 : Super(parent.ctx(), label, { parent.id() }, { parent.node() }),
43 reduce_function_(reduce_function),
44 sum_(initial_value),
45 // add to initial value if defined and we are first worker
46 first_(!(with_initial_value && parent.ctx().my_rank() == 0)) {
47 // Hook PreOp(s)
__anoned10ab5c0102(const ValueType& input) 48 auto pre_op_fn = [this](const ValueType& input) {
49 PreOp(input);
50 };
51
52 auto lop_chain = parent.stack().push(pre_op_fn).fold();
53 parent.node()->AddChild(this, lop_chain);
54 }
55
PreOp(const ValueType & input)56 void PreOp(const ValueType& input) {
57 if (TLX_UNLIKELY(first_)) {
58 first_ = false;
59 sum_ = input;
60 }
61 else {
62 sum_ = reduce_function_(sum_, input);
63 }
64 }
65
66 //! Executes the sum operation.
Execute()67 void Execute() final {
68 // process the reduce
69 sum_ = context_.net.AllReduce(sum_, reduce_function_);
70 }
71
72 //! Returns result of global sum.
result() const73 const ValueType& result() const final {
74 return sum_;
75 }
76
77 private:
78 //! The sum function which is applied to two values.
79 ReduceFunction reduce_function_;
80 //! Local/global sum to be used in all reduce operation.
81 ValueType sum_;
82 //! indicate that sum_ is the default constructed first value. Worker 0's
83 //! value is already set to initial_value.
84 bool first_;
85 };
86
87 template <typename ValueType, typename Stack>
88 template <typename ReduceFunction>
AllReduce(const ReduceFunction & sum_function) const89 ValueType DIA<ValueType, Stack>::AllReduce(
90 const ReduceFunction& sum_function) const {
91 assert(IsValid());
92
93 using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
94
95 static_assert(
96 std::is_convertible<
97 ValueType,
98 typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
99 "ReduceFunction has the wrong input type");
100
101 static_assert(
102 std::is_convertible<
103 ValueType,
104 typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
105 "ReduceFunction has the wrong input type");
106
107 static_assert(
108 std::is_convertible<
109 typename FunctionTraits<ReduceFunction>::result_type,
110 ValueType>::value,
111 "ReduceFunction has the wrong input type");
112
113 auto node = tlx::make_counting<AllReduceNode>(
114 *this, "AllReduce", ValueType(), /* with_initial_value */ false,
115 sum_function);
116
117 node->RunScope();
118
119 return node->result();
120 }
121
122 template <typename ValueType, typename Stack>
123 template <typename ReduceFunction>
AllReduce(const ReduceFunction & sum_function,const ValueType & initial_value) const124 ValueType DIA<ValueType, Stack>::AllReduce(
125 const ReduceFunction& sum_function, const ValueType& initial_value) const {
126 assert(IsValid());
127
128 using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
129
130 static_assert(
131 std::is_convertible<
132 ValueType,
133 typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
134 "ReduceFunction has the wrong input type");
135
136 static_assert(
137 std::is_convertible<
138 ValueType,
139 typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
140 "ReduceFunction has the wrong input type");
141
142 static_assert(
143 std::is_convertible<
144 typename FunctionTraits<ReduceFunction>::result_type,
145 ValueType>::value,
146 "ReduceFunction has the wrong input type");
147
148 auto node = tlx::make_counting<AllReduceNode>(
149 *this, "AllReduce", initial_value, /* with_initial_value */ true,
150 sum_function);
151
152 node->RunScope();
153
154 return node->result();
155 }
156
157 template <typename ValueType, typename Stack>
158 template <typename ReduceFunction>
AllReduceFuture(const ReduceFunction & sum_function) const159 Future<ValueType> DIA<ValueType, Stack>::AllReduceFuture(
160 const ReduceFunction& sum_function) const {
161 assert(IsValid());
162
163 using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
164
165 static_assert(
166 std::is_convertible<
167 ValueType,
168 typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
169 "ReduceFunction has the wrong input type");
170
171 static_assert(
172 std::is_convertible<
173 ValueType,
174 typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
175 "ReduceFunction has the wrong input type");
176
177 static_assert(
178 std::is_convertible<
179 typename FunctionTraits<ReduceFunction>::result_type,
180 ValueType>::value,
181 "ReduceFunction has the wrong input type");
182
183 auto node = tlx::make_counting<AllReduceNode>(
184 *this, "AllReduce", ValueType(), /* with_initial_value */ false,
185 sum_function);
186
187 return Future<ValueType>(node);
188 }
189
190 template <typename ValueType, typename Stack>
191 template <typename ReduceFunction>
AllReduceFuture(const ReduceFunction & sum_function,const ValueType & initial_value) const192 Future<ValueType> DIA<ValueType, Stack>::AllReduceFuture(
193 const ReduceFunction& sum_function, const ValueType& initial_value) const {
194 assert(IsValid());
195
196 using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
197
198 static_assert(
199 std::is_convertible<
200 ValueType,
201 typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
202 "ReduceFunction has the wrong input type");
203
204 static_assert(
205 std::is_convertible<
206 ValueType,
207 typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
208 "ReduceFunction has the wrong input type");
209
210 static_assert(
211 std::is_convertible<
212 typename FunctionTraits<ReduceFunction>::result_type,
213 ValueType>::value,
214 "ReduceFunction has the wrong input type");
215
216 auto node = tlx::make_counting<AllReduceNode>(
217 *this, "AllReduce", initial_value, /* with_initial_value */ true,
218 sum_function);
219
220 return Future<ValueType>(node);
221 }
222
223 } // namespace api
224 } // namespace thrill
225
226 #endif // !THRILL_API_ALL_REDUCE_HEADER
227
228 /******************************************************************************/
229