1 #ifndef STAN_MATH_OPENCL_REV_TO_MATRIX_HPP
2 #define STAN_MATH_OPENCL_REV_TO_MATRIX_HPP
3 #ifdef STAN_OPENCL
4 
5 #include <stan/math/opencl/matrix_cl.hpp>
6 #include <stan/math/opencl/kernel_generator.hpp>
7 #include <stan/math/rev/core.hpp>
8 
9 namespace stan {
10 namespace math {
11 
12 /**
13  * Returns a matrix representation of a vector or matrix in column-major
14  * order with the specified number of rows and columns.
15  *
16  * @tparam T_x type of the matrix
17  *
18  * @param x matrix
19  * @param m rows
20  * @param n columns
21  * @return Reshaped input matrix
22  * @throw <code>std::invalid_argument</code> if the sizes
23  * do not match
24  */
25 template <typename T_x,
26           require_all_kernel_expressions_and_none_scalar_t<T_x>* = nullptr>
to_matrix(const var_value<T_x> & x,int m,int n)27 inline var_value<matrix_cl<double>> to_matrix(const var_value<T_x>& x, int m,
28                                               int n) {
29   return make_callback_var(
30       to_matrix(value_of(x), m, n),
31       [x, m, n](vari_value<matrix_cl<double>>& res) mutable {
32         matrix_cl<double> x_adj_cpy = std::move(x.adj());
33         matrix_cl<double> reshaped(x_adj_cpy.buffer(), m, n);
34         for (cl::Event e : x_adj_cpy.read_events()) {
35           reshaped.add_read_event(e);
36         }
37         for (cl::Event e : x_adj_cpy.write_events()) {
38           reshaped.add_write_event(e);
39         }
40         reshaped += res.adj();
41         for (cl::Event e : reshaped.write_events()) {
42           x_adj_cpy.add_write_event(e);
43         }
44         x.adj() = std::move(x_adj_cpy);
45       });
46 }
47 
48 }  // namespace math
49 }  // namespace stan
50 #endif
51 #endif
52