1 #ifndef STAN_MATH_REV_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_REV_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/LDLT_factor.hpp>
7 #include <stan/math/rev/core/typedefs.hpp>
8 #include <stan/math/prim/meta.hpp>
9 #include <stan/math/prim/err.hpp>
10 #include <stan/math/prim/fun/Eigen.hpp>
11 #include <stan/math/prim/fun/typedefs.hpp>
12 #include <type_traits>
13
14 namespace stan {
15 namespace math {
16
17 /**
18 * Compute the trace of an inverse quadratic form premultiplied by a
19 * square matrix. This computes
20 * trace(B^T A^-1 B)
21 * where the LDLT_factor of A is provided.
22 *
23 * @tparam T1 type of elements in the LDLT_factor
24 * @tparam T2 type of the second matrix
25 *
26 * @param A an LDLT_factor
27 * @param B a matrix
28 * @return The trace of the inverse quadratic form.
29 */
30 template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
31 require_any_st_var<T1, T2>* = nullptr>
trace_inv_quad_form_ldlt(LDLT_factor<T1> & A,const T2 & B)32 inline var trace_inv_quad_form_ldlt(LDLT_factor<T1>& A, const T2& B) {
33 check_multiplicable("trace_quad_form", "A", A.matrix(), "B", B);
34
35 if (A.matrix().size() == 0)
36 return 0.0;
37
38 if (!is_constant<T1>::value && !is_constant<T2>::value) {
39 arena_t<promote_scalar_t<var, T1>> arena_A = A.matrix();
40 arena_t<promote_scalar_t<var, T2>> arena_B = B;
41 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
42
43 var res = (arena_B.val_op().transpose() * AsolveB).trace();
44
45 reverse_pass_callback([arena_A, AsolveB, arena_B, res]() mutable {
46 arena_A.adj() += -res.adj() * AsolveB * AsolveB.transpose();
47 arena_B.adj() += 2 * res.adj() * AsolveB;
48 });
49
50 return res;
51 } else if (!is_constant<T1>::value) {
52 arena_t<promote_scalar_t<var, T1>> arena_A = A.matrix();
53 const auto& B_ref = to_ref(B);
54
55 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
56
57 var res = (value_of(B_ref).transpose() * AsolveB).trace();
58
59 reverse_pass_callback([arena_A, AsolveB, res]() mutable {
60 arena_A.adj() += -res.adj() * AsolveB * AsolveB.transpose();
61 });
62
63 return res;
64 } else {
65 arena_t<promote_scalar_t<var, T2>> arena_B = B;
66 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
67
68 var res = (arena_B.val_op().transpose() * AsolveB).trace();
69
70 reverse_pass_callback([AsolveB, arena_B, res]() mutable {
71 arena_B.adj() += 2 * res.adj() * AsolveB;
72 });
73
74 return res;
75 }
76 }
77
78 } // namespace math
79 } // namespace stan
80 #endif
81