1 #ifndef STAN_MATH_OPENCL_PRIM_TO_MATRIX_HPP
2 #define STAN_MATH_OPENCL_PRIM_TO_MATRIX_HPP
3 #ifdef STAN_OPENCL
4 
5 #include <stan/math/opencl/matrix_cl.hpp>
6 #include <stan/math/opencl/kernel_generator.hpp>
7 
8 namespace stan {
9 namespace math {
10 
11 /** \ingroup opencl
12  * Returns input matrix.
13  *
14  * @tparam T_x type of the matrix
15  *
16  * @param x matrix
17  * @return the matrix representation of the input
18  */
19 template <typename T_x,
20           require_nonscalar_prim_or_rev_kernel_expression_t<T_x>* = nullptr>
to_matrix(T_x && x)21 inline T_x to_matrix(T_x&& x) {
22   return std::forward<T_x>(x);
23 }
24 
25 /**
26  * Returns a matrix representation of a vector or matrix in column-major
27  * order with the specified number of rows and columns.
28  *
29  * @tparam T_x type of the matrix
30  *
31  * @param x matrix
32  * @param m rows
33  * @param n columns
34  * @return Reshaped input matrix
35  * @throw <code>std::invalid_argument</code> if the sizes
36  * do not match
37  */
38 template <typename T_x,
39           require_all_kernel_expressions_and_none_scalar_t<T_x>* = nullptr>
to_matrix(const T_x & x,int m,int n)40 inline matrix_cl<return_type_t<T_x>> to_matrix(const T_x& x, int m, int n) {
41   using res_scal = return_type_t<T_x>;
42   check_size_match("to_matrix", "rows * columns", "", m * n, "input size", "",
43                    x.size());
44   matrix_cl<res_scal> res(m, n);
45   matrix_cl<res_scal> tmp(res.buffer(), x.rows(), x.cols());
46   tmp = x;
47   for (cl::Event e : tmp.write_events()) {
48     res.add_write_event(e);
49   }
50   return res;
51 }
52 
53 /**
54  * Returns a matrix representation of the vector or matrix in column-major or
55  * row major order with the specified number of rows and columns.
56  *
57  * @tparam T_x type of the matrix
58  * @param x matrix
59  * @param m rows
60  * @param n columns
61  * @param col_major column-major indicator
62  * if 1, output matrix is transversed in column-major order
63  * if 0, output matrix is transversed in row-major order
64  * @return the matrix representation of the input
65  * @throw <code>std::invalid_argument</code>
66  * if the sizes do not match
67  */
68 template <typename T_x,
69           require_nonscalar_prim_or_rev_kernel_expression_t<T_x>* = nullptr>
to_matrix(const T_x & x,int m,int n,bool col_major)70 inline auto to_matrix(const T_x& x, int m, int n, bool col_major)
71     -> decltype(to_matrix(x, m, n)) {
72   if (col_major) {
73     return to_matrix(x, m, n);
74   } else {
75     return transpose(to_matrix(transpose(x), n, m));
76   }
77 }
78 
79 }  // namespace math
80 }  // namespace stan
81 #endif
82 #endif
83