1 #ifndef STAN_MATH_OPENCL_PRIM_TO_MATRIX_HPP
2 #define STAN_MATH_OPENCL_PRIM_TO_MATRIX_HPP
3 #ifdef STAN_OPENCL
4
5 #include <stan/math/opencl/matrix_cl.hpp>
6 #include <stan/math/opencl/kernel_generator.hpp>
7
8 namespace stan {
9 namespace math {
10
11 /** \ingroup opencl
12 * Returns input matrix.
13 *
14 * @tparam T_x type of the matrix
15 *
16 * @param x matrix
17 * @return the matrix representation of the input
18 */
19 template <typename T_x,
20 require_nonscalar_prim_or_rev_kernel_expression_t<T_x>* = nullptr>
to_matrix(T_x && x)21 inline T_x to_matrix(T_x&& x) {
22 return std::forward<T_x>(x);
23 }
24
25 /**
26 * Returns a matrix representation of a vector or matrix in column-major
27 * order with the specified number of rows and columns.
28 *
29 * @tparam T_x type of the matrix
30 *
31 * @param x matrix
32 * @param m rows
33 * @param n columns
34 * @return Reshaped input matrix
35 * @throw <code>std::invalid_argument</code> if the sizes
36 * do not match
37 */
38 template <typename T_x,
39 require_all_kernel_expressions_and_none_scalar_t<T_x>* = nullptr>
to_matrix(const T_x & x,int m,int n)40 inline matrix_cl<return_type_t<T_x>> to_matrix(const T_x& x, int m, int n) {
41 using res_scal = return_type_t<T_x>;
42 check_size_match("to_matrix", "rows * columns", "", m * n, "input size", "",
43 x.size());
44 matrix_cl<res_scal> res(m, n);
45 matrix_cl<res_scal> tmp(res.buffer(), x.rows(), x.cols());
46 tmp = x;
47 for (cl::Event e : tmp.write_events()) {
48 res.add_write_event(e);
49 }
50 return res;
51 }
52
53 /**
54 * Returns a matrix representation of the vector or matrix in column-major or
55 * row major order with the specified number of rows and columns.
56 *
57 * @tparam T_x type of the matrix
58 * @param x matrix
59 * @param m rows
60 * @param n columns
61 * @param col_major column-major indicator
62 * if 1, output matrix is transversed in column-major order
63 * if 0, output matrix is transversed in row-major order
64 * @return the matrix representation of the input
65 * @throw <code>std::invalid_argument</code>
66 * if the sizes do not match
67 */
68 template <typename T_x,
69 require_nonscalar_prim_or_rev_kernel_expression_t<T_x>* = nullptr>
to_matrix(const T_x & x,int m,int n,bool col_major)70 inline auto to_matrix(const T_x& x, int m, int n, bool col_major)
71 -> decltype(to_matrix(x, m, n)) {
72 if (col_major) {
73 return to_matrix(x, m, n);
74 } else {
75 return transpose(to_matrix(transpose(x), n, m));
76 }
77 }
78
79 } // namespace math
80 } // namespace stan
81 #endif
82 #endif
83