1 #ifndef STAN_MATH_PRIM_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_PRIM_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/LDLT_factor.hpp>
8 #include <stan/math/prim/fun/mdivide_left_ldlt.hpp>
9 #include <stan/math/prim/fun/trace.hpp>
10 #include <stan/math/prim/fun/transpose.hpp>
11 #include <stan/math/prim/fun/multiply.hpp>
12 
13 namespace stan {
14 namespace math {
15 
16 /**
17  * Compute the trace of an inverse quadratic form.  I.E., this computes
18  *       trace(D B^T A^-1 B)
19  * where D is a square matrix and the LDLT_factor of A is provided.
20  *
21  * @tparam EigMat1 type of the first matrix
22  * @tparam T2 type of matrix in the LDLT_factor
23  * @tparam EigMat3 type of the third matrix
24  *
25  * @param D multiplier
26  * @param A LDLT_factor
27  * @param B inner term in quadratic form
28  * @return trace(D * B^T * A^-1 * B)
29  * @throw std::domain_error if D is not square
30  * @throw std::domain_error if A cannot be multiplied by B or B cannot
31  * be multiplied by D.
32  */
33 template <typename EigMat1, typename T2, typename EigMat3,
34           require_not_col_vector_t<EigMat1>* = nullptr,
35           require_all_not_st_var<EigMat1, T2, EigMat3>* = nullptr>
trace_gen_inv_quad_form_ldlt(const EigMat1 & D,LDLT_factor<T2> & A,const EigMat3 & B)36 inline return_type_t<EigMat1, T2, EigMat3> trace_gen_inv_quad_form_ldlt(
37     const EigMat1& D, LDLT_factor<T2>& A, const EigMat3& B) {
38   check_square("trace_gen_inv_quad_form_ldlt", "D", D);
39   check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
40   check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
41 
42   if (D.size() == 0 || A.matrix().size() == 0) {
43     return 0;
44   }
45 
46   return multiply(B, D.transpose()).cwiseProduct(mdivide_left_ldlt(A, B)).sum();
47 }
48 
49 /**
50  * Compute the trace of an inverse quadratic form.  I.E., this computes
51  *       `trace(diag(D) B^T A^-1 B)`
52  * where D is the diagonal of a diagonal matrix (`diag(D)` is the diagonal
53  * matrix itself) and the LDLT_factor of A is provided.
54  *
55  * @tparam EigVec type of the diagonal of first matrix
56  * @tparam T type of matrix in the LDLT_factor
57  * @tparam EigMat type of the B matrix
58  *
59  * @param D diagonal of multiplier
60  * @param A LDLT_factor
61  * @param B inner term in quadratic form
62  * @return trace(diag(D) * B^T * A^-1 * B)
63  * @throw std::domain_error if A cannot be multiplied by B or B cannot
64  * be multiplied by diag(D).
65  */
66 template <typename EigVec, typename T, typename EigMat,
67           require_col_vector_t<EigVec>* = nullptr,
68           require_all_not_st_var<EigVec, T, EigMat>* = nullptr>
trace_gen_inv_quad_form_ldlt(const EigVec & D,LDLT_factor<T> & A,const EigMat & B)69 inline return_type_t<EigVec, T, EigMat> trace_gen_inv_quad_form_ldlt(
70     const EigVec& D, LDLT_factor<T>& A, const EigMat& B) {
71   check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
72   check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
73 
74   if (D.size() == 0 || A.matrix().size() == 0) {
75     return 0;
76   }
77 
78   return (B * D.asDiagonal()).cwiseProduct(mdivide_left_ldlt(A, B)).sum();
79 }
80 
81 }  // namespace math
82 }  // namespace stan
83 
84 #endif
85