1 #ifndef STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
2 #define STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/value_of.hpp>
7 #include <stan/math/rev/fun/to_var_value.hpp>
8 #include <stan/math/prim/meta.hpp>
9 #include <stan/math/prim/err.hpp>
10 #include <stan/math/prim/fun/Eigen.hpp>
11 #include <stan/math/prim/fun/trace_quad_form.hpp>
12 #include <stan/math/prim/fun/typedefs.hpp>
13 #include <stan/math/prim/fun/value_of.hpp>
14 #include <type_traits>
15 
16 namespace stan {
17 namespace math {
18 namespace internal {
19 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
20 class trace_quad_form_vari_alloc : public chainable_alloc {
21  public:
trace_quad_form_vari_alloc(const Eigen::Matrix<Ta,Ra,Ca> & A,const Eigen::Matrix<Tb,Rb,Cb> & B)22   trace_quad_form_vari_alloc(const Eigen::Matrix<Ta, Ra, Ca>& A,
23                              const Eigen::Matrix<Tb, Rb, Cb>& B)
24       : A_(A), B_(B) {}
25 
compute()26   double compute() { return trace_quad_form(value_of(A_), value_of(B_)); }
27 
28   Eigen::Matrix<Ta, Ra, Ca> A_;
29   Eigen::Matrix<Tb, Rb, Cb> B_;
30 };
31 
32 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
33 class trace_quad_form_vari : public vari {
34  protected:
chainA(Eigen::Matrix<double,Ra,Ca> & A,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)35   static inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
36                             const Eigen::Matrix<double, Rb, Cb>& Bd,
37                             double adjC) {}
chainB(Eigen::Matrix<double,Rb,Cb> & B,const Eigen::Matrix<double,Ra,Ca> & Ad,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)38   static inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
39                             const Eigen::Matrix<double, Ra, Ca>& Ad,
40                             const Eigen::Matrix<double, Rb, Cb>& Bd,
41                             double adjC) {}
42 
chainA(Eigen::Matrix<var,Ra,Ca> & A,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)43   static inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
44                             const Eigen::Matrix<double, Rb, Cb>& Bd,
45                             double adjC) {
46     A.adj() += adjC * Bd * Bd.transpose();
47   }
chainB(Eigen::Matrix<var,Rb,Cb> & B,const Eigen::Matrix<double,Ra,Ca> & Ad,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)48   static inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
49                             const Eigen::Matrix<double, Ra, Ca>& Ad,
50                             const Eigen::Matrix<double, Rb, Cb>& Bd,
51                             double adjC) {
52     B.adj() += adjC * (Ad + Ad.transpose()) * Bd;
53   }
54 
chainAB(Eigen::Matrix<Ta,Ra,Ca> & A,Eigen::Matrix<Tb,Rb,Cb> & B,const Eigen::Matrix<double,Ra,Ca> & Ad,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)55   inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
56                       Eigen::Matrix<Tb, Rb, Cb>& B,
57                       const Eigen::Matrix<double, Ra, Ca>& Ad,
58                       const Eigen::Matrix<double, Rb, Cb>& Bd, double adjC) {
59     chainA(A, Bd, adjC);
60     chainB(B, Ad, Bd, adjC);
61   }
62 
63  public:
trace_quad_form_vari(trace_quad_form_vari_alloc<Ta,Ra,Ca,Tb,Rb,Cb> * impl)64   explicit trace_quad_form_vari(
65       trace_quad_form_vari_alloc<Ta, Ra, Ca, Tb, Rb, Cb>* impl)
66       : vari(impl->compute()), impl_(impl) {}
67 
chain()68   virtual void chain() {
69     chainAB(impl_->A_, impl_->B_, value_of(impl_->A_), value_of(impl_->B_),
70             adj_);
71   }
72 
73   trace_quad_form_vari_alloc<Ta, Ra, Ca, Tb, Rb, Cb>* impl_;
74 };
75 }  // namespace internal
76 
77 template <typename EigMat1, typename EigMat2,
78           require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
79           require_any_st_var<EigMat1, EigMat2>* = nullptr>
trace_quad_form(const EigMat1 & A,const EigMat2 & B)80 inline return_type_t<EigMat1, EigMat2> trace_quad_form(const EigMat1& A,
81                                                        const EigMat2& B) {
82   using Ta = value_type_t<EigMat1>;
83   using Tb = value_type_t<EigMat2>;
84   constexpr int Ra = EigMat1::RowsAtCompileTime;
85   constexpr int Ca = EigMat1::ColsAtCompileTime;
86   constexpr int Rb = EigMat2::RowsAtCompileTime;
87   constexpr int Cb = EigMat2::ColsAtCompileTime;
88   check_square("trace_quad_form", "A", A);
89   check_multiplicable("trace_quad_form", "A", A, "B", B);
90 
91   auto* baseVari
92       = new internal::trace_quad_form_vari_alloc<Ta, Ra, Ca, Tb, Rb, Cb>(A, B);
93 
94   return var(
95       new internal::trace_quad_form_vari<Ta, Ra, Ca, Tb, Rb, Cb>(baseVari));
96 }
97 
98 /**
99  * Compute trace(B^T A B).
100  *
101  * This overload handles arguments where one of Mat1 or Mat2 are
102  * `var_value<T>` where `T` is an Eigen type. The other type can
103  * also be a `var_value` or it can be a type that inherits
104  * from EigenBase
105  *
106  * @tparam Mat1 type of the first matrix
107  * @tparam Mat2 type of the second matrix
108  *
109  * @param A matrix
110  * @param B matrix
111  * @return The trace of B^T A B
112  * @throw std::domain_error if A is not square
113  * @throw std::domain_error if A cannot be multiplied by B
114  */
115 template <typename Mat1, typename Mat2,
116           require_all_matrix_t<Mat1, Mat2>* = nullptr,
117           require_any_var_matrix_t<Mat1, Mat2>* = nullptr>
trace_quad_form(const Mat1 & A,const Mat2 & B)118 inline var trace_quad_form(const Mat1& A, const Mat2& B) {
119   check_square("trace_quad_form", "A", A);
120   check_multiplicable("trace_quad_form", "A", A, "B", B);
121 
122   var res;
123 
124   if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
125     arena_t<promote_scalar_t<var, Mat1>> arena_A = A;
126     arena_t<promote_scalar_t<var, Mat2>> arena_B = B;
127 
128     res = (value_of(arena_B).transpose() * value_of(arena_A)
129            * value_of(arena_B))
130               .trace();
131 
132     reverse_pass_callback([arena_A, arena_B, res]() mutable {
133       if (is_var_matrix<Mat1>::value) {
134         arena_A.adj().noalias()
135             += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
136       } else {
137         arena_A.adj()
138             += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
139       }
140 
141       if (is_var_matrix<Mat2>::value) {
142         arena_B.adj().noalias()
143             += res.adj() * (value_of(arena_A) + value_of(arena_A).transpose())
144                * value_of(arena_B);
145       } else {
146         arena_B.adj() += res.adj()
147                          * (value_of(arena_A) + value_of(arena_A).transpose())
148                          * value_of(arena_B);
149       }
150     });
151   } else if (!is_constant<Mat2>::value) {
152     arena_t<promote_scalar_t<double, Mat1>> arena_A = value_of(A);
153     arena_t<promote_scalar_t<var, Mat2>> arena_B = B;
154 
155     res = (value_of(arena_B).transpose() * value_of(arena_A)
156            * value_of(arena_B))
157               .trace();
158 
159     reverse_pass_callback([arena_A, arena_B, res]() mutable {
160       if (is_var_matrix<Mat2>::value) {
161         arena_B.adj().noalias()
162             += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
163       } else {
164         arena_B.adj()
165             += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
166       }
167     });
168   } else {
169     arena_t<promote_scalar_t<var, Mat1>> arena_A = A;
170     arena_t<promote_scalar_t<double, Mat2>> arena_B = value_of(B);
171 
172     res = (arena_B.transpose() * value_of(arena_A) * arena_B).trace();
173 
174     reverse_pass_callback([arena_A, arena_B, res]() mutable {
175       if (is_var_matrix<Mat1>::value) {
176         arena_A.adj().noalias() += res.adj() * arena_B * arena_B.transpose();
177       } else {
178         arena_A.adj() += res.adj() * arena_B * arena_B.transpose();
179       }
180     });
181   }
182 
183   return res;
184 }
185 
186 }  // namespace math
187 }  // namespace stan
188 #endif
189