1 #ifndef STAN_MATH_REV_FUN_MULTIPLY_HPP
2 #define STAN_MATH_REV_FUN_MULTIPLY_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/core/typedefs.hpp>
7 #include <stan/math/prim/fun.hpp>
8 #include <type_traits>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * Return the product of two matrices.
15  *
16  * This version does not handle row vector times column vector
17  *
18  * @tparam T1 type of first matrix
19  * @tparam T2 type of second matrix
20  *
21  *
22  * @param[in] A first matrix
23  * @param[in] B second matrix
24  * @return A * B
25  */
26 template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
27           require_return_type_t<is_var, T1, T2>* = nullptr,
28           require_not_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)29 inline auto multiply(const T1& A, const T2& B) {
30   check_multiplicable("multiply", "A", A, "B", B);
31   if (!is_constant<T2>::value && !is_constant<T1>::value) {
32     arena_t<promote_scalar_t<var, T1>> arena_A = A;
33     arena_t<promote_scalar_t<var, T2>> arena_B = B;
34     auto arena_A_val = to_arena(arena_A.val());
35     auto arena_B_val = to_arena(arena_B.val());
36     using return_t
37         = return_var_matrix_t<decltype(arena_A_val * arena_B_val), T1, T2>;
38     arena_t<return_t> res = arena_A_val * arena_B_val;
39 
40     reverse_pass_callback(
41         [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable {
42           if (is_var_matrix<T1>::value || is_var_matrix<T2>::value) {
43             arena_A.adj() += res.adj_op() * arena_B_val.transpose();
44             arena_B.adj() += arena_A_val.transpose() * res.adj_op();
45           } else {
46             auto res_adj = res.adj().eval();
47             arena_A.adj() += res_adj * arena_B_val.transpose();
48             arena_B.adj() += arena_A_val.transpose() * res_adj;
49           }
50         });
51     return return_t(res);
52   } else if (!is_constant<T2>::value) {
53     arena_t<promote_scalar_t<double, T1>> arena_A = value_of(A);
54     arena_t<promote_scalar_t<var, T2>> arena_B = B;
55     using return_t
56         = return_var_matrix_t<decltype(arena_A * value_of(B).eval()), T1, T2>;
57     arena_t<return_t> res = arena_A * arena_B.val_op();
58     reverse_pass_callback([arena_B, arena_A, res]() mutable {
59       arena_B.adj() += arena_A.transpose() * res.adj_op();
60     });
61     return return_t(res);
62   } else {
63     arena_t<promote_scalar_t<var, T1>> arena_A = A;
64     arena_t<promote_scalar_t<double, T2>> arena_B = value_of(B);
65     using return_t
66         = return_var_matrix_t<decltype(value_of(arena_A).eval() * arena_B), T1,
67                               T2>;
68     arena_t<return_t> res = arena_A.val_op() * arena_B;
69     reverse_pass_callback([arena_A, arena_B, res]() mutable {
70       arena_A.adj() += res.adj_op() * arena_B.transpose();
71     });
72 
73     return return_t(res);
74   }
75 }
76 
77 /**
78  * Return the product of a row vector times a column vector as a scalar
79  *
80  * @tparam T1 type of row vector
81  * @tparam T2 type of column vector
82  *
83  * @param[in] A row vector
84  * @param[in] B column vector
85  * @return A * B as a scalar
86  */
87 template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
88           require_return_type_t<is_var, T1, T2>* = nullptr,
89           require_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)90 inline var multiply(const T1& A, const T2& B) {
91   check_multiplicable("multiply", "A", A, "B", B);
92   if (!is_constant<T2>::value && !is_constant<T1>::value) {
93     arena_t<promote_scalar_t<var, T1>> arena_A = A;
94     arena_t<promote_scalar_t<var, T2>> arena_B = B;
95     arena_t<promote_scalar_t<double, T1>> arena_A_val = value_of(arena_A);
96     arena_t<promote_scalar_t<double, T2>> arena_B_val = value_of(arena_B);
97     var res = arena_A_val.dot(arena_B_val);
98 
99     reverse_pass_callback(
100         [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable {
101           auto res_adj = res.adj();
102           arena_A.adj().array() += res_adj * arena_B_val.transpose().array();
103           arena_B.adj().array() += arena_A_val.transpose().array() * res_adj;
104         });
105     return res;
106   } else if (!is_constant<T2>::value) {
107     arena_t<promote_scalar_t<var, T2>> arena_B = B;
108     arena_t<promote_scalar_t<double, T1>> arena_A_val = value_of(A);
109     var res = arena_A_val.dot(value_of(arena_B));
110     reverse_pass_callback([arena_B, arena_A_val, res]() mutable {
111       arena_B.adj().array() += arena_A_val.transpose().array() * res.adj();
112     });
113     return res;
114   } else {
115     arena_t<promote_scalar_t<var, T1>> arena_A = A;
116     arena_t<promote_scalar_t<double, T2>> arena_B_val = value_of(B);
117     var res = value_of(arena_A).dot(arena_B_val);
118     reverse_pass_callback([arena_A, arena_B_val, res]() mutable {
119       arena_A.adj().array() += res.adj() * arena_B_val.transpose().array();
120     });
121     return res;
122   }
123 }
124 
125 /**
126  * Return specified matrix multiplied by specified scalar where at least one
127  * input has a scalar type of a `var_value`.
128  *
129  * @tparam T1 type of the scalar
130  * @tparam T2 type of the matrix or expression
131  *
132  * @param A scalar
133  * @param B matrix
134  * @return product of matrix and scalar
135  */
136 template <typename T1, typename T2, require_not_matrix_t<T1>* = nullptr,
137           require_matrix_t<T2>* = nullptr,
138           require_return_type_t<is_var, T1, T2>* = nullptr,
139           require_not_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)140 inline auto multiply(const T1& A, const T2& B) {
141   if (!is_constant<T2>::value && !is_constant<T1>::value) {
142     arena_t<promote_scalar_t<var, T1>> arena_A = A;
143     arena_t<promote_scalar_t<var, T2>> arena_B = B;
144     using return_t = return_var_matrix_t<T2, T1, T2>;
145     arena_t<return_t> res = arena_A.val() * arena_B.val().array();
146     reverse_pass_callback([arena_A, arena_B, res]() mutable {
147       const auto a_val = arena_A.val();
148       for (Eigen::Index j = 0; j < res.cols(); ++j) {
149         for (Eigen::Index i = 0; i < res.rows(); ++i) {
150           const auto res_adj = res.adj().coeffRef(i, j);
151           arena_A.adj() += res_adj * arena_B.val().coeff(i, j);
152           arena_B.adj().coeffRef(i, j) += a_val * res_adj;
153         }
154       }
155     });
156     return return_t(res);
157   } else if (!is_constant<T2>::value) {
158     arena_t<promote_scalar_t<double, T1>> arena_A = value_of(A);
159     arena_t<promote_scalar_t<var, T2>> arena_B = B;
160     using return_t = return_var_matrix_t<T2, T1, T2>;
161     arena_t<return_t> res = arena_A * arena_B.val().array();
162     reverse_pass_callback([arena_A, arena_B, res]() mutable {
163       arena_B.adj().array() += arena_A * res.adj().array();
164     });
165     return return_t(res);
166   } else {
167     arena_t<promote_scalar_t<var, T1>> arena_A = A;
168     arena_t<promote_scalar_t<double, T2>> arena_B = value_of(B);
169     using return_t = return_var_matrix_t<T2, T1, T2>;
170     arena_t<return_t> res = arena_A.val() * arena_B.array();
171     reverse_pass_callback([arena_A, arena_B, res]() mutable {
172       arena_A.adj() += (res.adj().array() * arena_B.array()).sum();
173     });
174     return return_t(res);
175   }
176 }
177 
178 /**
179  * Return specified matrix multiplied by specified scalar where at least one
180  * input has a scalar type of a `var_value`.
181  *
182  * @tparam T1 type of the matrix or expression
183  * @tparam T2 type of the scalar
184  *
185  * @param A matrix
186  * @param B scalar
187  * @return product of matrix and scalar
188  */
189 template <typename T1, typename T2, require_matrix_t<T1>* = nullptr,
190           require_not_matrix_t<T2>* = nullptr,
191           require_any_st_var<T1, T2>* = nullptr,
192           require_not_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)193 inline auto multiply(const T1& A, const T2& B) {
194   return multiply(B, A);
195 }
196 
197 }  // namespace math
198 }  // namespace stan
199 #endif
200