1 #ifndef STAN_MATH_REV_FUN_ROWS_DOT_SELF_HPP
2 #define STAN_MATH_REV_FUN_ROWS_DOT_SELF_HPP
3 
4 #include <stan/math/prim/fun/to_ref.hpp>
5 #include <stan/math/rev/meta.hpp>
6 #include <stan/math/rev/core.hpp>
7 #include <stan/math/rev/fun/dot_self.hpp>
8 #include <stan/math/prim/fun/Eigen.hpp>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * Returns the dot product of each row of a matrix with itself.
15  *
16  * @tparam Mat An Eigen matrix with a `var` scalar type.
17  * @param x Matrix.
18  */
19 template <typename Mat, require_eigen_vt<is_var, Mat>* = nullptr>
rows_dot_self(const Mat & x)20 inline Eigen::Matrix<var, Mat::RowsAtCompileTime, 1> rows_dot_self(
21     const Mat& x) {
22   Eigen::Matrix<var, Mat::RowsAtCompileTime, 1> ret(x.rows());
23   for (size_type i = 0; i < x.rows(); i++) {
24     ret(i) = dot_self(x.row(i));
25   }
26   return ret;
27 }
28 
29 /**
30  * Returns the dot product of row row of a matrix with itself.
31  *
32  * @tparam Mat A `var_value<>` with an inner matrix type.
33  * @param x Matrix.
34  */
35 template <typename Mat, require_var_matrix_t<Mat>* = nullptr>
rows_dot_self(const Mat & x)36 inline auto rows_dot_self(const Mat& x) {
37   using ret_type = var_value<Eigen::VectorXd>;
38   arena_t<ret_type> res = x.val().rowwise().squaredNorm();
39   if (x.size() >= 0) {
40     reverse_pass_callback([res, x]() mutable {
41       x.adj() += (2 * res.adj()).asDiagonal() * x.val();
42     });
43   }
44   return res;
45 }
46 
47 }  // namespace math
48 }  // namespace stan
49 #endif
50