1 #ifndef STAN_MATH_REV_FUN_MULTIPLY_HPP
2 #define STAN_MATH_REV_FUN_MULTIPLY_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/core/typedefs.hpp>
7 #include <stan/math/prim/fun.hpp>
8 #include <type_traits>
9
10 namespace stan {
11 namespace math {
12
13 /**
14 * Return the product of two matrices.
15 *
16 * This version does not handle row vector times column vector
17 *
18 * @tparam T1 type of first matrix
19 * @tparam T2 type of second matrix
20 *
21 *
22 * @param[in] A first matrix
23 * @param[in] B second matrix
24 * @return A * B
25 */
26 template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
27 require_return_type_t<is_var, T1, T2>* = nullptr,
28 require_not_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)29 inline auto multiply(const T1& A, const T2& B) {
30 check_multiplicable("multiply", "A", A, "B", B);
31 if (!is_constant<T2>::value && !is_constant<T1>::value) {
32 arena_t<promote_scalar_t<var, T1>> arena_A = A;
33 arena_t<promote_scalar_t<var, T2>> arena_B = B;
34 auto arena_A_val = to_arena(arena_A.val());
35 auto arena_B_val = to_arena(arena_B.val());
36 using return_t
37 = return_var_matrix_t<decltype(arena_A_val * arena_B_val), T1, T2>;
38 arena_t<return_t> res = arena_A_val * arena_B_val;
39
40 reverse_pass_callback(
41 [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable {
42 if (is_var_matrix<T1>::value || is_var_matrix<T2>::value) {
43 arena_A.adj() += res.adj_op() * arena_B_val.transpose();
44 arena_B.adj() += arena_A_val.transpose() * res.adj_op();
45 } else {
46 auto res_adj = res.adj().eval();
47 arena_A.adj() += res_adj * arena_B_val.transpose();
48 arena_B.adj() += arena_A_val.transpose() * res_adj;
49 }
50 });
51 return return_t(res);
52 } else if (!is_constant<T2>::value) {
53 arena_t<promote_scalar_t<double, T1>> arena_A = value_of(A);
54 arena_t<promote_scalar_t<var, T2>> arena_B = B;
55 using return_t
56 = return_var_matrix_t<decltype(arena_A * value_of(B).eval()), T1, T2>;
57 arena_t<return_t> res = arena_A * arena_B.val_op();
58 reverse_pass_callback([arena_B, arena_A, res]() mutable {
59 arena_B.adj() += arena_A.transpose() * res.adj_op();
60 });
61 return return_t(res);
62 } else {
63 arena_t<promote_scalar_t<var, T1>> arena_A = A;
64 arena_t<promote_scalar_t<double, T2>> arena_B = value_of(B);
65 using return_t
66 = return_var_matrix_t<decltype(value_of(arena_A).eval() * arena_B), T1,
67 T2>;
68 arena_t<return_t> res = arena_A.val_op() * arena_B;
69 reverse_pass_callback([arena_A, arena_B, res]() mutable {
70 arena_A.adj() += res.adj_op() * arena_B.transpose();
71 });
72
73 return return_t(res);
74 }
75 }
76
77 /**
78 * Return the product of a row vector times a column vector as a scalar
79 *
80 * @tparam T1 type of row vector
81 * @tparam T2 type of column vector
82 *
83 * @param[in] A row vector
84 * @param[in] B column vector
85 * @return A * B as a scalar
86 */
87 template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
88 require_return_type_t<is_var, T1, T2>* = nullptr,
89 require_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)90 inline var multiply(const T1& A, const T2& B) {
91 check_multiplicable("multiply", "A", A, "B", B);
92 if (!is_constant<T2>::value && !is_constant<T1>::value) {
93 arena_t<promote_scalar_t<var, T1>> arena_A = A;
94 arena_t<promote_scalar_t<var, T2>> arena_B = B;
95 arena_t<promote_scalar_t<double, T1>> arena_A_val = value_of(arena_A);
96 arena_t<promote_scalar_t<double, T2>> arena_B_val = value_of(arena_B);
97 var res = arena_A_val.dot(arena_B_val);
98
99 reverse_pass_callback(
100 [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable {
101 auto res_adj = res.adj();
102 arena_A.adj().array() += res_adj * arena_B_val.transpose().array();
103 arena_B.adj().array() += arena_A_val.transpose().array() * res_adj;
104 });
105 return res;
106 } else if (!is_constant<T2>::value) {
107 arena_t<promote_scalar_t<var, T2>> arena_B = B;
108 arena_t<promote_scalar_t<double, T1>> arena_A_val = value_of(A);
109 var res = arena_A_val.dot(value_of(arena_B));
110 reverse_pass_callback([arena_B, arena_A_val, res]() mutable {
111 arena_B.adj().array() += arena_A_val.transpose().array() * res.adj();
112 });
113 return res;
114 } else {
115 arena_t<promote_scalar_t<var, T1>> arena_A = A;
116 arena_t<promote_scalar_t<double, T2>> arena_B_val = value_of(B);
117 var res = value_of(arena_A).dot(arena_B_val);
118 reverse_pass_callback([arena_A, arena_B_val, res]() mutable {
119 arena_A.adj().array() += res.adj() * arena_B_val.transpose().array();
120 });
121 return res;
122 }
123 }
124
125 /**
126 * Return specified matrix multiplied by specified scalar where at least one
127 * input has a scalar type of a `var_value`.
128 *
129 * @tparam T1 type of the scalar
130 * @tparam T2 type of the matrix or expression
131 *
132 * @param A scalar
133 * @param B matrix
134 * @return product of matrix and scalar
135 */
136 template <typename T1, typename T2, require_not_matrix_t<T1>* = nullptr,
137 require_matrix_t<T2>* = nullptr,
138 require_return_type_t<is_var, T1, T2>* = nullptr,
139 require_not_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)140 inline auto multiply(const T1& A, const T2& B) {
141 if (!is_constant<T2>::value && !is_constant<T1>::value) {
142 arena_t<promote_scalar_t<var, T1>> arena_A = A;
143 arena_t<promote_scalar_t<var, T2>> arena_B = B;
144 using return_t = return_var_matrix_t<T2, T1, T2>;
145 arena_t<return_t> res = arena_A.val() * arena_B.val().array();
146 reverse_pass_callback([arena_A, arena_B, res]() mutable {
147 const auto a_val = arena_A.val();
148 for (Eigen::Index j = 0; j < res.cols(); ++j) {
149 for (Eigen::Index i = 0; i < res.rows(); ++i) {
150 const auto res_adj = res.adj().coeffRef(i, j);
151 arena_A.adj() += res_adj * arena_B.val().coeff(i, j);
152 arena_B.adj().coeffRef(i, j) += a_val * res_adj;
153 }
154 }
155 });
156 return return_t(res);
157 } else if (!is_constant<T2>::value) {
158 arena_t<promote_scalar_t<double, T1>> arena_A = value_of(A);
159 arena_t<promote_scalar_t<var, T2>> arena_B = B;
160 using return_t = return_var_matrix_t<T2, T1, T2>;
161 arena_t<return_t> res = arena_A * arena_B.val().array();
162 reverse_pass_callback([arena_A, arena_B, res]() mutable {
163 arena_B.adj().array() += arena_A * res.adj().array();
164 });
165 return return_t(res);
166 } else {
167 arena_t<promote_scalar_t<var, T1>> arena_A = A;
168 arena_t<promote_scalar_t<double, T2>> arena_B = value_of(B);
169 using return_t = return_var_matrix_t<T2, T1, T2>;
170 arena_t<return_t> res = arena_A.val() * arena_B.array();
171 reverse_pass_callback([arena_A, arena_B, res]() mutable {
172 arena_A.adj() += (res.adj().array() * arena_B.array()).sum();
173 });
174 return return_t(res);
175 }
176 }
177
178 /**
179 * Return specified matrix multiplied by specified scalar where at least one
180 * input has a scalar type of a `var_value`.
181 *
182 * @tparam T1 type of the matrix or expression
183 * @tparam T2 type of the scalar
184 *
185 * @param A matrix
186 * @param B scalar
187 * @return product of matrix and scalar
188 */
189 template <typename T1, typename T2, require_matrix_t<T1>* = nullptr,
190 require_not_matrix_t<T2>* = nullptr,
191 require_any_st_var<T1, T2>* = nullptr,
192 require_not_row_and_col_vector_t<T1, T2>* = nullptr>
multiply(const T1 & A,const T2 & B)193 inline auto multiply(const T1& A, const T2& B) {
194 return multiply(B, A);
195 }
196
197 } // namespace math
198 } // namespace stan
199 #endif
200