1 #ifndef STAN_MATH_REV_FUN_ELT_MULTIPLY_HPP
2 #define STAN_MATH_REV_FUN_ELT_MULTIPLY_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/meta.hpp>
6 #include <stan/math/prim/err.hpp>
7 #include <stan/math/prim/fun/eval.hpp>
8 #include <stan/math/rev/core.hpp>
9 #include <stan/math/rev/fun/multiply.hpp>
10 
11 namespace stan {
12 namespace math {
13 
14 /**
15  * Return the elementwise multiplication of the specified
16  * matrices.
17  *
18  * @tparam Mat1 type of the first matrix or expression
19  * @tparam Mat2 type of the second matrix or expression
20  *
21  * @param m1 First matrix or expression
22  * @param m2 Second matrix or expression
23  * @return Elementwise product of matrices.
24  */
25 template <typename Mat1, typename Mat2,
26           require_all_matrix_t<Mat1, Mat2>* = nullptr,
27           require_any_rev_matrix_t<Mat1, Mat2>* = nullptr>
elt_multiply(const Mat1 & m1,const Mat2 & m2)28 auto elt_multiply(const Mat1& m1, const Mat2& m2) {
29   check_matching_dims("elt_multiply", "m1", m1, "m2", m2);
30   using inner_ret_type = decltype(value_of(m1).cwiseProduct(value_of(m2)));
31   using ret_type = return_var_matrix_t<inner_ret_type, Mat1, Mat2>;
32   if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
33     arena_t<promote_scalar_t<var, Mat1>> arena_m1 = m1;
34     arena_t<promote_scalar_t<var, Mat2>> arena_m2 = m2;
35     arena_t<ret_type> ret(arena_m1.val().cwiseProduct(arena_m2.val()));
36     reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
37       for (Eigen::Index j = 0; j < arena_m2.cols(); ++j) {
38         for (Eigen::Index i = 0; i < arena_m2.rows(); ++i) {
39           const auto ret_adj = ret.adj().coeffRef(i, j);
40           arena_m1.adj().coeffRef(i, j) += arena_m2.val().coeff(i, j) * ret_adj;
41           arena_m2.adj().coeffRef(i, j) += arena_m1.val().coeff(i, j) * ret_adj;
42         }
43       }
44     });
45     return ret_type(ret);
46   } else if (!is_constant<Mat1>::value) {
47     arena_t<promote_scalar_t<var, Mat1>> arena_m1 = m1;
48     arena_t<promote_scalar_t<double, Mat2>> arena_m2 = value_of(m2);
49     arena_t<ret_type> ret(arena_m1.val().cwiseProduct(arena_m2));
50     reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
51       arena_m1.adj().array() += arena_m2.array() * ret.adj().array();
52     });
53     return ret_type(ret);
54   } else if (!is_constant<Mat2>::value) {
55     arena_t<promote_scalar_t<double, Mat1>> arena_m1 = value_of(m1);
56     arena_t<promote_scalar_t<var, Mat2>> arena_m2 = m2;
57     arena_t<ret_type> ret(arena_m1.cwiseProduct(arena_m2.val()));
58     reverse_pass_callback([ret, arena_m2, arena_m1]() mutable {
59       arena_m2.adj().array() += arena_m1.array() * ret.adj().array();
60     });
61     return ret_type(ret);
62   }
63 }
64 
65 }  // namespace math
66 }  // namespace stan
67 
68 #endif
69