1 #ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
2 #define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
3 #ifdef STAN_OPENCL
4
5 #include <stan/math/opencl/kernel_generator/operation_cl.hpp>
6 #include <stan/math/opencl/kernel_generator/load.hpp>
7 #include <stan/math/opencl/kernel_generator/scalar.hpp>
8 #include <stan/math/opencl/matrix_cl.hpp>
9 #include <stan/math/prim/meta.hpp>
10 #include <type_traits>
11
12 namespace stan {
13 namespace math {
14
15 /** \addtogroup opencl_kernel_generator
16 * @{
17 */
18
19 /**
20 * Converts any valid kernel generator expression into an operation. This is an
21 * overload for operations - a no-op
22 * @tparam T_operation type of the input operation
23 * @param a an operation
24 * @return operation
25 */
26 template <typename T_operation,
27 typename = std::enable_if_t<std::is_base_of<
28 operation_cl_base, std::remove_reference_t<T_operation>>::value>>
as_operation_cl(T_operation && a)29 inline T_operation&& as_operation_cl(T_operation&& a) {
30 return std::forward<T_operation>(a);
31 }
32
33 /**
34 * Converts any valid kernel generator expression into an operation. This is an
35 * overload for scalars (arithmetic types). It wraps them into \c scalar_.
36 * @tparam T_scalar type of the input scalar
37 * @param a scalar
38 * @return \c scalar_ wrapping the input
39 */
40 template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
41 require_not_same_t<T_scalar, bool>* = nullptr>
as_operation_cl(const T_scalar a)42 inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
43 return scalar_<T_scalar>(a);
44 }
45
46 /**
47 * Converts any valid kernel generator expression into an operation. This is an
48 * overload for bool scalars. It wraps them into \c scalar_<char> as \c bool can
49 * not be used as a type of a kernel argument.
50 * @param a scalar
51 * @return \c scalar_<char> wrapping the input
52 */
as_operation_cl(const bool a)53 inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
54
55 /**
56 * Converts any valid kernel generator expression into an operation. This is an
57 * overload for \c matrix_cl. It wraps them into into \c load_.
58 * @tparam T_matrix_cl \c matrix_cl
59 * @param a \c matrix_cl
60 * @return \c load_ wrapping the input
61 */
62 template <typename T_matrix_cl,
63 typename = require_any_t<is_matrix_cl<T_matrix_cl>,
64 is_arena_matrix_cl<T_matrix_cl>>>
as_operation_cl(T_matrix_cl && a)65 inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
66 return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
67 }
68
69 /**
70 * Type that results when converting any valid kernel generator expression into
71 * operation. If a function accepts a forwarding reference T&& a, the result of
72 * as_operation_cl(a) should be stored in a variable of type
73 * as_operation_cl_t<T>. If the return value of \c as_operation_cl() would be a
74 * rvalue reference, the reference is removed, so that a variable of this type
75 * actually stores the value.
76 */
77 template <typename T>
78 using as_operation_cl_t = std::conditional_t<
79 std::is_lvalue_reference<T>::value,
80 decltype(as_operation_cl(std::declval<T>())),
81 std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
82
83 /** @}*/
84 } // namespace math
85 } // namespace stan
86
87 #endif
88 #endif
89