1 #ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
2 #define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
3 #ifdef STAN_OPENCL
4 
5 #include <stan/math/opencl/kernel_generator/operation_cl.hpp>
6 #include <stan/math/opencl/kernel_generator/load.hpp>
7 #include <stan/math/opencl/kernel_generator/scalar.hpp>
8 #include <stan/math/opencl/matrix_cl.hpp>
9 #include <stan/math/prim/meta.hpp>
10 #include <type_traits>
11 
12 namespace stan {
13 namespace math {
14 
15 /** \addtogroup opencl_kernel_generator
16  *  @{
17  */
18 
19 /**
20  * Converts any valid kernel generator expression into an operation. This is an
21  * overload for operations - a no-op
22  * @tparam T_operation type of the input operation
23  * @param a an operation
24  * @return operation
25  */
26 template <typename T_operation,
27           typename = std::enable_if_t<std::is_base_of<
28               operation_cl_base, std::remove_reference_t<T_operation>>::value>>
as_operation_cl(T_operation && a)29 inline T_operation&& as_operation_cl(T_operation&& a) {
30   return std::forward<T_operation>(a);
31 }
32 
33 /**
34  * Converts any valid kernel generator expression into an operation. This is an
35  * overload for scalars (arithmetic types). It wraps them into \c scalar_.
36  * @tparam T_scalar type of the input scalar
37  * @param a scalar
38  * @return \c scalar_ wrapping the input
39  */
40 template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
41           require_not_same_t<T_scalar, bool>* = nullptr>
as_operation_cl(const T_scalar a)42 inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
43   return scalar_<T_scalar>(a);
44 }
45 
46 /**
47  * Converts any valid kernel generator expression into an operation. This is an
48  * overload for bool scalars. It wraps them into \c scalar_<char> as \c bool can
49  * not be used as a type of a kernel argument.
50  * @param a scalar
51  * @return \c scalar_<char> wrapping the input
52  */
as_operation_cl(const bool a)53 inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
54 
55 /**
56  * Converts any valid kernel generator expression into an operation. This is an
57  * overload for \c matrix_cl. It wraps them into into \c load_.
58  * @tparam T_matrix_cl \c matrix_cl
59  * @param a \c matrix_cl
60  * @return \c load_ wrapping the input
61  */
62 template <typename T_matrix_cl,
63           typename = require_any_t<is_matrix_cl<T_matrix_cl>,
64                                    is_arena_matrix_cl<T_matrix_cl>>>
as_operation_cl(T_matrix_cl && a)65 inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
66   return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
67 }
68 
69 /**
70  * Type that results when converting any valid kernel generator expression into
71  * operation. If a function accepts a forwarding reference T&& a, the result of
72  * as_operation_cl(a) should be stored in a variable of type
73  * as_operation_cl_t<T>. If the return value of \c as_operation_cl() would be a
74  * rvalue reference, the reference is removed, so that a variable of this type
75  * actually stores the value.
76  */
77 template <typename T>
78 using as_operation_cl_t = std::conditional_t<
79     std::is_lvalue_reference<T>::value,
80     decltype(as_operation_cl(std::declval<T>())),
81     std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
82 
83 /** @}*/
84 }  // namespace math
85 }  // namespace stan
86 
87 #endif
88 #endif
89