1 /*******************************************************************************
2  * thrill/api/all_reduce.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2015 Matthias Stumpp <mstumpp@gmail.com>
7  * Copyright (C) 2015 Sebastian Lamm <seba.lamm@gmail.com>
8  *
9  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
10  ******************************************************************************/
11 
12 #pragma once
13 #ifndef THRILL_API_ALL_REDUCE_HEADER
14 #define THRILL_API_ALL_REDUCE_HEADER
15 
16 #include <thrill/api/action_node.hpp>
17 #include <thrill/api/dia.hpp>
18 
19 #include <type_traits>
20 
21 namespace thrill {
22 namespace api {
23 
24 /*!
25  * \ingroup api_layer
26  */
27 template <typename ValueType, typename ReduceFunction>
28 class AllReduceNode final : public ActionResultNode<ValueType>
29 {
30     static constexpr bool debug = false;
31 
32     using Super = ActionResultNode<ValueType>;
33     using Super::context_;
34 
35 public:
36     template <typename ParentDIA>
AllReduceNode(const ParentDIA & parent,const char * label,const ValueType & initial_value=ValueType (),bool with_initial_value=false,const ReduceFunction & reduce_function=ReduceFunction ())37     AllReduceNode(const ParentDIA& parent,
38                   const char* label,
39                   const ValueType& initial_value = ValueType(),
40                   bool with_initial_value = false,
41                   const ReduceFunction& reduce_function = ReduceFunction())
42         : Super(parent.ctx(), label, { parent.id() }, { parent.node() }),
43           reduce_function_(reduce_function),
44           sum_(initial_value),
45           // add to initial value if defined and we are first worker
46           first_(!(with_initial_value && parent.ctx().my_rank() == 0)) {
47         // Hook PreOp(s)
__anoned10ab5c0102(const ValueType& input) 48         auto pre_op_fn = [this](const ValueType& input) {
49                              PreOp(input);
50                          };
51 
52         auto lop_chain = parent.stack().push(pre_op_fn).fold();
53         parent.node()->AddChild(this, lop_chain);
54     }
55 
PreOp(const ValueType & input)56     void PreOp(const ValueType& input) {
57         if (TLX_UNLIKELY(first_)) {
58             first_ = false;
59             sum_ = input;
60         }
61         else {
62             sum_ = reduce_function_(sum_, input);
63         }
64     }
65 
66     //! Executes the sum operation.
Execute()67     void Execute() final {
68         // process the reduce
69         sum_ = context_.net.AllReduce(sum_, reduce_function_);
70     }
71 
72     //! Returns result of global sum.
result() const73     const ValueType& result() const final {
74         return sum_;
75     }
76 
77 private:
78     //! The sum function which is applied to two values.
79     ReduceFunction reduce_function_;
80     //! Local/global sum to be used in all reduce operation.
81     ValueType sum_;
82     //! indicate that sum_ is the default constructed first value. Worker 0's
83     //! value is already set to initial_value.
84     bool first_;
85 };
86 
87 template <typename ValueType, typename Stack>
88 template <typename ReduceFunction>
AllReduce(const ReduceFunction & sum_function) const89 ValueType DIA<ValueType, Stack>::AllReduce(
90     const ReduceFunction& sum_function) const {
91     assert(IsValid());
92 
93     using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
94 
95     static_assert(
96         std::is_convertible<
97             ValueType,
98             typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
99         "ReduceFunction has the wrong input type");
100 
101     static_assert(
102         std::is_convertible<
103             ValueType,
104             typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
105         "ReduceFunction has the wrong input type");
106 
107     static_assert(
108         std::is_convertible<
109             typename FunctionTraits<ReduceFunction>::result_type,
110             ValueType>::value,
111         "ReduceFunction has the wrong input type");
112 
113     auto node = tlx::make_counting<AllReduceNode>(
114         *this, "AllReduce", ValueType(), /* with_initial_value */ false,
115         sum_function);
116 
117     node->RunScope();
118 
119     return node->result();
120 }
121 
122 template <typename ValueType, typename Stack>
123 template <typename ReduceFunction>
AllReduce(const ReduceFunction & sum_function,const ValueType & initial_value) const124 ValueType DIA<ValueType, Stack>::AllReduce(
125     const ReduceFunction& sum_function, const ValueType& initial_value) const {
126     assert(IsValid());
127 
128     using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
129 
130     static_assert(
131         std::is_convertible<
132             ValueType,
133             typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
134         "ReduceFunction has the wrong input type");
135 
136     static_assert(
137         std::is_convertible<
138             ValueType,
139             typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
140         "ReduceFunction has the wrong input type");
141 
142     static_assert(
143         std::is_convertible<
144             typename FunctionTraits<ReduceFunction>::result_type,
145             ValueType>::value,
146         "ReduceFunction has the wrong input type");
147 
148     auto node = tlx::make_counting<AllReduceNode>(
149         *this, "AllReduce", initial_value, /* with_initial_value */ true,
150         sum_function);
151 
152     node->RunScope();
153 
154     return node->result();
155 }
156 
157 template <typename ValueType, typename Stack>
158 template <typename ReduceFunction>
AllReduceFuture(const ReduceFunction & sum_function) const159 Future<ValueType> DIA<ValueType, Stack>::AllReduceFuture(
160     const ReduceFunction& sum_function) const {
161     assert(IsValid());
162 
163     using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
164 
165     static_assert(
166         std::is_convertible<
167             ValueType,
168             typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
169         "ReduceFunction has the wrong input type");
170 
171     static_assert(
172         std::is_convertible<
173             ValueType,
174             typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
175         "ReduceFunction has the wrong input type");
176 
177     static_assert(
178         std::is_convertible<
179             typename FunctionTraits<ReduceFunction>::result_type,
180             ValueType>::value,
181         "ReduceFunction has the wrong input type");
182 
183     auto node = tlx::make_counting<AllReduceNode>(
184         *this, "AllReduce", ValueType(), /* with_initial_value */ false,
185         sum_function);
186 
187     return Future<ValueType>(node);
188 }
189 
190 template <typename ValueType, typename Stack>
191 template <typename ReduceFunction>
AllReduceFuture(const ReduceFunction & sum_function,const ValueType & initial_value) const192 Future<ValueType> DIA<ValueType, Stack>::AllReduceFuture(
193     const ReduceFunction& sum_function, const ValueType& initial_value) const {
194     assert(IsValid());
195 
196     using AllReduceNode = api::AllReduceNode<ValueType, ReduceFunction>;
197 
198     static_assert(
199         std::is_convertible<
200             ValueType,
201             typename FunctionTraits<ReduceFunction>::template arg<0> >::value,
202         "ReduceFunction has the wrong input type");
203 
204     static_assert(
205         std::is_convertible<
206             ValueType,
207             typename FunctionTraits<ReduceFunction>::template arg<1> >::value,
208         "ReduceFunction has the wrong input type");
209 
210     static_assert(
211         std::is_convertible<
212             typename FunctionTraits<ReduceFunction>::result_type,
213             ValueType>::value,
214         "ReduceFunction has the wrong input type");
215 
216     auto node = tlx::make_counting<AllReduceNode>(
217         *this, "AllReduce", initial_value, /* with_initial_value */ true,
218         sum_function);
219 
220     return Future<ValueType>(node);
221 }
222 
223 } // namespace api
224 } // namespace thrill
225 
226 #endif // !THRILL_API_ALL_REDUCE_HEADER
227 
228 /******************************************************************************/
229