1 #ifndef STAN_MATH_REV_FUN_COLUMNS_DOT_SELF_HPP
2 #define STAN_MATH_REV_FUN_COLUMNS_DOT_SELF_HPP
3 
4 #include <stan/math/prim/fun/to_ref.hpp>
5 #include <stan/math/rev/meta.hpp>
6 #include <stan/math/rev/core.hpp>
7 #include <stan/math/rev/fun/dot_self.hpp>
8 #include <stan/math/prim/fun/Eigen.hpp>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * Returns the dot product of each column of a matrix with itself.
15  *
16  * @tparam Mat An Eigen matrix with a `var` scalar type.
17  * @param x Matrix.
18  */
19 template <typename Mat, require_eigen_vt<is_var, Mat>* = nullptr>
columns_dot_self(const Mat & x)20 inline Eigen::Matrix<var, 1, Mat::ColsAtCompileTime> columns_dot_self(
21     const Mat& x) {
22   Eigen::Matrix<var, 1, Mat::ColsAtCompileTime> ret(1, x.cols());
23   for (size_type i = 0; i < x.cols(); i++) {
24     ret(i) = dot_self(x.col(i));
25   }
26   return ret;
27 }
28 
29 /**
30  * Returns the dot product of each column of a matrix with itself.
31  *
32  * @tparam Mat A `var_value<>` with an inner matrix type.
33  * @param x Matrix.
34  */
35 template <typename Mat, require_var_matrix_t<Mat>* = nullptr>
columns_dot_self(const Mat & x)36 inline auto columns_dot_self(const Mat& x) {
37   using ret_type
38       = return_var_matrix_t<decltype(x.val().colwise().squaredNorm()), Mat>;
39   arena_t<ret_type> res = x.val().colwise().squaredNorm();
40   if (x.size() >= 0) {
41     reverse_pass_callback([res, x]() mutable {
42       x.adj() += x.val() * (2 * res.adj()).asDiagonal();
43     });
44   }
45   return res;
46 }
47 
48 }  // namespace math
49 }  // namespace stan
50 #endif
51