1 #ifndef STAN_MATH_OPENCL_REV_COPY_HPP
2 #define STAN_MATH_OPENCL_REV_COPY_HPP
3 #ifdef STAN_OPENCL
4
5 #include <stan/math/opencl/rev/vari.hpp>
6 #include <stan/math/opencl/rev/arena_type.hpp>
7 #include <stan/math/opencl/rev/to_arena.hpp>
8 #include <stan/math/opencl/copy.hpp>
9 #include <stan/math/rev/core.hpp>
10 #include <stan/math/rev/meta.hpp>
11 #include <stan/math/prim/err.hpp>
12 #include <stan/math/prim/fun/Eigen.hpp>
13 #include <stan/math/prim/fun/vec_concat.hpp>
14
15 #include <CL/opencl.hpp>
16 #include <iostream>
17 #include <vector>
18 #include <algorithm>
19 #include <type_traits>
20
21 namespace stan {
22 namespace math {
23
24 /** \ingroup opencl
25 * Copies the source var containing Eigen matrices to destination var that has
26 * data stored on the OpenCL device.
27 *
28 * @tparam T type of the Eigen matrix
29 * @param a source Eigen matrix
30 * @return var with a copy of the data on the OpenCL device
31 */
32 template <typename T>
to_matrix_cl(const var_value<T> & a)33 inline var_value<matrix_cl<value_type_t<T>>> to_matrix_cl(
34 const var_value<T>& a) {
35 return make_callback_var(to_matrix_cl(a.val()), [a](auto& res_vari) mutable {
36 a.adj() += from_matrix_cl<plain_type_t<T>>(res_vari.adj());
37 });
38 }
39
40 /** \ingroup opencl
41 * Copies the source std::vector of vars to a destination var that has
42 * data stored on the OpenCL device.
43 *
44 * @tparam T type of the std::vector
45 * @param a source Eigen matrix
46 * @return var with a copy of the data on the OpenCL device
47 */
48 template <typename T, require_stan_scalar_t<T>* = nullptr>
to_matrix_cl(const std::vector<var_value<T>> & a)49 inline var_value<matrix_cl<value_type_t<T>>> to_matrix_cl(
50 const std::vector<var_value<T>>& a) {
51 return to_matrix_cl(
52 Eigen::Map<const Eigen::Matrix<var_value<T>, Eigen::Dynamic, 1>>(
53 a.data(), a.size()));
54 }
55
56 /** \ingroup opencl
57 * Copies the source Eigen matrix of vars to
58 * the destination matrix that is stored
59 * on the OpenCL device.
60 *
61 * @tparam R Compile time rows of the Eigen matrix
62 * @tparam C Compile time columns of the Eigen matrix
63 * @param src source Eigen matrix
64 * @return matrix_cl with a copy of the data in the source matrix
65 */
66 template <typename T, require_eigen_vt<is_var, T>* = nullptr>
to_matrix_cl(const T & src)67 inline var_value<matrix_cl<value_type_t<value_type_t<T>>>> to_matrix_cl(
68 const T& src) {
69 arena_t<T> src_stacked = src;
70
71 return make_callback_var(
72 to_matrix_cl(src_stacked.val()), [src_stacked](auto& res_vari) mutable {
73 src_stacked.adj() += from_matrix_cl<
74 Eigen::Matrix<double, T::RowsAtCompileTime, T::ColsAtCompileTime>>(
75 res_vari.adj());
76 });
77 }
78
79 /** \ingroup opencl
80 * Copies the source vector of Eigen matrices of vars to
81 * the destination matrix that is stored
82 * on the OpenCL device. Each element of the vector is stored into one column of
83 * the returned matrix_cl.
84 *
85 * @param src source vector of Eigen matrices
86 * @return matrix_cl with a copy of the data in the source matrix
87 */
88 template <typename T, require_eigen_vt<is_var, T>* = nullptr>
to_matrix_cl(const std::vector<T> & src)89 inline var_value<matrix_cl<value_type_t<value_type_t<T>>>> to_matrix_cl(
90 const std::vector<T>& src) {
91 auto src_stacked = to_arena(src);
92
93 return make_callback_var(
94 to_matrix_cl(value_of(src_stacked)),
95 [src_stacked](auto& res_vari) mutable {
96 Eigen::MatrixXd adj = from_matrix_cl(res_vari.adj());
97 for (int i = 0; i < src_stacked.size(); i++) {
98 src_stacked[i].adj()
99 += Eigen::Map<plain_type_t<decltype(src_stacked[i].adj())>>(
100 adj.data() + adj.rows() * i, src_stacked[i].rows(),
101 src_stacked[i].cols());
102 }
103 });
104 }
105
106 /** \ingroup opencl
107 * Copies the source var that has data stored on the OpenCL device to
108 * destination var containing Eigen matrix.
109 *
110 * @tparam T_dst destination type
111 * @tparam T type of the matrix or expression on the OpenCL device
112 * @param a source matrix_cl or expression
113 * @return var with a copy of the data on the host
114 */
115 template <typename T_dst, typename T,
116 require_var_vt<is_eigen, T_dst>* = nullptr,
117 require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)118 inline T_dst from_matrix_cl(const var_value<T>& a) {
119 return make_callback_var(
120 from_matrix_cl<Eigen::Matrix<double, T_dst::RowsAtCompileTime,
121 T_dst::ColsAtCompileTime>>(a.val()),
122 [a](auto& res_vari) mutable { a.adj() += to_matrix_cl(res_vari.adj()); });
123 }
124
125 /** \ingroup opencl
126 * Copies the source var that has data stored on the OpenCL device to
127 * destination Eigen matrix containing vars.
128 *
129 * @tparam T_dst destination type
130 * @tparam T type of the matrix or expression on the OpenCL device
131 * @param a source matrix_cl or expression
132 * @return var with a copy of the data on the host
133 */
134 template <typename T_dst, typename T,
135 require_eigen_vt<is_var, T_dst>* = nullptr,
136 require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)137 inline T_dst from_matrix_cl(const var_value<T>& a) {
138 arena_t<T_dst> res
139 = from_matrix_cl<Eigen::Matrix<double, T_dst::RowsAtCompileTime,
140 T_dst::ColsAtCompileTime>>(a.val());
141 reverse_pass_callback(
142 [a, res]() mutable { a.adj() += to_matrix_cl(res.adj()); });
143 return res;
144 }
145
146 /** \ingroup opencl
147 * Copies the source var that has data stored on the OpenCL device to
148 * destination `std::vector` containing vars.
149 *
150 * @tparam T_dst destination type
151 * @tparam T type of the matrix or expression on the OpenCL device
152 * @param a source matrix_cl or expression
153 * @return var with a copy of the data on the host
154 */
155 template <typename T_dst, typename T,
156 require_std_vector_vt<is_var, T_dst>* = nullptr,
157 require_all_stan_scalar_t<value_type_t<T_dst>>* = nullptr,
158 require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)159 inline T_dst from_matrix_cl(const var_value<T>& a) {
160 check_size_match("from_matrix_cl<std::vector<var>>", "src.cols()", a.cols(),
161 "dst.cols()", 1);
162 std::vector<double> val = from_matrix_cl<std::vector<double>>(a.val());
163 arena_t<T_dst> res(val.begin(), val.end());
164 reverse_pass_callback([a, res]() mutable {
165 a.adj() += to_matrix_cl(as_column_vector_or_scalar(res).adj());
166 });
167 return {res.begin(), res.end()};
168 }
169
170 /** \ingroup opencl
171 * Copies the source var that has data stored on the OpenCL device to
172 * destination std::vector containing either Eigen vectors of vars or vars
173 * containing Eigen vectors.
174 *
175 * @tparam T_dst destination type
176 * @tparam T type of the matrix or expression on the OpenCL device
177 * @param a source matrix_cl or expression
178 * @return var with a copy of the data on the host
179 */
180 template <typename T_dst, typename T, require_std_vector_t<T_dst>* = nullptr,
181 require_rev_vector_t<value_type_t<T_dst>>* = nullptr,
182 require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)183 inline T_dst from_matrix_cl(const var_value<T>& a) {
184 Eigen::MatrixXd val = from_matrix_cl(a.val());
185 arena_t<T_dst> res;
186 res.reserve(a.cols());
187 for (int i = 0; i < a.cols(); i++) {
188 res.emplace_back(val.col(i));
189 }
190 reverse_pass_callback([a, res]() mutable {
191 Eigen::MatrixXd adj(a.rows(), a.cols());
192 for (int i = 0; i < a.cols(); i++) {
193 adj.col(i) = res[i].adj();
194 }
195 a.adj() += to_matrix_cl(adj);
196 });
197 return {res.begin(), res.end()};
198 }
199
200 /** \ingroup opencl
201 * Copies the source var that has data stored on the OpenCL device to
202 * destination Eigen matrix containing vars.
203 *
204 * @tparam T type of the matrix or expression on the OpenCL device
205 * @param src source matrix_cl or expression
206 * @return var with a copy of the data on the host
207 */
208 template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & src)209 auto from_matrix_cl(const var_value<T>& src) {
210 return from_matrix_cl<var_value<Eigen::MatrixXd>>(src);
211 }
212
213 } // namespace math
214 } // namespace stan
215 #endif
216 #endif
217