1 #ifndef STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
2 #define STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/value_of.hpp>
7 #include <stan/math/rev/fun/to_var_value.hpp>
8 #include <stan/math/prim/meta.hpp>
9 #include <stan/math/prim/err.hpp>
10 #include <stan/math/prim/fun/Eigen.hpp>
11 #include <stan/math/prim/fun/trace_quad_form.hpp>
12 #include <stan/math/prim/fun/typedefs.hpp>
13 #include <stan/math/prim/fun/value_of.hpp>
14 #include <type_traits>
15
16 namespace stan {
17 namespace math {
18 namespace internal {
19 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
20 class trace_quad_form_vari_alloc : public chainable_alloc {
21 public:
trace_quad_form_vari_alloc(const Eigen::Matrix<Ta,Ra,Ca> & A,const Eigen::Matrix<Tb,Rb,Cb> & B)22 trace_quad_form_vari_alloc(const Eigen::Matrix<Ta, Ra, Ca>& A,
23 const Eigen::Matrix<Tb, Rb, Cb>& B)
24 : A_(A), B_(B) {}
25
compute()26 double compute() { return trace_quad_form(value_of(A_), value_of(B_)); }
27
28 Eigen::Matrix<Ta, Ra, Ca> A_;
29 Eigen::Matrix<Tb, Rb, Cb> B_;
30 };
31
32 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
33 class trace_quad_form_vari : public vari {
34 protected:
chainA(Eigen::Matrix<double,Ra,Ca> & A,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)35 static inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
36 const Eigen::Matrix<double, Rb, Cb>& Bd,
37 double adjC) {}
chainB(Eigen::Matrix<double,Rb,Cb> & B,const Eigen::Matrix<double,Ra,Ca> & Ad,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)38 static inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
39 const Eigen::Matrix<double, Ra, Ca>& Ad,
40 const Eigen::Matrix<double, Rb, Cb>& Bd,
41 double adjC) {}
42
chainA(Eigen::Matrix<var,Ra,Ca> & A,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)43 static inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
44 const Eigen::Matrix<double, Rb, Cb>& Bd,
45 double adjC) {
46 A.adj() += adjC * Bd * Bd.transpose();
47 }
chainB(Eigen::Matrix<var,Rb,Cb> & B,const Eigen::Matrix<double,Ra,Ca> & Ad,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)48 static inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
49 const Eigen::Matrix<double, Ra, Ca>& Ad,
50 const Eigen::Matrix<double, Rb, Cb>& Bd,
51 double adjC) {
52 B.adj() += adjC * (Ad + Ad.transpose()) * Bd;
53 }
54
chainAB(Eigen::Matrix<Ta,Ra,Ca> & A,Eigen::Matrix<Tb,Rb,Cb> & B,const Eigen::Matrix<double,Ra,Ca> & Ad,const Eigen::Matrix<double,Rb,Cb> & Bd,double adjC)55 inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
56 Eigen::Matrix<Tb, Rb, Cb>& B,
57 const Eigen::Matrix<double, Ra, Ca>& Ad,
58 const Eigen::Matrix<double, Rb, Cb>& Bd, double adjC) {
59 chainA(A, Bd, adjC);
60 chainB(B, Ad, Bd, adjC);
61 }
62
63 public:
trace_quad_form_vari(trace_quad_form_vari_alloc<Ta,Ra,Ca,Tb,Rb,Cb> * impl)64 explicit trace_quad_form_vari(
65 trace_quad_form_vari_alloc<Ta, Ra, Ca, Tb, Rb, Cb>* impl)
66 : vari(impl->compute()), impl_(impl) {}
67
chain()68 virtual void chain() {
69 chainAB(impl_->A_, impl_->B_, value_of(impl_->A_), value_of(impl_->B_),
70 adj_);
71 }
72
73 trace_quad_form_vari_alloc<Ta, Ra, Ca, Tb, Rb, Cb>* impl_;
74 };
75 } // namespace internal
76
77 template <typename EigMat1, typename EigMat2,
78 require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
79 require_any_st_var<EigMat1, EigMat2>* = nullptr>
trace_quad_form(const EigMat1 & A,const EigMat2 & B)80 inline return_type_t<EigMat1, EigMat2> trace_quad_form(const EigMat1& A,
81 const EigMat2& B) {
82 using Ta = value_type_t<EigMat1>;
83 using Tb = value_type_t<EigMat2>;
84 constexpr int Ra = EigMat1::RowsAtCompileTime;
85 constexpr int Ca = EigMat1::ColsAtCompileTime;
86 constexpr int Rb = EigMat2::RowsAtCompileTime;
87 constexpr int Cb = EigMat2::ColsAtCompileTime;
88 check_square("trace_quad_form", "A", A);
89 check_multiplicable("trace_quad_form", "A", A, "B", B);
90
91 auto* baseVari
92 = new internal::trace_quad_form_vari_alloc<Ta, Ra, Ca, Tb, Rb, Cb>(A, B);
93
94 return var(
95 new internal::trace_quad_form_vari<Ta, Ra, Ca, Tb, Rb, Cb>(baseVari));
96 }
97
98 /**
99 * Compute trace(B^T A B).
100 *
101 * This overload handles arguments where one of Mat1 or Mat2 are
102 * `var_value<T>` where `T` is an Eigen type. The other type can
103 * also be a `var_value` or it can be a type that inherits
104 * from EigenBase
105 *
106 * @tparam Mat1 type of the first matrix
107 * @tparam Mat2 type of the second matrix
108 *
109 * @param A matrix
110 * @param B matrix
111 * @return The trace of B^T A B
112 * @throw std::domain_error if A is not square
113 * @throw std::domain_error if A cannot be multiplied by B
114 */
115 template <typename Mat1, typename Mat2,
116 require_all_matrix_t<Mat1, Mat2>* = nullptr,
117 require_any_var_matrix_t<Mat1, Mat2>* = nullptr>
trace_quad_form(const Mat1 & A,const Mat2 & B)118 inline var trace_quad_form(const Mat1& A, const Mat2& B) {
119 check_square("trace_quad_form", "A", A);
120 check_multiplicable("trace_quad_form", "A", A, "B", B);
121
122 var res;
123
124 if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
125 arena_t<promote_scalar_t<var, Mat1>> arena_A = A;
126 arena_t<promote_scalar_t<var, Mat2>> arena_B = B;
127
128 res = (value_of(arena_B).transpose() * value_of(arena_A)
129 * value_of(arena_B))
130 .trace();
131
132 reverse_pass_callback([arena_A, arena_B, res]() mutable {
133 if (is_var_matrix<Mat1>::value) {
134 arena_A.adj().noalias()
135 += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
136 } else {
137 arena_A.adj()
138 += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
139 }
140
141 if (is_var_matrix<Mat2>::value) {
142 arena_B.adj().noalias()
143 += res.adj() * (value_of(arena_A) + value_of(arena_A).transpose())
144 * value_of(arena_B);
145 } else {
146 arena_B.adj() += res.adj()
147 * (value_of(arena_A) + value_of(arena_A).transpose())
148 * value_of(arena_B);
149 }
150 });
151 } else if (!is_constant<Mat2>::value) {
152 arena_t<promote_scalar_t<double, Mat1>> arena_A = value_of(A);
153 arena_t<promote_scalar_t<var, Mat2>> arena_B = B;
154
155 res = (value_of(arena_B).transpose() * value_of(arena_A)
156 * value_of(arena_B))
157 .trace();
158
159 reverse_pass_callback([arena_A, arena_B, res]() mutable {
160 if (is_var_matrix<Mat2>::value) {
161 arena_B.adj().noalias()
162 += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
163 } else {
164 arena_B.adj()
165 += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
166 }
167 });
168 } else {
169 arena_t<promote_scalar_t<var, Mat1>> arena_A = A;
170 arena_t<promote_scalar_t<double, Mat2>> arena_B = value_of(B);
171
172 res = (arena_B.transpose() * value_of(arena_A) * arena_B).trace();
173
174 reverse_pass_callback([arena_A, arena_B, res]() mutable {
175 if (is_var_matrix<Mat1>::value) {
176 arena_A.adj().noalias() += res.adj() * arena_B * arena_B.transpose();
177 } else {
178 arena_A.adj() += res.adj() * arena_B * arena_B.transpose();
179 }
180 });
181 }
182
183 return res;
184 }
185
186 } // namespace math
187 } // namespace stan
188 #endif
189