1 #ifndef STAN_MATH_PRIM_FUN_TRACE_GEN_QUAD_FORM_HPP
2 #define STAN_MATH_PRIM_FUN_TRACE_GEN_QUAD_FORM_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/trace.hpp>
8 #include <stan/math/prim/fun/multiply.hpp>
9 #include <stan/math/prim/fun/to_ref.hpp>
10 #include <stan/math/prim/fun/transpose.hpp>
11 #include <exception>
12
13 namespace stan {
14 namespace math {
15
16 /**
17 * Return the trace of D times the quadratic form of B and A.
18 * That is, `trace_gen_quad_form(D, A, B) = trace(D * B' * A * B).`
19 *
20 * @tparam TD type of the first matrix or expression
21 * @tparam TA type of the second matrix or expression
22 * @tparam TB type of the third matrix or expression
23 *
24 * @param D multiplier
25 * @param A outside term in quadratic form
26 * @param B inner term in quadratic form
27 * @return trace(D * B' * A * B)
28 * @throw std::domain_error if A or D is not square
29 * @throw std::domain_error if A cannot be multiplied by B or B cannot
30 * be multiplied by D.
31 */
32 template <typename TD, typename TA, typename TB,
33 typename = require_all_eigen_t<TD, TA, TB>,
34 typename = require_all_not_vt_var<TD, TA, TB>,
35 typename = require_any_not_vt_arithmetic<TD, TA, TB>>
trace_gen_quad_form(const TD & D,const TA & A,const TB & B)36 inline auto trace_gen_quad_form(const TD& D, const TA& A, const TB& B) {
37 check_square("trace_gen_quad_form", "A", A);
38 check_square("trace_gen_quad_form", "D", D);
39 check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
40 check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
41 const auto& B_ref = to_ref(B);
42 return multiply(B_ref, D.transpose()).cwiseProduct(multiply(A, B_ref)).sum();
43 }
44
45 /**
46 * Return the trace of D times the quadratic form of B and A.
47 * That is, `trace_gen_quad_form(D, A, B) = trace(D * B' * A * B).`
48 * This is the overload for arithmetic types to allow Eigen's expression
49 * templates to be used for efficiency.
50 *
51 * @tparam EigMatD type of the first matrix or expression
52 * @tparam EigMatA type of the second matrix or expression
53 * @tparam EigMatB type of the third matrix or expression
54 *
55 * @param D multiplier
56 * @param A outside term in quadratic form
57 * @param B inner term in quadratic form
58 * @return trace(D * B' * A * B)
59 * @throw std::domain_error if A or D is not square
60 * @throw std::domain_error if A cannot be multiplied by B or B cannot
61 * be multiplied by D.
62 */
63 template <typename EigMatD, typename EigMatA, typename EigMatB,
64 require_all_eigen_vt<std::is_arithmetic, EigMatD, EigMatA,
65 EigMatB>* = nullptr>
trace_gen_quad_form(const EigMatD & D,const EigMatA & A,const EigMatB & B)66 inline double trace_gen_quad_form(const EigMatD& D, const EigMatA& A,
67 const EigMatB& B) {
68 check_square("trace_gen_quad_form", "A", A);
69 check_square("trace_gen_quad_form", "D", D);
70 check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
71 check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
72 const auto& B_ref = to_ref(B);
73 return (B_ref * D.transpose()).cwiseProduct(A * B_ref).sum();
74 }
75
76 } // namespace math
77 } // namespace stan
78
79 #endif
80