1 #ifndef STAN_MATH_REV_FUN_COLUMNS_DOT_PRODUCT_HPP
2 #define STAN_MATH_REV_FUN_COLUMNS_DOT_PRODUCT_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/rev/core.hpp>
8 #include <stan/math/rev/core/typedefs.hpp>
9 #include <stan/math/rev/fun/dot_product.hpp>
10 #include <stan/math/prim/meta.hpp>
11 
12 #include <type_traits>
13 
14 namespace stan {
15 namespace math {
16 
17 /**
18  * Returns the dot product of columns of the specified matrices.
19  *
20  * @tparam Mat1 type of the first matrix (must be derived from \c
21  * Eigen::MatrixBase)
22  * @tparam Mat2 type of the second matrix (must be derived from \c
23  * Eigen::MatrixBase)
24  *
25  * @param v1 Matrix of first vectors.
26  * @param v2 Matrix of second vectors.
27  * @return Dot product of the vectors.
28  * @throw std::domain_error If the vectors are not the same
29  * size or if they are both not vector dimensioned.
30  */
31 template <typename Mat1, typename Mat2,
32           require_all_eigen_t<Mat1, Mat2>* = nullptr,
33           require_any_eigen_vt<is_var, Mat1, Mat2>* = nullptr>
34 inline Eigen::Matrix<return_type_t<Mat1, Mat2>, 1, Mat1::ColsAtCompileTime>
columns_dot_product(const Mat1 & v1,const Mat2 & v2)35 columns_dot_product(const Mat1& v1, const Mat2& v2) {
36   check_matching_sizes("dot_product", "v1", v1, "v2", v2);
37   Eigen::Matrix<var, 1, Mat1::ColsAtCompileTime> ret(1, v1.cols());
38   for (size_type j = 0; j < v1.cols(); ++j) {
39     ret.coeffRef(j) = dot_product(v1.col(j), v2.col(j));
40   }
41   return ret;
42 }
43 
44 /**
45  * Returns the dot product of columns of the specified matrices.
46  *
47  * This overload is used when at least one of Mat1 and Mat2 is
48  * a `var_value<T>` where `T` inherits from `EigenBase`. The other
49  * argument can be another `var_value` or a type that inherits from
50  * `EigenBase`.
51  *
52  * @tparam Mat1 type of the first matrix
53  * @tparam Mat2 type of the second matrix
54  *
55  * @param v1 Matrix of first vectors.
56  * @param v2 Matrix of second vectors.
57  * @return Dot product of the vectors.
58  * @throw std::domain_error If the vectors are not the same
59  * size or if they are both not vector dimensioned.
60  */
61 template <typename Mat1, typename Mat2,
62           require_all_matrix_t<Mat1, Mat2>* = nullptr,
63           require_any_var_matrix_t<Mat1, Mat2>* = nullptr>
columns_dot_product(const Mat1 & v1,const Mat2 & v2)64 inline auto columns_dot_product(const Mat1& v1, const Mat2& v2) {
65   check_matching_sizes("columns_dot_product", "v1", v1, "v2", v2);
66   using inner_return_t = decltype(
67       (value_of(v1).array() * value_of(v2).array()).colwise().sum().matrix());
68   using return_t = return_var_matrix_t<inner_return_t, Mat1, Mat2>;
69 
70   if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
71     arena_t<promote_scalar_t<var, Mat1>> arena_v1 = v1;
72     arena_t<promote_scalar_t<var, Mat2>> arena_v2 = v2;
73 
74     return_t res
75         = (arena_v1.val().array() * arena_v2.val().array()).colwise().sum();
76 
77     reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
78       if (is_var_matrix<Mat1>::value) {
79         arena_v1.adj().noalias() += arena_v2.val() * res.adj().asDiagonal();
80       } else {
81         arena_v1.adj() += arena_v2.val() * res.adj().asDiagonal();
82       }
83       if (is_var_matrix<Mat2>::value) {
84         arena_v2.adj().noalias() += arena_v1.val() * res.adj().asDiagonal();
85       } else {
86         arena_v2.adj() += arena_v1.val() * res.adj().asDiagonal();
87       }
88     });
89 
90     return res;
91   } else if (!is_constant<Mat2>::value) {
92     arena_t<promote_scalar_t<double, Mat1>> arena_v1 = value_of(v1);
93     arena_t<promote_scalar_t<var, Mat2>> arena_v2 = v2;
94 
95     return_t res = (arena_v1.array() * arena_v2.val().array()).colwise().sum();
96 
97     reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
98       if (is_var_matrix<Mat2>::value) {
99         arena_v2.adj().noalias() += arena_v1 * res.adj().asDiagonal();
100       } else {
101         arena_v2.adj() += arena_v1 * res.adj().asDiagonal();
102       }
103     });
104 
105     return res;
106   } else {
107     arena_t<promote_scalar_t<var, Mat1>> arena_v1 = v1;
108     arena_t<promote_scalar_t<double, Mat2>> arena_v2 = value_of(v2);
109 
110     return_t res = (arena_v1.val().array() * arena_v2.array()).colwise().sum();
111 
112     reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
113       if (is_var_matrix<Mat2>::value) {
114         arena_v1.adj().noalias() += arena_v2 * res.adj().asDiagonal();
115       } else {
116         arena_v1.adj() += arena_v2 * res.adj().asDiagonal();
117       }
118     });
119 
120     return res;
121   }
122 }
123 
124 }  // namespace math
125 }  // namespace stan
126 #endif
127