1 #ifndef STAN_MATH_REV_FUN_INVERSE_HPP
2 #define STAN_MATH_REV_FUN_INVERSE_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/fun/typedefs.hpp>
6 #include <stan/math/rev/core.hpp>
7 #include <stan/math/rev/meta.hpp>
8 #include <stan/math/rev/fun/value_of.hpp>
9 #include <stan/math/prim/err.hpp>
10 
11 namespace stan {
12 namespace math {
13 
14 /**
15  * Reverse mode specialization of calculating the inverse of the matrix.
16  *
17  * @param m specified matrix
18  * @return Inverse of the matrix (an empty matrix if the specified matrix has
19  * size zero).
20  * @throw std::invalid_argument if the matrix is not square.
21  */
22 template <typename T, require_rev_matrix_t<T>* = nullptr>
inverse(const T & m)23 inline auto inverse(const T& m) {
24   check_square("inverse", "m", m);
25 
26   using ret_type = return_var_matrix_t<T>;
27   if (unlikely(m.size() == 0)) {
28     return ret_type(m);
29   }
30 
31   arena_t<T> arena_m = m;
32   arena_t<promote_scalar_t<double, T>> res_val = arena_m.val().inverse();
33   arena_t<ret_type> res = res_val;
34 
35   reverse_pass_callback([res, res_val, arena_m]() mutable {
36     arena_m.adj() -= res_val.transpose() * res.adj_op() * res_val.transpose();
37   });
38 
39   return ret_type(res);
40 }
41 
42 }  // namespace math
43 }  // namespace stan
44 #endif
45