1 #ifndef STAN_MATH_PRIM_FUN_QUAD_FORM_SYM_HPP
2 #define STAN_MATH_PRIM_FUN_QUAD_FORM_SYM_HPP
3
4 #include <stan/math/prim/err.hpp>
5 #include <stan/math/prim/fun/Eigen.hpp>
6 #include <stan/math/prim/fun/to_ref.hpp>
7
8 namespace stan {
9 namespace math {
10
11 /**
12 * Return the quadratic form \f$ B^T A B \f$ of a symmetric matrix.
13 *
14 * Symmetry of the resulting matrix is guaranteed.
15 *
16 * @tparam EigMat1 type of the first (symmetric) matrix
17 * @tparam EigMat2 type of the second matrix
18 *
19 * @param A symmetric matrix
20 * @param B second matrix
21 * @return The quadratic form, which is a symmetric matrix.
22 * @throws std::invalid_argument if A is not symmetric, or if A cannot be
23 * multiplied by B
24 */
25 template <typename EigMat1, typename EigMat2,
26 require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
27 require_not_eigen_col_vector_t<EigMat2>* = nullptr,
28 require_vt_same<EigMat1, EigMat2>* = nullptr,
29 require_all_vt_arithmetic<EigMat1, EigMat2>* = nullptr>
quad_form_sym(const EigMat1 & A,const EigMat2 & B)30 inline plain_type_t<EigMat2> quad_form_sym(const EigMat1& A, const EigMat2& B) {
31 check_multiplicable("quad_form_sym", "A", A, "B", B);
32 const auto& A_ref = to_ref(A);
33 const auto& B_ref = to_ref(B);
34 check_symmetric("quad_form_sym", "A", A_ref);
35 return make_holder(
36 [](const auto& ret) { return 0.5 * (ret + ret.transpose()); },
37 (B_ref.transpose() * A_ref * B_ref).eval());
38 }
39
40 /**
41 * Return the quadratic form \f$ B^T A B \f$ of a symmetric matrix.
42 *
43 * @tparam EigMat type of the (symmetric) matrix
44 * @tparam ColVec type of the vector
45 *
46 * @param A symmetric matrix
47 * @param B vector
48 * @return The quadratic form (a scalar).
49 * @throws std::invalid_argument if A is not symmetric, or if A cannot be
50 * multiplied by B
51 */
52 template <typename EigMat, typename ColVec, require_eigen_t<EigMat>* = nullptr,
53 require_eigen_col_vector_t<ColVec>* = nullptr,
54 require_vt_same<EigMat, ColVec>* = nullptr,
55 require_all_vt_arithmetic<EigMat, ColVec>* = nullptr>
quad_form_sym(const EigMat & A,const ColVec & B)56 inline value_type_t<EigMat> quad_form_sym(const EigMat& A, const ColVec& B) {
57 check_multiplicable("quad_form_sym", "A", A, "B", B);
58 const auto& A_ref = to_ref(A);
59 const auto& B_ref = to_ref(B);
60 check_symmetric("quad_form_sym", "A", A_ref);
61 return B_ref.dot(A_ref * B_ref);
62 }
63
64 } // namespace math
65 } // namespace stan
66
67 #endif
68