1 #ifndef STAN_MATH_OPENCL_REV_COPY_HPP
2 #define STAN_MATH_OPENCL_REV_COPY_HPP
3 #ifdef STAN_OPENCL
4 
5 #include <stan/math/opencl/rev/vari.hpp>
6 #include <stan/math/opencl/rev/arena_type.hpp>
7 #include <stan/math/opencl/rev/to_arena.hpp>
8 #include <stan/math/opencl/copy.hpp>
9 #include <stan/math/rev/core.hpp>
10 #include <stan/math/rev/meta.hpp>
11 #include <stan/math/prim/err.hpp>
12 #include <stan/math/prim/fun/Eigen.hpp>
13 #include <stan/math/prim/fun/vec_concat.hpp>
14 
15 #include <CL/opencl.hpp>
16 #include <iostream>
17 #include <vector>
18 #include <algorithm>
19 #include <type_traits>
20 
21 namespace stan {
22 namespace math {
23 
24 /** \ingroup opencl
25  * Copies the source var containing Eigen matrices to destination var that has
26  * data stored on the OpenCL device.
27  *
28  * @tparam T type of the Eigen matrix
29  * @param a source Eigen matrix
30  * @return var with a copy of the data on the OpenCL device
31  */
32 template <typename T>
to_matrix_cl(const var_value<T> & a)33 inline var_value<matrix_cl<value_type_t<T>>> to_matrix_cl(
34     const var_value<T>& a) {
35   return make_callback_var(to_matrix_cl(a.val()), [a](auto& res_vari) mutable {
36     a.adj() += from_matrix_cl<plain_type_t<T>>(res_vari.adj());
37   });
38 }
39 
40 /** \ingroup opencl
41  * Copies the source std::vector of vars to a destination var that has
42  * data stored on the OpenCL device.
43  *
44  * @tparam T type of the std::vector
45  * @param a source Eigen matrix
46  * @return var with a copy of the data on the OpenCL device
47  */
48 template <typename T, require_stan_scalar_t<T>* = nullptr>
to_matrix_cl(const std::vector<var_value<T>> & a)49 inline var_value<matrix_cl<value_type_t<T>>> to_matrix_cl(
50     const std::vector<var_value<T>>& a) {
51   return to_matrix_cl(
52       Eigen::Map<const Eigen::Matrix<var_value<T>, Eigen::Dynamic, 1>>(
53           a.data(), a.size()));
54 }
55 
56 /** \ingroup opencl
57  * Copies the source Eigen matrix of vars to
58  * the destination matrix that is stored
59  * on the OpenCL device.
60  *
61  * @tparam R Compile time rows of the Eigen matrix
62  * @tparam C Compile time columns of the Eigen matrix
63  * @param src source Eigen matrix
64  * @return matrix_cl with a copy of the data in the source matrix
65  */
66 template <typename T, require_eigen_vt<is_var, T>* = nullptr>
to_matrix_cl(const T & src)67 inline var_value<matrix_cl<value_type_t<value_type_t<T>>>> to_matrix_cl(
68     const T& src) {
69   arena_t<T> src_stacked = src;
70 
71   return make_callback_var(
72       to_matrix_cl(src_stacked.val()), [src_stacked](auto& res_vari) mutable {
73         src_stacked.adj() += from_matrix_cl<
74             Eigen::Matrix<double, T::RowsAtCompileTime, T::ColsAtCompileTime>>(
75             res_vari.adj());
76       });
77 }
78 
79 /** \ingroup opencl
80  * Copies the source vector of Eigen matrices of vars to
81  * the destination matrix that is stored
82  * on the OpenCL device. Each element of the vector is stored into one column of
83  * the returned matrix_cl.
84  *
85  * @param src source vector of Eigen matrices
86  * @return matrix_cl with a copy of the data in the source matrix
87  */
88 template <typename T, require_eigen_vt<is_var, T>* = nullptr>
to_matrix_cl(const std::vector<T> & src)89 inline var_value<matrix_cl<value_type_t<value_type_t<T>>>> to_matrix_cl(
90     const std::vector<T>& src) {
91   auto src_stacked = to_arena(src);
92 
93   return make_callback_var(
94       to_matrix_cl(value_of(src_stacked)),
95       [src_stacked](auto& res_vari) mutable {
96         Eigen::MatrixXd adj = from_matrix_cl(res_vari.adj());
97         for (int i = 0; i < src_stacked.size(); i++) {
98           src_stacked[i].adj()
99               += Eigen::Map<plain_type_t<decltype(src_stacked[i].adj())>>(
100                   adj.data() + adj.rows() * i, src_stacked[i].rows(),
101                   src_stacked[i].cols());
102         }
103       });
104 }
105 
106 /** \ingroup opencl
107  * Copies the source var that has data stored on the OpenCL device to
108  * destination var containing Eigen matrix.
109  *
110  * @tparam T_dst destination type
111  * @tparam T type of the matrix or expression on the OpenCL device
112  * @param a source matrix_cl or expression
113  * @return var with a copy of the data on the host
114  */
115 template <typename T_dst, typename T,
116           require_var_vt<is_eigen, T_dst>* = nullptr,
117           require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)118 inline T_dst from_matrix_cl(const var_value<T>& a) {
119   return make_callback_var(
120       from_matrix_cl<Eigen::Matrix<double, T_dst::RowsAtCompileTime,
121                                    T_dst::ColsAtCompileTime>>(a.val()),
122       [a](auto& res_vari) mutable { a.adj() += to_matrix_cl(res_vari.adj()); });
123 }
124 
125 /** \ingroup opencl
126  * Copies the source var that has data stored on the OpenCL device to
127  * destination Eigen matrix containing vars.
128  *
129  * @tparam T_dst destination type
130  * @tparam T type of the matrix or expression on the OpenCL device
131  * @param a source matrix_cl or expression
132  * @return var with a copy of the data on the host
133  */
134 template <typename T_dst, typename T,
135           require_eigen_vt<is_var, T_dst>* = nullptr,
136           require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)137 inline T_dst from_matrix_cl(const var_value<T>& a) {
138   arena_t<T_dst> res
139       = from_matrix_cl<Eigen::Matrix<double, T_dst::RowsAtCompileTime,
140                                      T_dst::ColsAtCompileTime>>(a.val());
141   reverse_pass_callback(
142       [a, res]() mutable { a.adj() += to_matrix_cl(res.adj()); });
143   return res;
144 }
145 
146 /** \ingroup opencl
147  * Copies the source var that has data stored on the OpenCL device to
148  * destination `std::vector` containing vars.
149  *
150  * @tparam T_dst destination type
151  * @tparam T type of the matrix or expression on the OpenCL device
152  * @param a source matrix_cl or expression
153  * @return var with a copy of the data on the host
154  */
155 template <typename T_dst, typename T,
156           require_std_vector_vt<is_var, T_dst>* = nullptr,
157           require_all_stan_scalar_t<value_type_t<T_dst>>* = nullptr,
158           require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)159 inline T_dst from_matrix_cl(const var_value<T>& a) {
160   check_size_match("from_matrix_cl<std::vector<var>>", "src.cols()", a.cols(),
161                    "dst.cols()", 1);
162   std::vector<double> val = from_matrix_cl<std::vector<double>>(a.val());
163   arena_t<T_dst> res(val.begin(), val.end());
164   reverse_pass_callback([a, res]() mutable {
165     a.adj() += to_matrix_cl(as_column_vector_or_scalar(res).adj());
166   });
167   return {res.begin(), res.end()};
168 }
169 
170 /** \ingroup opencl
171  * Copies the source var that has data stored on the OpenCL device to
172  * destination std::vector containing either Eigen vectors of vars or vars
173  * containing Eigen vectors.
174  *
175  * @tparam T_dst destination type
176  * @tparam T type of the matrix or expression on the OpenCL device
177  * @param a source matrix_cl or expression
178  * @return var with a copy of the data on the host
179  */
180 template <typename T_dst, typename T, require_std_vector_t<T_dst>* = nullptr,
181           require_rev_vector_t<value_type_t<T_dst>>* = nullptr,
182           require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & a)183 inline T_dst from_matrix_cl(const var_value<T>& a) {
184   Eigen::MatrixXd val = from_matrix_cl(a.val());
185   arena_t<T_dst> res;
186   res.reserve(a.cols());
187   for (int i = 0; i < a.cols(); i++) {
188     res.emplace_back(val.col(i));
189   }
190   reverse_pass_callback([a, res]() mutable {
191     Eigen::MatrixXd adj(a.rows(), a.cols());
192     for (int i = 0; i < a.cols(); i++) {
193       adj.col(i) = res[i].adj();
194     }
195     a.adj() += to_matrix_cl(adj);
196   });
197   return {res.begin(), res.end()};
198 }
199 
200 /** \ingroup opencl
201  * Copies the source var that has data stored on the OpenCL device to
202  * destination Eigen matrix containing vars.
203  *
204  * @tparam T type of the matrix or expression on the OpenCL device
205  * @param src source matrix_cl or expression
206  * @return var with a copy of the data on the host
207  */
208 template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
from_matrix_cl(const var_value<T> & src)209 auto from_matrix_cl(const var_value<T>& src) {
210   return from_matrix_cl<var_value<Eigen::MatrixXd>>(src);
211 }
212 
213 }  // namespace math
214 }  // namespace stan
215 #endif
216 #endif
217