1 #ifndef STAN_MATH_REV_FUN_LDLT_FACTOR_HPP
2 #define STAN_MATH_REV_FUN_LDLT_FACTOR_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/err.hpp>
7 #include <stan/math/prim/fun/LDLT_factor.hpp>
8 
9 namespace stan {
10 namespace math {
11 
12 /**
13  * An LDLT_factor of an `Eigen::Matrix<var, Eigen::Dynamic, Eigen::Dynamic>`
14  * with `alloc_in_arena = True` holds a copy of the input matrix and the LDLT
15  * of its values, with all member variable allocations are done in the arena.
16  */
17 template <typename T>
18 class LDLT_factor<T, require_eigen_matrix_dynamic_vt<is_var, T>> {
19  private:
20   arena_t<plain_type_t<T>> matrix_;
21   Eigen::LDLT<Eigen::MatrixXd> ldlt_;
22 
23  public:
24   template <typename S,
25             require_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
LDLT_factor(const S & matrix)26   explicit LDLT_factor(const S& matrix)
27       : matrix_(matrix), ldlt_(matrix.val().ldlt()) {}
28 
29   /**
30    * Return a const reference to the underlying matrix
31    */
matrix() const32   const auto& matrix() const noexcept { return matrix_; }
33 
34   /**
35    * Return a const reference to the LDLT factor of the matrix values
36    */
ldlt() const37   const auto& ldlt() const noexcept { return ldlt_; }
38 };
39 
40 /**
41  * An LDLT_factor of a `var_value<Eigen::MatrixXd>`
42  * holds a copy of the input `var_value` and the LDLT of its values.
43  */
44 template <typename T>
45 class LDLT_factor<T, require_var_matrix_t<T>> {
46  private:
47   std::decay_t<T> matrix_;
48   Eigen::LDLT<Eigen::MatrixXd> ldlt_;
49 
50  public:
51   template <typename S,
52             require_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
LDLT_factor(const S & matrix)53   explicit LDLT_factor(const S& matrix)
54       : matrix_(matrix), ldlt_(matrix.val().ldlt()) {}
55 
56   /**
57    * Return a const reference the underlying `var_value`
58    */
matrix() const59   const auto& matrix() const noexcept { return matrix_; }
60 
61   /**
62    * Return a const reference to the LDLT factor of the matrix values
63    */
ldlt() const64   const auto& ldlt() const noexcept { return ldlt_; }
65 };
66 
67 }  // namespace math
68 }  // namespace stan
69 #endif
70