1 #ifndef STAN_MATH_REV_FUN_COLUMNS_DOT_SELF_HPP
2 #define STAN_MATH_REV_FUN_COLUMNS_DOT_SELF_HPP
3
4 #include <stan/math/prim/fun/to_ref.hpp>
5 #include <stan/math/rev/meta.hpp>
6 #include <stan/math/rev/core.hpp>
7 #include <stan/math/rev/fun/dot_self.hpp>
8 #include <stan/math/prim/fun/Eigen.hpp>
9
10 namespace stan {
11 namespace math {
12
13 /**
14 * Returns the dot product of each column of a matrix with itself.
15 *
16 * @tparam Mat An Eigen matrix with a `var` scalar type.
17 * @param x Matrix.
18 */
19 template <typename Mat, require_eigen_vt<is_var, Mat>* = nullptr>
columns_dot_self(const Mat & x)20 inline Eigen::Matrix<var, 1, Mat::ColsAtCompileTime> columns_dot_self(
21 const Mat& x) {
22 Eigen::Matrix<var, 1, Mat::ColsAtCompileTime> ret(1, x.cols());
23 for (size_type i = 0; i < x.cols(); i++) {
24 ret(i) = dot_self(x.col(i));
25 }
26 return ret;
27 }
28
29 /**
30 * Returns the dot product of each column of a matrix with itself.
31 *
32 * @tparam Mat A `var_value<>` with an inner matrix type.
33 * @param x Matrix.
34 */
35 template <typename Mat, require_var_matrix_t<Mat>* = nullptr>
columns_dot_self(const Mat & x)36 inline auto columns_dot_self(const Mat& x) {
37 using ret_type
38 = return_var_matrix_t<decltype(x.val().colwise().squaredNorm()), Mat>;
39 arena_t<ret_type> res = x.val().colwise().squaredNorm();
40 if (x.size() >= 0) {
41 reverse_pass_callback([res, x]() mutable {
42 x.adj() += x.val() * (2 * res.adj()).asDiagonal();
43 });
44 }
45 return res;
46 }
47
48 } // namespace math
49 } // namespace stan
50 #endif
51