1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013-2014 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10
11 #ifndef BOOST_COMPUTE_CLOSURE_HPP
12 #define BOOST_COMPUTE_CLOSURE_HPP
13
14 #include <string>
15 #include <sstream>
16
17 #include <boost/config.hpp>
18 #include <boost/fusion/adapted/boost_tuple.hpp>
19 #include <boost/fusion/algorithm/iteration/for_each.hpp>
20 #include <boost/mpl/for_each.hpp>
21 #include <boost/mpl/transform.hpp>
22 #include <boost/typeof/typeof.hpp>
23 #include <boost/static_assert.hpp>
24 #include <boost/algorithm/string.hpp>
25 #include <boost/tuple/tuple.hpp>
26 #include <boost/type_traits/function_traits.hpp>
27
28 #include <boost/compute/cl.hpp>
29 #include <boost/compute/function.hpp>
30 #include <boost/compute/type_traits/type_name.hpp>
31 #include <boost/compute/type_traits/detail/capture_traits.hpp>
32
33 namespace boost {
34 namespace compute {
35 namespace detail {
36
37 template<class ResultType, class ArgTuple, class CaptureTuple>
38 class invoked_closure
39 {
40 public:
41 typedef ResultType result_type;
42
43 BOOST_STATIC_CONSTANT(
44 size_t, arity = boost::tuples::length<ArgTuple>::value
45 );
46
invoked_closure(const std::string & name,const std::string & source,const std::map<std::string,std::string> & definitions,const ArgTuple & args,const CaptureTuple & capture)47 invoked_closure(const std::string &name,
48 const std::string &source,
49 const std::map<std::string, std::string> &definitions,
50 const ArgTuple &args,
51 const CaptureTuple &capture)
52 : m_name(name),
53 m_source(source),
54 m_definitions(definitions),
55 m_args(args),
56 m_capture(capture)
57 {
58 }
59
name() const60 std::string name() const
61 {
62 return m_name;
63 }
64
source() const65 std::string source() const
66 {
67 return m_source;
68 }
69
definitions() const70 const std::map<std::string, std::string>& definitions() const
71 {
72 return m_definitions;
73 }
74
args() const75 const ArgTuple& args() const
76 {
77 return m_args;
78 }
79
capture() const80 const CaptureTuple& capture() const
81 {
82 return m_capture;
83 }
84
85 private:
86 std::string m_name;
87 std::string m_source;
88 std::map<std::string, std::string> m_definitions;
89 ArgTuple m_args;
90 CaptureTuple m_capture;
91 };
92
93 } // end detail namespace
94
95 /// \internal_
96 template<class Signature, class CaptureTuple>
97 class closure
98 {
99 public:
100 typedef typename
101 boost::function_traits<Signature>::result_type result_type;
102
103 BOOST_STATIC_CONSTANT(
104 size_t, arity = boost::function_traits<Signature>::arity
105 );
106
closure(const std::string & name,const CaptureTuple & capture,const std::string & source)107 closure(const std::string &name,
108 const CaptureTuple &capture,
109 const std::string &source)
110 : m_name(name),
111 m_source(source),
112 m_capture(capture)
113 {
114 }
115
~closure()116 ~closure()
117 {
118 }
119
name() const120 std::string name() const
121 {
122 return m_name;
123 }
124
125 /// \internal_
source() const126 std::string source() const
127 {
128 return m_source;
129 }
130
131 /// \internal_
define(std::string name,std::string value=std::string ())132 void define(std::string name, std::string value = std::string())
133 {
134 m_definitions[name] = value;
135 }
136
137 /// \internal_
138 detail::invoked_closure<result_type, boost::tuple<>, CaptureTuple>
operator ()() const139 operator()() const
140 {
141 BOOST_STATIC_ASSERT_MSG(
142 arity == 0,
143 "Non-nullary closure function invoked with zero arguments"
144 );
145
146 return detail::invoked_closure<result_type, boost::tuple<>, CaptureTuple>(
147 m_name, m_source, m_definitions, boost::make_tuple(), m_capture
148 );
149 }
150
151 /// \internal_
152 template<class Arg1>
153 detail::invoked_closure<result_type, boost::tuple<Arg1>, CaptureTuple>
operator ()(const Arg1 & arg1) const154 operator()(const Arg1 &arg1) const
155 {
156 BOOST_STATIC_ASSERT_MSG(
157 arity == 1,
158 "Non-unary closure function invoked with one argument"
159 );
160
161 return detail::invoked_closure<result_type, boost::tuple<Arg1>, CaptureTuple>(
162 m_name, m_source, m_definitions, boost::make_tuple(arg1), m_capture
163 );
164 }
165
166 /// \internal_
167 template<class Arg1, class Arg2>
168 detail::invoked_closure<result_type, boost::tuple<Arg1, Arg2>, CaptureTuple>
operator ()(const Arg1 & arg1,const Arg2 & arg2) const169 operator()(const Arg1 &arg1, const Arg2 &arg2) const
170 {
171 BOOST_STATIC_ASSERT_MSG(
172 arity == 2,
173 "Non-binary closure function invoked with two arguments"
174 );
175
176 return detail::invoked_closure<result_type, boost::tuple<Arg1, Arg2>, CaptureTuple>(
177 m_name, m_source, m_definitions, boost::make_tuple(arg1, arg2), m_capture
178 );
179 }
180
181 /// \internal_
182 template<class Arg1, class Arg2, class Arg3>
183 detail::invoked_closure<result_type, boost::tuple<Arg1, Arg2, Arg3>, CaptureTuple>
operator ()(const Arg1 & arg1,const Arg2 & arg2,const Arg3 & arg3) const184 operator()(const Arg1 &arg1, const Arg2 &arg2, const Arg3 &arg3) const
185 {
186 BOOST_STATIC_ASSERT_MSG(
187 arity == 3,
188 "Non-ternary closure function invoked with three arguments"
189 );
190
191 return detail::invoked_closure<result_type, boost::tuple<Arg1, Arg2, Arg3>, CaptureTuple>(
192 m_name, m_source, m_definitions, boost::make_tuple(arg1, arg2, arg3), m_capture
193 );
194 }
195
196 private:
197 std::string m_name;
198 std::string m_source;
199 std::map<std::string, std::string> m_definitions;
200 CaptureTuple m_capture;
201 };
202
203 namespace detail {
204
205 struct closure_signature_argument_inserter
206 {
closure_signature_argument_inserterboost::compute::detail::closure_signature_argument_inserter207 closure_signature_argument_inserter(std::stringstream &s_,
208 const char *capture_string,
209 size_t last)
210 : s(s_)
211 {
212 n = 0;
213 m_last = last;
214
215 size_t capture_string_length = std::strlen(capture_string);
216 BOOST_ASSERT(capture_string[0] == '(' &&
217 capture_string[capture_string_length-1] == ')');
218 std::string capture_string_(capture_string + 1, capture_string_length - 2);
219 boost::split(m_capture_names, capture_string_ , boost::is_any_of(","));
220 }
221
222 template<class T>
operator ()boost::compute::detail::closure_signature_argument_inserter223 void operator()(const T&) const
224 {
225 BOOST_ASSERT(n < m_capture_names.size());
226
227 // get captured variable name
228 std::string variable_name = m_capture_names[n];
229
230 // remove leading and trailing whitespace from variable name
231 boost::trim(variable_name);
232
233 s << capture_traits<T>::type_name() << " " << variable_name;
234 if(n+1 < m_last){
235 s << ", ";
236 }
237 n++;
238 }
239
240 mutable size_t n;
241 size_t m_last;
242 std::vector<std::string> m_capture_names;
243 std::stringstream &s;
244 };
245
246 template<class Signature, class CaptureTuple>
247 inline std::string
make_closure_declaration(const char * name,const char * arguments,const CaptureTuple & capture_tuple,const char * capture_string)248 make_closure_declaration(const char *name,
249 const char *arguments,
250 const CaptureTuple &capture_tuple,
251 const char *capture_string)
252 {
253 typedef typename
254 boost::function_traits<Signature>::result_type result_type;
255 typedef typename
256 boost::function_types::parameter_types<Signature>::type parameter_types;
257 typedef typename
258 mpl::size<parameter_types>::type arity_type;
259
260 std::stringstream s;
261 s << "inline " << type_name<result_type>() << " " << name;
262 s << "(";
263
264 // insert function arguments
265 signature_argument_inserter i(s, arguments, arity_type::value);
266 mpl::for_each<
267 typename mpl::transform<parameter_types, boost::add_pointer<mpl::_1>
268 >::type>(i);
269 s << ", ";
270
271 // insert capture arguments
272 closure_signature_argument_inserter j(
273 s, capture_string, boost::tuples::length<CaptureTuple>::value
274 );
275 fusion::for_each(capture_tuple, j);
276
277 s << ")";
278 return s.str();
279 }
280
281 // used by the BOOST_COMPUTE_CLOSURE() macro to create a closure
282 // function with the given signature, name, capture, and source.
283 template<class Signature, class CaptureTuple>
284 inline closure<Signature, CaptureTuple>
make_closure_impl(const char * name,const char * arguments,const CaptureTuple & capture,const char * capture_string,const std::string & source)285 make_closure_impl(const char *name,
286 const char *arguments,
287 const CaptureTuple &capture,
288 const char *capture_string,
289 const std::string &source)
290 {
291 std::stringstream s;
292 s << make_closure_declaration<Signature>(name, arguments, capture, capture_string);
293 s << source;
294
295 return closure<Signature, CaptureTuple>(name, capture, s.str());
296 }
297
298 } // end detail namespace
299 } // end compute namespace
300 } // end boost namespace
301
302 /// Creates a closure function object with \p name and \p source.
303 ///
304 /// \param return_type The return type for the function.
305 /// \param name The name of the function.
306 /// \param arguments A list of arguments for the function.
307 /// \param capture A list of variables to capture.
308 /// \param source The OpenCL C source code for the function.
309 ///
310 /// For example, to create a function which checks if a 2D point is
311 /// contained in a circle of a given radius:
312 /// \code
313 /// // radius variable declared in C++
314 /// float radius = 1.5f;
315 ///
316 /// // create a closure function which returns true if the 2D point
317 /// // argument is contained within a circle of the given radius
318 /// BOOST_COMPUTE_CLOSURE(bool, is_in_circle, (const float2_ p), (radius),
319 /// {
320 /// return sqrt(p.x*p.x + p.y*p.y) < radius;
321 /// });
322 ///
323 /// // vector of 2D points
324 /// boost::compute::vector<float2_> points = ...
325 ///
326 /// // count number of points in the circle
327 /// size_t count = boost::compute::count_if(
328 /// points.begin(), points.end(), is_in_circle, queue
329 /// );
330 /// \endcode
331 ///
332 /// \see BOOST_COMPUTE_FUNCTION()
333 #ifdef BOOST_COMPUTE_DOXYGEN_INVOKED
334 #define BOOST_COMPUTE_CLOSURE(return_type, name, arguments, capture, source)
335 #else
336 #define BOOST_COMPUTE_CLOSURE(return_type, name, arguments, capture, ...) \
337 ::boost::compute::closure< \
338 return_type arguments, BOOST_TYPEOF(boost::tie capture) \
339 > name = \
340 ::boost::compute::detail::make_closure_impl< \
341 return_type arguments \
342 >( \
343 #name, #arguments, boost::tie capture, #capture, #__VA_ARGS__ \
344 )
345 #endif
346
347 #endif // BOOST_COMPUTE_CLOSURE_HPP
348