1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
12 #define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
13 
14 #include <boost/proto/core.hpp>
15 #include <boost/proto/context.hpp>
16 #include <boost/type_traits.hpp>
17 #include <boost/preprocessor/repetition.hpp>
18 
19 #include <boost/compute/config.hpp>
20 #include <boost/compute/function.hpp>
21 #include <boost/compute/lambda/result_of.hpp>
22 #include <boost/compute/lambda/functional.hpp>
23 #include <boost/compute/type_traits/result_of.hpp>
24 #include <boost/compute/type_traits/type_name.hpp>
25 #include <boost/compute/detail/meta_kernel.hpp>
26 
27 namespace boost {
28 namespace compute {
29 namespace lambda {
30 
31 namespace mpl = boost::mpl;
32 namespace proto = boost::proto;
33 
34 #define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \
35     template<class LHS, class RHS> \
36     void operator()(tag, const LHS &lhs, const RHS &rhs) \
37     { \
38         if(proto::arity_of<LHS>::value > 0){ \
39             stream << '('; \
40             proto::eval(lhs, *this); \
41             stream << ')'; \
42         } \
43         else { \
44             proto::eval(lhs, *this); \
45         } \
46         \
47         stream << op; \
48         \
49         if(proto::arity_of<RHS>::value > 0){ \
50             stream << '('; \
51             proto::eval(rhs, *this); \
52             stream << ')'; \
53         } \
54         else { \
55             proto::eval(rhs, *this); \
56         } \
57     }
58 
59 // lambda expression context
60 template<class Args>
61 struct context : proto::callable_context<context<Args> >
62 {
63     typedef void result_type;
64     typedef Args args_tuple;
65 
66     // create a lambda context for kernel with args
contextboost::compute::lambda::context67     context(boost::compute::detail::meta_kernel &kernel, const Args &args_)
68         : stream(kernel),
69           args(args_)
70     {
71     }
72 
73     // handle terminals
74     template<class T>
operator ()boost::compute::lambda::context75     void operator()(proto::tag::terminal, const T &x)
76     {
77         // terminal values in lambda expressions are always literals
78         stream << stream.lit(x);
79     }
80 
operator ()boost::compute::lambda::context81     void operator()(proto::tag::terminal, const uchar_ &x)
82     {
83         stream << "(uchar)(" << stream.lit(uint_(x)) << "u)";
84     }
85 
operator ()boost::compute::lambda::context86     void operator()(proto::tag::terminal, const char_ &x)
87     {
88         stream << "(char)(" << stream.lit(int_(x)) << ")";
89     }
90 
operator ()boost::compute::lambda::context91     void operator()(proto::tag::terminal, const ushort_ &x)
92     {
93         stream << "(ushort)(" << stream.lit(x) << "u)";
94     }
95 
operator ()boost::compute::lambda::context96     void operator()(proto::tag::terminal, const short_ &x)
97     {
98         stream << "(short)(" << stream.lit(x) << ")";
99     }
100 
operator ()boost::compute::lambda::context101     void operator()(proto::tag::terminal, const uint_ &x)
102     {
103         stream << "(" << stream.lit(x) << "u)";
104     }
105 
operator ()boost::compute::lambda::context106     void operator()(proto::tag::terminal, const ulong_ &x)
107     {
108         stream << "(" << stream.lit(x) << "ul)";
109     }
110 
operator ()boost::compute::lambda::context111     void operator()(proto::tag::terminal, const long_ &x)
112     {
113         stream << "(" << stream.lit(x) << "l)";
114     }
115 
116     // handle placeholders
117     template<int I>
operator ()boost::compute::lambda::context118     void operator()(proto::tag::terminal, placeholder<I>)
119     {
120         stream << boost::get<I>(args);
121     }
122 
123     // handle functions
124     #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \
125         BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n)
126 
127     #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \
128     template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \
129     void operator()( \
130         proto::tag::function, \
131         const F &function, \
132         BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \
133     ) \
134     { \
135         proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \
136     }
137 
138     BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~)
139 
140     #undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION
141 
142     // operators
143     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+')
144     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-')
145     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*')
146     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/')
147     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%')
148     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<')
149     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>')
150     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=")
151     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=")
152     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==")
153     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=")
154     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&")
155     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||")
156     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&')
157     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|')
158     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^')
159     BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=')
160 
161     // subscript operator
162     template<class LHS, class RHS>
operator ()boost::compute::lambda::context163     void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs)
164     {
165         proto::eval(lhs, *this);
166         stream << '[';
167         proto::eval(rhs, *this);
168         stream << ']';
169     }
170 
171     // ternary conditional operator
172     template<class Pred, class Arg1, class Arg2>
operator ()boost::compute::lambda::context173     void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y)
174     {
175         proto::eval(p, *this);
176         stream << '?';
177         proto::eval(x, *this);
178         stream << ':';
179         proto::eval(y, *this);
180     }
181 
182     boost::compute::detail::meta_kernel &stream;
183     Args args;
184 };
185 
186 namespace detail {
187 
188 template<class Expr, class Arg>
189 struct invoked_unary_expression
190 {
191     typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type;
192 
invoked_unary_expressionboost::compute::lambda::detail::invoked_unary_expression193     invoked_unary_expression(const Expr &expr, const Arg &arg)
194         : m_expr(expr),
195           m_arg(arg)
196     {
197     }
198 
199     Expr m_expr;
200     Arg m_arg;
201 };
202 
203 template<class Expr, class Arg>
204 boost::compute::detail::meta_kernel&
operator <<(boost::compute::detail::meta_kernel & kernel,const invoked_unary_expression<Expr,Arg> & expr)205 operator<<(boost::compute::detail::meta_kernel &kernel,
206            const invoked_unary_expression<Expr, Arg> &expr)
207 {
208     context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg));
209     proto::eval(expr.m_expr, ctx);
210 
211     return kernel;
212 }
213 
214 template<class Expr, class Arg1, class Arg2>
215 struct invoked_binary_expression
216 {
217     typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type;
218 
invoked_binary_expressionboost::compute::lambda::detail::invoked_binary_expression219     invoked_binary_expression(const Expr &expr,
220                               const Arg1 &arg1,
221                               const Arg2 &arg2)
222         : m_expr(expr),
223           m_arg1(arg1),
224           m_arg2(arg2)
225     {
226     }
227 
228     Expr m_expr;
229     Arg1 m_arg1;
230     Arg2 m_arg2;
231 };
232 
233 template<class Expr, class Arg1, class Arg2>
234 boost::compute::detail::meta_kernel&
operator <<(boost::compute::detail::meta_kernel & kernel,const invoked_binary_expression<Expr,Arg1,Arg2> & expr)235 operator<<(boost::compute::detail::meta_kernel &kernel,
236            const invoked_binary_expression<Expr, Arg1, Arg2> &expr)
237 {
238     context<boost::tuple<Arg1, Arg2> > ctx(
239         kernel,
240         boost::make_tuple(expr.m_arg1, expr.m_arg2)
241     );
242     proto::eval(expr.m_expr, ctx);
243 
244     return kernel;
245 }
246 
247 } // end detail namespace
248 
249 // forward declare domain
250 struct domain;
251 
252 // lambda expression wrapper
253 template<class Expr>
254 struct expression : proto::extends<Expr, expression<Expr>, domain>
255 {
256     typedef proto::extends<Expr, expression<Expr>, domain> base_type;
257 
258     BOOST_PROTO_EXTENDS_USING_ASSIGN(expression)
259 
expressionboost::compute::lambda::expression260     expression(const Expr &expr = Expr())
261         : base_type(expr)
262     {
263     }
264 
265     // result_of protocol
266     template<class Signature>
267     struct result
268     {
269     };
270 
271     template<class This>
272     struct result<This()>
273     {
274         typedef
275             typename ::boost::compute::lambda::result_of<Expr>::type type;
276     };
277 
278     template<class This, class Arg>
279     struct result<This(Arg)>
280     {
281         typedef
282             typename ::boost::compute::lambda::result_of<
283                 Expr,
284                 typename boost::tuple<Arg>
285             >::type type;
286     };
287 
288     template<class This, class Arg1, class Arg2>
289     struct result<This(Arg1, Arg2)>
290     {
291         typedef typename
292             ::boost::compute::lambda::result_of<
293                 Expr,
294                 typename boost::tuple<Arg1, Arg2>
295             >::type type;
296     };
297 
298     template<class Arg>
299     detail::invoked_unary_expression<expression<Expr>, Arg>
operator ()boost::compute::lambda::expression300     operator()(const Arg &x) const
301     {
302         return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x);
303     }
304 
305     template<class Arg1, class Arg2>
306     detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2>
operator ()boost::compute::lambda::expression307     operator()(const Arg1 &x, const Arg2 &y) const
308     {
309         return detail::invoked_binary_expression<
310                    expression<Expr>,
311                    Arg1,
312                    Arg2
313                 >(*this, x, y);
314     }
315 
316     // function<> conversion operator
317     template<class R, class A1>
operator function<R(A1)>boost::compute::lambda::expression318     operator function<R(A1)>() const
319     {
320         using ::boost::compute::detail::meta_kernel;
321 
322         std::stringstream source;
323 
324         ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
325 
326         source << "inline " << type_name<R>() << " lambda"
327                << ::boost::compute::detail::generate_argument_list<R(A1)>('x')
328                << "{\n"
329                << "    return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n"
330                << "}\n";
331 
332         return make_function_from_source<R(A1)>("lambda", source.str());
333     }
334 
335     template<class R, class A1, class A2>
operator function<R(A1, A2)>boost::compute::lambda::expression336     operator function<R(A1, A2)>() const
337     {
338         using ::boost::compute::detail::meta_kernel;
339 
340         std::stringstream source;
341 
342         ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
343         ::boost::compute::detail::meta_kernel_variable<A1> arg2("y");
344 
345         source << "inline " << type_name<R>() << " lambda"
346                << ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x')
347                << "{\n"
348                << "    return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n"
349                << "}\n";
350 
351         return make_function_from_source<R(A1, A2)>("lambda", source.str());
352     }
353 };
354 
355 // lambda expression domain
356 struct domain : proto::domain<proto::generator<expression> >
357 {
358 };
359 
360 } // end lambda namespace
361 } // end compute namespace
362 } // end boost namespace
363 
364 #endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
365