1 #ifndef STAN_MATH_REV_FUN_DIAG_POST_MULTIPLY_HPP
2 #define STAN_MATH_REV_FUN_DIAG_POST_MULTIPLY_HPP
3
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/rev/core.hpp>
7
8 namespace stan {
9 namespace math {
10
11 /**
12 * Return the product of the matrix and a diagonal matrix formed from the vector
13 * or row_vector.
14 *
15 * @tparam T1 type of the matrix
16 * @tparam T2 type of the vector/row_vector
17 * @param m1 input matrix
18 * @param m2 input vector/row_vector
19 *
20 * @return product of the matrix and the diagonal matrix formed from the
21 * vector or row_vector.
22 */
23 template <typename T1, typename T2, require_matrix_t<T1>* = nullptr,
24 require_vector_t<T2>* = nullptr,
25 require_any_st_var<T1, T2>* = nullptr>
diag_post_multiply(const T1 & m1,const T2 & m2)26 auto diag_post_multiply(const T1& m1, const T2& m2) {
27 check_size_match("diag_post_multiply", "m2.size()", m2.size(), "m1.cols()",
28 m1.cols());
29 using inner_ret_type = decltype(value_of(m1) * value_of(m2).asDiagonal());
30 using ret_type = return_var_matrix_t<inner_ret_type, T1, T2>;
31
32 if (!is_constant<T1>::value && !is_constant<T2>::value) {
33 arena_t<promote_scalar_t<var, T1>> arena_m1 = m1;
34 arena_t<promote_scalar_t<var, T2>> arena_m2 = m2;
35 arena_t<ret_type> ret(arena_m1.val() * arena_m2.val().asDiagonal());
36 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
37 arena_m2.adj() += arena_m1.val().cwiseProduct(ret.adj()).colwise().sum();
38 arena_m1.adj() += ret.adj() * arena_m2.val().asDiagonal();
39 });
40 return ret_type(ret);
41 } else if (!is_constant<T1>::value) {
42 arena_t<promote_scalar_t<var, T1>> arena_m1 = m1;
43 arena_t<promote_scalar_t<double, T2>> arena_m2 = value_of(m2);
44 arena_t<ret_type> ret(arena_m1.val() * arena_m2.asDiagonal());
45 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
46 arena_m1.adj() += ret.adj() * arena_m2.val().asDiagonal();
47 });
48 return ret_type(ret);
49 } else if (!is_constant<T2>::value) {
50 arena_t<promote_scalar_t<double, T1>> arena_m1 = value_of(m1);
51 arena_t<promote_scalar_t<var, T2>> arena_m2 = m2;
52 arena_t<ret_type> ret(arena_m1 * arena_m2.val().asDiagonal());
53 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
54 arena_m2.adj() += arena_m1.val().cwiseProduct(ret.adj()).colwise().sum();
55 });
56 return ret_type(ret);
57 }
58 }
59
60 } // namespace math
61 } // namespace stan
62
63 #endif
64