1 #ifndef STAN_MATH_PRIM_FUN_QUAD_FORM_SYM_HPP
2 #define STAN_MATH_PRIM_FUN_QUAD_FORM_SYM_HPP
3 
4 #include <stan/math/prim/err.hpp>
5 #include <stan/math/prim/fun/Eigen.hpp>
6 #include <stan/math/prim/fun/to_ref.hpp>
7 
8 namespace stan {
9 namespace math {
10 
11 /**
12  * Return the quadratic form \f$ B^T A B \f$ of a symmetric matrix.
13  *
14  * Symmetry of the resulting matrix is guaranteed.
15  *
16  * @tparam EigMat1 type of the first (symmetric) matrix
17  * @tparam EigMat2 type of the second matrix
18  *
19  * @param A symmetric matrix
20  * @param B second matrix
21  * @return The quadratic form, which is a symmetric matrix.
22  * @throws std::invalid_argument if A is not symmetric, or if A cannot be
23  * multiplied by B
24  */
25 template <typename EigMat1, typename EigMat2,
26           require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
27           require_not_eigen_col_vector_t<EigMat2>* = nullptr,
28           require_vt_same<EigMat1, EigMat2>* = nullptr,
29           require_all_vt_arithmetic<EigMat1, EigMat2>* = nullptr>
quad_form_sym(const EigMat1 & A,const EigMat2 & B)30 inline plain_type_t<EigMat2> quad_form_sym(const EigMat1& A, const EigMat2& B) {
31   check_multiplicable("quad_form_sym", "A", A, "B", B);
32   const auto& A_ref = to_ref(A);
33   const auto& B_ref = to_ref(B);
34   check_symmetric("quad_form_sym", "A", A_ref);
35   return make_holder(
36       [](const auto& ret) { return 0.5 * (ret + ret.transpose()); },
37       (B_ref.transpose() * A_ref * B_ref).eval());
38 }
39 
40 /**
41  * Return the quadratic form \f$ B^T A B \f$ of a symmetric matrix.
42  *
43  * @tparam EigMat type of the (symmetric) matrix
44  * @tparam ColVec type of the vector
45  *
46  * @param A symmetric matrix
47  * @param B vector
48  * @return The quadratic form (a scalar).
49  * @throws std::invalid_argument if A is not symmetric, or if A cannot be
50  * multiplied by B
51  */
52 template <typename EigMat, typename ColVec, require_eigen_t<EigMat>* = nullptr,
53           require_eigen_col_vector_t<ColVec>* = nullptr,
54           require_vt_same<EigMat, ColVec>* = nullptr,
55           require_all_vt_arithmetic<EigMat, ColVec>* = nullptr>
quad_form_sym(const EigMat & A,const ColVec & B)56 inline value_type_t<EigMat> quad_form_sym(const EigMat& A, const ColVec& B) {
57   check_multiplicable("quad_form_sym", "A", A, "B", B);
58   const auto& A_ref = to_ref(A);
59   const auto& B_ref = to_ref(B);
60   check_symmetric("quad_form_sym", "A", A_ref);
61   return B_ref.dot(A_ref * B_ref);
62 }
63 
64 }  // namespace math
65 }  // namespace stan
66 
67 #endif
68