1 #ifndef STAN_MATH_REV_FUN_REP_MATRIX_HPP
2 #define STAN_MATH_REV_FUN_REP_MATRIX_HPP
3
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/meta.hpp>
7 #include <stan/math/prim/err.hpp>
8 #include <stan/math/prim/fun/rep_matrix.hpp>
9
10 namespace stan {
11 namespace math {
12
13 /**
14 * Impl of rep_matrix returning an `var_value<Eigen::Matrix>` with a var scalar
15 * type.
16 * @tparam Ret A `var_value` with inner Eigen type.
17 * @tparam T A Scalar type.
18 * @param x A Scalar whose values are propogated to all values in the return
19 * matrix.
20 * @param m Number or rows.
21 * @param n Number of columns.
22 */
23 template <typename Ret, typename T, require_var_matrix_t<Ret>* = nullptr,
24 require_var_t<T>* = nullptr>
rep_matrix(const T & x,int m,int n)25 inline auto rep_matrix(const T& x, int m, int n) {
26 check_nonnegative("rep_matrix", "rows", m);
27 check_nonnegative("rep_matrix", "cols", n);
28 return make_callback_var(
29 value_type_t<Ret>::Constant(m, n, x.val()),
30 [x](auto& rep) mutable { x.adj() += rep.adj().sum(); });
31 }
32
33 /**
34 * Impl of rep_matrix returning a `var_value<Eigen::Matrix>` from a `var_value`
35 * with an inner Eigen vector type.
36 * @tparam Ret A `var_value` with inner Eigen dynamic matrix type.
37 * @tparam Vec A `var_value` with an inner Eigen vector type.
38 * @param x A `var_value` with inner Eigen vector type. For Row vectors the
39 * values are replacated rowwise and for column vectors the values are
40 * repliacated colwise.
41 * @param n Number of rows or columns.
42 */
43 template <typename Vec, require_var_matrix_t<Vec>* = nullptr>
rep_matrix(const Vec & x,int n)44 inline auto rep_matrix(const Vec& x, int n) {
45 if (is_row_vector<Vec>::value) {
46 check_nonnegative("rep_matrix", "rows", n);
47 return make_callback_var(x.val().replicate(n, 1), [x](auto& rep) mutable {
48 x.adj() += rep.adj().colwise().sum();
49 });
50 } else {
51 check_nonnegative("rep_matrix", "cols", n);
52 return make_callback_var(x.val().replicate(1, n), [x](auto& rep) mutable {
53 x.adj() += rep.adj().rowwise().sum();
54 });
55 }
56 }
57
58 } // namespace math
59 } // namespace stan
60
61 #endif
62