1 #ifndef STAN_MATH_REV_MAT_FUN_QUAD_FORM_SYM_HPP
2 #define STAN_MATH_REV_MAT_FUN_QUAD_FORM_SYM_HPP
3 
4 #include <stan/math/rev/core.hpp>
5 #include <stan/math/prim/mat/fun/Eigen.hpp>
6 #include <stan/math/prim/mat/fun/typedefs.hpp>
7 #include <stan/math/rev/mat/fun/typedefs.hpp>
8 #include <stan/math/prim/mat/fun/value_of.hpp>
9 #include <stan/math/prim/mat/fun/quad_form.hpp>
10 #include <stan/math/prim/mat/err/check_multiplicable.hpp>
11 #include <stan/math/prim/mat/err/check_square.hpp>
12 #include <stan/math/prim/mat/err/check_symmetric.hpp>
13 #include <stan/math/rev/mat/fun/quad_form.hpp>
14 #include <type_traits>
15 
16 namespace stan {
17 namespace math {
18 
19 template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
20 inline typename std::enable_if<std::is_same<Ta, var>::value
21                                    || std::is_same<Tb, var>::value,
22                                Eigen::Matrix<var, Cb, Cb> >::type
quad_form_sym(const Eigen::Matrix<Ta,Ra,Ca> & A,const Eigen::Matrix<Tb,Rb,Cb> & B)23 quad_form_sym(const Eigen::Matrix<Ta, Ra, Ca>& A,
24               const Eigen::Matrix<Tb, Rb, Cb>& B) {
25   check_square("quad_form", "A", A);
26   check_symmetric("quad_form_sym", "A", A);
27   check_multiplicable("quad_form_sym", "A", A, "B", B);
28 
29   internal::quad_form_vari<Ta, Ra, Ca, Tb, Rb, Cb>* baseVari
30       = new internal::quad_form_vari<Ta, Ra, Ca, Tb, Rb, Cb>(A, B, true);
31 
32   return baseVari->impl_->C_;
33 }
34 
35 template <typename Ta, int Ra, int Ca, typename Tb, int Rb>
36 inline typename std::enable_if<
37     std::is_same<Ta, var>::value || std::is_same<Tb, var>::value, var>::type
quad_form_sym(const Eigen::Matrix<Ta,Ra,Ca> & A,const Eigen::Matrix<Tb,Rb,1> & B)38 quad_form_sym(const Eigen::Matrix<Ta, Ra, Ca>& A,
39               const Eigen::Matrix<Tb, Rb, 1>& B) {
40   check_square("quad_form", "A", A);
41   check_symmetric("quad_form_sym", "A", A);
42   check_multiplicable("quad_form_sym", "A", A, "B", B);
43 
44   internal::quad_form_vari<Ta, Ra, Ca, Tb, Rb, 1>* baseVari
45       = new internal::quad_form_vari<Ta, Ra, Ca, Tb, Rb, 1>(A, B, true);
46 
47   return baseVari->impl_->C_(0, 0);
48 }
49 
50 }  // namespace math
51 }  // namespace stan
52 #endif
53