1 #ifndef STAN_MATH_PRIM_FUN_CSR_TO_DENSE_MATRIX_HPP
2 #define STAN_MATH_PRIM_FUN_CSR_TO_DENSE_MATRIX_HPP
3
4 #include <stan/math/prim/err.hpp>
5 #include <stan/math/prim/fun/csr_u_to_z.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/to_ref.hpp>
8 #include <vector>
9
10 namespace stan {
11 namespace math {
12
13 /** \addtogroup csr_format
14 * @{
15 */
16 /**
17 * Construct a dense Eigen matrix from the CSR format components.
18 *
19 * @tparam T type of the matrix
20 * @param[in] m Number of matrix rows.
21 * @param[in] n Number of matrix columns.
22 * @param[in] w Values of non-zero matrix entries.
23 * @param[in] v Column index for each value in w.
24 * @param[in] u Index of where each row starts in w.
25 * @return Dense matrix defined by previous arguments.
26 * @throw std::domain_error If the arguments do not define a matrix.
27 * @throw std::invalid_argument if m/n/w/v/u are not internally
28 * consistent, as defined by the indexing scheme. Extractors are
29 * defined in Stan which guarantee a consistent set of m/n/w/v/u
30 * for a given sparse matrix.
31 * @throw std::out_of_range if any of the indices are out of range.
32 */
33 template <typename T>
34 inline Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, Eigen::Dynamic>
csr_to_dense_matrix(int m,int n,const T & w,const std::vector<int> & v,const std::vector<int> & u)35 csr_to_dense_matrix(int m, int n, const T& w, const std::vector<int>& v,
36 const std::vector<int>& u) {
37 using Eigen::Dynamic;
38 using Eigen::Matrix;
39
40 check_positive("csr_to_dense_matrix", "m", m);
41 check_positive("csr_to_dense_matrix", "n", n);
42 check_size_match("csr_to_dense_matrix", "m", m, "u", u.size() - 1);
43 check_size_match("csr_to_dense_matrix", "w", w.size(), "v", v.size());
44 check_size_match("csr_to_dense_matrix", "u/z",
45 u[m - 1] + csr_u_to_z(u, m - 1) - 1, "v", v.size());
46 for (int i : v) {
47 check_range("csr_to_dense_matrix", "v[]", n, i);
48 }
49 const auto& w_ref = to_ref(w);
50 Matrix<value_type_t<T>, Dynamic, Dynamic> result(m, n);
51 result.setZero();
52 for (int row = 0; row < m; ++row) {
53 int row_end_in_w = (u[row] - stan::error_index::value) + csr_u_to_z(u, row);
54 check_range("csr_to_dense_matrix", "w", w.size(), row_end_in_w);
55 for (int nze = u[row] - stan::error_index::value; nze < row_end_in_w;
56 ++nze) {
57 // row is row index, v[nze] is column index. w[nze] is entry value.
58 check_range("csr_to_dense_matrix", "j", n, v[nze]);
59 result(row, v[nze] - stan::error_index::value) = w_ref.coeff(nze);
60 }
61 }
62 return result;
63 }
64 /** @} */ // end of csr_format group
65
66 } // namespace math
67 } // namespace stan
68
69 #endif
70