1 #ifndef STAN_MATH_PRIM_FUN_TRACE_GEN_QUAD_FORM_HPP
2 #define STAN_MATH_PRIM_FUN_TRACE_GEN_QUAD_FORM_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/trace.hpp>
8 #include <stan/math/prim/fun/multiply.hpp>
9 #include <stan/math/prim/fun/to_ref.hpp>
10 #include <stan/math/prim/fun/transpose.hpp>
11 #include <exception>
12 
13 namespace stan {
14 namespace math {
15 
16 /**
17  * Return the trace of D times the quadratic form of B and A.
18  * That is, `trace_gen_quad_form(D, A, B) = trace(D * B' * A * B).`
19  *
20  * @tparam TD type of the first matrix or expression
21  * @tparam TA type of the second matrix or expression
22  * @tparam TB type of the third matrix or expression
23  *
24  * @param D multiplier
25  * @param A outside term in quadratic form
26  * @param B inner term in quadratic form
27  * @return trace(D * B' * A * B)
28  * @throw std::domain_error if A or D is not square
29  * @throw std::domain_error if A cannot be multiplied by B or B cannot
30  * be multiplied by D.
31  */
32 template <typename TD, typename TA, typename TB,
33           typename = require_all_eigen_t<TD, TA, TB>,
34           typename = require_all_not_vt_var<TD, TA, TB>,
35           typename = require_any_not_vt_arithmetic<TD, TA, TB>>
trace_gen_quad_form(const TD & D,const TA & A,const TB & B)36 inline auto trace_gen_quad_form(const TD& D, const TA& A, const TB& B) {
37   check_square("trace_gen_quad_form", "A", A);
38   check_square("trace_gen_quad_form", "D", D);
39   check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
40   check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
41   const auto& B_ref = to_ref(B);
42   return multiply(B_ref, D.transpose()).cwiseProduct(multiply(A, B_ref)).sum();
43 }
44 
45 /**
46  * Return the trace of D times the quadratic form of B and A.
47  * That is, `trace_gen_quad_form(D, A, B) = trace(D * B' * A * B).`
48  * This is the overload for arithmetic types to allow Eigen's expression
49  * templates to be used for efficiency.
50  *
51  * @tparam EigMatD type of the first matrix or expression
52  * @tparam EigMatA type of the second matrix or expression
53  * @tparam EigMatB type of the third matrix or expression
54  *
55  * @param D multiplier
56  * @param A outside term in quadratic form
57  * @param B inner term in quadratic form
58  * @return trace(D * B' * A * B)
59  * @throw std::domain_error if A or D is not square
60  * @throw std::domain_error if A cannot be multiplied by B or B cannot
61  * be multiplied by D.
62  */
63 template <typename EigMatD, typename EigMatA, typename EigMatB,
64           require_all_eigen_vt<std::is_arithmetic, EigMatD, EigMatA,
65                                EigMatB>* = nullptr>
trace_gen_quad_form(const EigMatD & D,const EigMatA & A,const EigMatB & B)66 inline double trace_gen_quad_form(const EigMatD& D, const EigMatA& A,
67                                   const EigMatB& B) {
68   check_square("trace_gen_quad_form", "A", A);
69   check_square("trace_gen_quad_form", "D", D);
70   check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
71   check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
72   const auto& B_ref = to_ref(B);
73   return (B_ref * D.transpose()).cwiseProduct(A * B_ref).sum();
74 }
75 
76 }  // namespace math
77 }  // namespace stan
78 
79 #endif
80