1 #ifndef STAN_MATH_FWD_FUN_MULTIPLY_HPP
2 #define STAN_MATH_FWD_FUN_MULTIPLY_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/dot_product.hpp>
8 #include <stan/math/fwd/core.hpp>
9 #include <stan/math/fwd/fun/typedefs.hpp>
10
11 namespace stan {
12 namespace math {
13
14 template <typename Mat1, typename Mat2,
15 require_all_eigen_vt<is_fvar, Mat1, Mat2>* = nullptr,
16 require_vt_same<Mat1, Mat2>* = nullptr,
17 require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
multiply(const Mat1 & m1,const Mat2 & m2)18 inline auto multiply(const Mat1& m1, const Mat2& m2) {
19 check_multiplicable("multiply", "m1", m1, "m2", m2);
20 return (m1 * m2).eval();
21 }
22
23 template <typename Mat1, typename Mat2,
24 require_eigen_vt<is_fvar, Mat1>* = nullptr,
25 require_eigen_vt<std::is_floating_point, Mat2>* = nullptr,
26 require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
multiply(const Mat1 & m1,const Mat2 & m2)27 inline auto multiply(const Mat1& m1, const Mat2& m2) {
28 check_multiplicable("multiply", "m1", m1, "m2", m2);
29 Eigen::Matrix<value_type_t<Mat1>, Mat1::RowsAtCompileTime,
30 Mat2::ColsAtCompileTime>
31 result(m1.rows(), m2.cols());
32 for (size_type i = 0; i < m1.rows(); i++) {
33 Eigen::Matrix<value_type_t<Mat1>, 1, Mat1::ColsAtCompileTime> crow
34 = m1.row(i);
35 for (size_type j = 0; j < m2.cols(); j++) {
36 result(i, j) = dot_product(crow, m2.col(j));
37 }
38 }
39 return result;
40 }
41
42 template <typename Mat1, typename Mat2,
43 require_eigen_vt<std::is_floating_point, Mat1>* = nullptr,
44 require_eigen_vt<is_fvar, Mat2>* = nullptr,
45 require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
multiply(const Mat1 & m1,const Mat2 & m2)46 inline auto multiply(const Mat1& m1, const Mat2& m2) {
47 check_multiplicable("multiply", "m1", m1, "m2", m2);
48 Eigen::Matrix<value_type_t<Mat2>, Mat1::RowsAtCompileTime,
49 Mat2::ColsAtCompileTime>
50 result(m1.rows(), m2.cols());
51 for (size_type i = 0; i < m1.rows(); i++) {
52 Eigen::Matrix<double, 1, Mat1::ColsAtCompileTime> crow = m1.row(i);
53 for (size_type j = 0; j < m2.cols(); j++) {
54 auto ccol = m2.col(j);
55 result(i, j) = dot_product(crow, ccol);
56 }
57 }
58 return result;
59 }
60
61 } // namespace math
62 } // namespace stan
63 #endif
64