1 #ifndef STAN_MATH_REV_FUN_DIAG_PRE_MULTIPLY_HPP
2 #define STAN_MATH_REV_FUN_DIAG_PRE_MULTIPLY_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/rev/core.hpp>
7 
8 namespace stan {
9 namespace math {
10 
11 /**
12  * Return the product of the diagonal matrix formed from the vector
13  * or row_vector and a matrix.
14  *
15  * @tparam T1 type of the vector/row_vector
16  * @tparam T2 type of the matrix
17  * @param m1 input vector/row_vector
18  * @param m2 input matrix
19  *
20  * @return product of the diagonal matrix formed from the
21  * vector or row_vector and a matrix.
22  */
23 template <typename T1, typename T2, require_vector_t<T1>* = nullptr,
24           require_matrix_t<T2>* = nullptr,
25           require_any_st_var<T1, T2>* = nullptr>
diag_pre_multiply(const T1 & m1,const T2 & m2)26 auto diag_pre_multiply(const T1& m1, const T2& m2) {
27   check_size_match("diag_pre_multiply", "m1.size()", m1.size(), "m2.rows()",
28                    m2.rows());
29   using inner_ret_type = decltype(value_of(m1).asDiagonal() * value_of(m2));
30   using ret_type = return_var_matrix_t<inner_ret_type, T1, T2>;
31   if (!is_constant<T1>::value && !is_constant<T2>::value) {
32     arena_t<promote_scalar_t<var, T1>> arena_m1 = m1;
33     arena_t<promote_scalar_t<var, T2>> arena_m2 = m2;
34     arena_t<ret_type> ret(arena_m1.val().asDiagonal() * arena_m2.val());
35     reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
36       arena_m1.adj() += arena_m2.val().cwiseProduct(ret.adj()).rowwise().sum();
37       arena_m2.adj() += arena_m1.val().asDiagonal() * ret.adj();
38     });
39     return ret_type(ret);
40   } else if (!is_constant<T1>::value) {
41     arena_t<promote_scalar_t<var, T1>> arena_m1 = m1;
42     arena_t<promote_scalar_t<double, T2>> arena_m2 = value_of(m2);
43     arena_t<ret_type> ret(arena_m1.val().asDiagonal() * arena_m2);
44     reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
45       arena_m1.adj() += arena_m2.val().cwiseProduct(ret.adj()).rowwise().sum();
46     });
47     return ret_type(ret);
48   } else if (!is_constant<T2>::value) {
49     arena_t<promote_scalar_t<double, T1>> arena_m1 = value_of(m1);
50     arena_t<promote_scalar_t<var, T2>> arena_m2 = m2;
51     arena_t<ret_type> ret(arena_m1.asDiagonal() * arena_m2.val());
52     reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
53       arena_m2.adj() += arena_m1.val().asDiagonal() * ret.adj();
54     });
55     return ret_type(ret);
56   }
57 }
58 
59 }  // namespace math
60 }  // namespace stan
61 
62 #endif
63