1 #ifndef STAN_MATH_REV_FUN_ROWS_DOT_PRODUCT_HPP
2 #define STAN_MATH_REV_FUN_ROWS_DOT_PRODUCT_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/core/typedefs.hpp>
7 #include <stan/math/rev/fun/dot_product.hpp>
8 #include <stan/math/prim/err.hpp>
9 #include <stan/math/prim/fun/Eigen.hpp>
10 #include <stan/math/prim/fun/rows_dot_product.hpp>
11 #include <type_traits>
12
13 namespace stan {
14 namespace math {
15
16 /**
17 * Returns the dot product of rows of the specified matrices.
18 *
19 * @tparam Mat1 type of the first matrix (must be derived from \c
20 * Eigen::MatrixBase)
21 * @tparam Mat2 type of the second matrix (must be derived from \c
22 * Eigen::MatrixBase)
23 *
24 * @param v1 Matrix of first vectors.
25 * @param v2 Matrix of second vectors.
26 * @return Dot product of the vectors.
27 * @throw std::domain_error If the vectors are not the same
28 * size or if they are both not vector dimensioned.
29 */
30 template <typename Mat1, typename Mat2,
31 require_all_eigen_t<Mat1, Mat2>* = nullptr,
32 require_any_eigen_vt<is_var, Mat1, Mat2>* = nullptr>
rows_dot_product(const Mat1 & v1,const Mat2 & v2)33 inline Eigen::Matrix<var, Mat1::RowsAtCompileTime, 1> rows_dot_product(
34 const Mat1& v1, const Mat2& v2) {
35 check_matching_sizes("dot_product", "v1", v1, "v2", v2);
36 Eigen::Matrix<var, Mat1::RowsAtCompileTime, 1> ret(v1.rows(), 1);
37 for (size_type j = 0; j < v1.rows(); ++j) {
38 ret.coeffRef(j) = dot_product(v1.row(j), v2.row(j));
39 }
40 return ret;
41 }
42
43 /**
44 * Returns the dot product of rows of the specified matrices.
45 *
46 * This overload is used when at least one of Mat1 and Mat2 is
47 * a `var_value<T>` where `T` inherits from `EigenBase`. The other
48 * argument can be another `var_value` or a type that inherits from
49 * `EigenBase`.
50 *
51 * @tparam Mat1 type of the first matrix
52 * @tparam Mat2 type of the second matrix
53 *
54 * @param v1 Matrix of first vectors.
55 * @param v2 Matrix of second vectors.
56 * @return Dot product of the vectors.
57 * @throw std::domain_error If the vectors are not the same
58 * size or if they are both not vector dimensioned.
59 */
60 template <typename Mat1, typename Mat2,
61 require_all_matrix_t<Mat1, Mat2>* = nullptr,
62 require_any_var_matrix_t<Mat1, Mat2>* = nullptr>
rows_dot_product(const Mat1 & v1,const Mat2 & v2)63 inline auto rows_dot_product(const Mat1& v1, const Mat2& v2) {
64 check_matching_sizes("rows_dot_product", "v1", v1, "v2", v2);
65
66 using return_t = return_var_matrix_t<
67 decltype((v1.val().array() * v2.val().array()).rowwise().sum().matrix()),
68 Mat1, Mat2>;
69
70 if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
71 arena_t<promote_scalar_t<var, Mat1>> arena_v1 = v1;
72 arena_t<promote_scalar_t<var, Mat2>> arena_v2 = v2;
73
74 return_t res
75 = (arena_v1.val().array() * arena_v2.val().array()).rowwise().sum();
76
77 reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
78 if (is_var_matrix<Mat1>::value) {
79 arena_v1.adj().noalias() += res.adj().asDiagonal() * arena_v2.val();
80 } else {
81 arena_v1.adj() += res.adj().asDiagonal() * arena_v2.val();
82 }
83 if (is_var_matrix<Mat2>::value) {
84 arena_v2.adj().noalias() += res.adj().asDiagonal() * arena_v1.val();
85 } else {
86 arena_v2.adj() += res.adj().asDiagonal() * arena_v1.val();
87 }
88 });
89
90 return res;
91 } else if (!is_constant<Mat2>::value) {
92 arena_t<promote_scalar_t<double, Mat1>> arena_v1 = value_of(v1);
93 arena_t<promote_scalar_t<var, Mat2>> arena_v2 = v2;
94
95 return_t res = (arena_v1.array() * arena_v2.val().array()).rowwise().sum();
96
97 reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
98 if (is_var_matrix<Mat2>::value) {
99 arena_v2.adj().noalias() += res.adj().asDiagonal() * arena_v1;
100 } else {
101 arena_v2.adj() += res.adj().asDiagonal() * arena_v1;
102 }
103 });
104
105 return res;
106 } else {
107 arena_t<promote_scalar_t<var, Mat1>> arena_v1 = v1;
108 arena_t<promote_scalar_t<double, Mat2>> arena_v2 = value_of(v2);
109
110 return_t res = (arena_v1.val().array() * arena_v2.array()).rowwise().sum();
111
112 reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
113 if (is_var_matrix<Mat2>::value) {
114 arena_v1.adj().noalias() += res.adj().asDiagonal() * arena_v2;
115 } else {
116 arena_v1.adj() += res.adj().asDiagonal() * arena_v2;
117 }
118 });
119
120 return res;
121 }
122 }
123
124 } // namespace math
125 } // namespace stan
126 #endif
127