1 #ifndef STAN_MATH_REV_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_REV_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/LDLT_factor.hpp>
7 #include <stan/math/rev/core/typedefs.hpp>
8 #include <stan/math/prim/meta.hpp>
9 #include <stan/math/prim/err.hpp>
10 #include <stan/math/prim/fun/Eigen.hpp>
11 #include <stan/math/prim/fun/typedefs.hpp>
12 #include <type_traits>
13 
14 namespace stan {
15 namespace math {
16 
17 /**
18  * Compute the trace of an inverse quadratic form premultiplied by a
19  * square matrix. This computes
20  *       trace(B^T A^-1 B)
21  * where the LDLT_factor of A is provided.
22  *
23  * @tparam T1 type of elements in the LDLT_factor
24  * @tparam T2 type of the second matrix
25  *
26  * @param A an LDLT_factor
27  * @param B a matrix
28  * @return The trace of the inverse quadratic form.
29  */
30 template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
31           require_any_st_var<T1, T2>* = nullptr>
trace_inv_quad_form_ldlt(LDLT_factor<T1> & A,const T2 & B)32 inline var trace_inv_quad_form_ldlt(LDLT_factor<T1>& A, const T2& B) {
33   check_multiplicable("trace_quad_form", "A", A.matrix(), "B", B);
34 
35   if (A.matrix().size() == 0)
36     return 0.0;
37 
38   if (!is_constant<T1>::value && !is_constant<T2>::value) {
39     arena_t<promote_scalar_t<var, T1>> arena_A = A.matrix();
40     arena_t<promote_scalar_t<var, T2>> arena_B = B;
41     auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
42 
43     var res = (arena_B.val_op().transpose() * AsolveB).trace();
44 
45     reverse_pass_callback([arena_A, AsolveB, arena_B, res]() mutable {
46       arena_A.adj() += -res.adj() * AsolveB * AsolveB.transpose();
47       arena_B.adj() += 2 * res.adj() * AsolveB;
48     });
49 
50     return res;
51   } else if (!is_constant<T1>::value) {
52     arena_t<promote_scalar_t<var, T1>> arena_A = A.matrix();
53     const auto& B_ref = to_ref(B);
54 
55     auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
56 
57     var res = (value_of(B_ref).transpose() * AsolveB).trace();
58 
59     reverse_pass_callback([arena_A, AsolveB, res]() mutable {
60       arena_A.adj() += -res.adj() * AsolveB * AsolveB.transpose();
61     });
62 
63     return res;
64   } else {
65     arena_t<promote_scalar_t<var, T2>> arena_B = B;
66     auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
67 
68     var res = (arena_B.val_op().transpose() * AsolveB).trace();
69 
70     reverse_pass_callback([AsolveB, arena_B, res]() mutable {
71       arena_B.adj() += 2 * res.adj() * AsolveB;
72     });
73 
74     return res;
75   }
76 }
77 
78 }  // namespace math
79 }  // namespace stan
80 #endif
81