1 #ifndef STAN_MATH_REV_FUN_TRACE_HPP
2 #define STAN_MATH_REV_FUN_TRACE_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/meta.hpp>
6 #include <stan/math/rev/core.hpp>
7 #include <stan/math/rev/meta.hpp>
8 
9 namespace stan {
10 namespace math {
11 
12 /**
13  * Returns the trace of the specified matrix.  The trace
14  * is defined as the sum of the elements on the diagonal.
15  * The matrix is not required to be square.  Returns 0 if
16  * matrix is empty.
17  *
18  * @tparam T type of the elements in the matrix
19  * @param[in] m Specified matrix.
20  * @return Trace of the matrix.
21  */
22 template <typename T, require_rev_matrix_t<T>* = nullptr>
trace(const T & m)23 inline auto trace(const T& m) {
24   arena_t<T> arena_m = m;
25 
26   return make_callback_var(arena_m.val_op().trace(),
27                            [arena_m](const auto& vi) mutable {
28                              arena_m.adj().diagonal().array() += vi.adj();
29                            });
30 }
31 
32 }  // namespace math
33 }  // namespace stan
34 
35 #endif
36