1 #ifndef STAN_MATH_PRIM_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_PRIM_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/LDLT_factor.hpp>
8 #include <stan/math/prim/fun/mdivide_left_ldlt.hpp>
9 #include <stan/math/prim/fun/trace.hpp>
10 #include <stan/math/prim/fun/transpose.hpp>
11 #include <stan/math/prim/fun/multiply.hpp>
12
13 namespace stan {
14 namespace math {
15
16 /**
17 * Compute the trace of an inverse quadratic form. I.E., this computes
18 * trace(D B^T A^-1 B)
19 * where D is a square matrix and the LDLT_factor of A is provided.
20 *
21 * @tparam EigMat1 type of the first matrix
22 * @tparam T2 type of matrix in the LDLT_factor
23 * @tparam EigMat3 type of the third matrix
24 *
25 * @param D multiplier
26 * @param A LDLT_factor
27 * @param B inner term in quadratic form
28 * @return trace(D * B^T * A^-1 * B)
29 * @throw std::domain_error if D is not square
30 * @throw std::domain_error if A cannot be multiplied by B or B cannot
31 * be multiplied by D.
32 */
33 template <typename EigMat1, typename T2, typename EigMat3,
34 require_not_col_vector_t<EigMat1>* = nullptr,
35 require_all_not_st_var<EigMat1, T2, EigMat3>* = nullptr>
trace_gen_inv_quad_form_ldlt(const EigMat1 & D,LDLT_factor<T2> & A,const EigMat3 & B)36 inline return_type_t<EigMat1, T2, EigMat3> trace_gen_inv_quad_form_ldlt(
37 const EigMat1& D, LDLT_factor<T2>& A, const EigMat3& B) {
38 check_square("trace_gen_inv_quad_form_ldlt", "D", D);
39 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
40 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
41
42 if (D.size() == 0 || A.matrix().size() == 0) {
43 return 0;
44 }
45
46 return multiply(B, D.transpose()).cwiseProduct(mdivide_left_ldlt(A, B)).sum();
47 }
48
49 /**
50 * Compute the trace of an inverse quadratic form. I.E., this computes
51 * `trace(diag(D) B^T A^-1 B)`
52 * where D is the diagonal of a diagonal matrix (`diag(D)` is the diagonal
53 * matrix itself) and the LDLT_factor of A is provided.
54 *
55 * @tparam EigVec type of the diagonal of first matrix
56 * @tparam T type of matrix in the LDLT_factor
57 * @tparam EigMat type of the B matrix
58 *
59 * @param D diagonal of multiplier
60 * @param A LDLT_factor
61 * @param B inner term in quadratic form
62 * @return trace(diag(D) * B^T * A^-1 * B)
63 * @throw std::domain_error if A cannot be multiplied by B or B cannot
64 * be multiplied by diag(D).
65 */
66 template <typename EigVec, typename T, typename EigMat,
67 require_col_vector_t<EigVec>* = nullptr,
68 require_all_not_st_var<EigVec, T, EigMat>* = nullptr>
trace_gen_inv_quad_form_ldlt(const EigVec & D,LDLT_factor<T> & A,const EigMat & B)69 inline return_type_t<EigVec, T, EigMat> trace_gen_inv_quad_form_ldlt(
70 const EigVec& D, LDLT_factor<T>& A, const EigMat& B) {
71 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
72 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
73
74 if (D.size() == 0 || A.matrix().size() == 0) {
75 return 0;
76 }
77
78 return (B * D.asDiagonal()).cwiseProduct(mdivide_left_ldlt(A, B)).sum();
79 }
80
81 } // namespace math
82 } // namespace stan
83
84 #endif
85