1 #ifndef STAN_MATH_REV_FUN_REP_MATRIX_HPP
2 #define STAN_MATH_REV_FUN_REP_MATRIX_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/meta.hpp>
7 #include <stan/math/prim/err.hpp>
8 #include <stan/math/prim/fun/rep_matrix.hpp>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * Impl of rep_matrix returning an `var_value<Eigen::Matrix>` with a var scalar
15  * type.
16  * @tparam Ret A `var_value` with inner Eigen type.
17  * @tparam T A Scalar type.
18  * @param x A Scalar whose values are propogated to all values in the return
19  * matrix.
20  * @param m Number or rows.
21  * @param n Number of columns.
22  */
23 template <typename Ret, typename T, require_var_matrix_t<Ret>* = nullptr,
24           require_var_t<T>* = nullptr>
rep_matrix(const T & x,int m,int n)25 inline auto rep_matrix(const T& x, int m, int n) {
26   check_nonnegative("rep_matrix", "rows", m);
27   check_nonnegative("rep_matrix", "cols", n);
28   return make_callback_var(
29       value_type_t<Ret>::Constant(m, n, x.val()),
30       [x](auto& rep) mutable { x.adj() += rep.adj().sum(); });
31 }
32 
33 /**
34  * Impl of rep_matrix returning a `var_value<Eigen::Matrix>` from a `var_value`
35  * with an inner Eigen vector type.
36  * @tparam Ret A `var_value` with inner Eigen dynamic matrix type.
37  * @tparam Vec A `var_value` with an inner Eigen vector type.
38  * @param x A `var_value` with inner Eigen vector type. For Row vectors the
39  * values are replacated rowwise and for column vectors the values are
40  * repliacated colwise.
41  * @param n Number of rows or columns.
42  */
43 template <typename Vec, require_var_matrix_t<Vec>* = nullptr>
rep_matrix(const Vec & x,int n)44 inline auto rep_matrix(const Vec& x, int n) {
45   if (is_row_vector<Vec>::value) {
46     check_nonnegative("rep_matrix", "rows", n);
47     return make_callback_var(x.val().replicate(n, 1), [x](auto& rep) mutable {
48       x.adj() += rep.adj().colwise().sum();
49     });
50   } else {
51     check_nonnegative("rep_matrix", "cols", n);
52     return make_callback_var(x.val().replicate(1, n), [x](auto& rep) mutable {
53       x.adj() += rep.adj().rowwise().sum();
54     });
55   }
56 }
57 
58 }  // namespace math
59 }  // namespace stan
60 
61 #endif
62