1 #ifndef STAN_MATH_REV_FUN_MATRIX_POWER_HPP
2 #define STAN_MATH_REV_FUN_MATRIX_POWER_HPP
3 
4 #include <stan/math/rev/core.hpp>
5 #include <stan/math/rev/fun/value_of.hpp>
6 #include <stan/math/rev/core/typedefs.hpp>
7 #include <stan/math/prim/err.hpp>
8 #include <stan/math/prim/fun/Eigen.hpp>
9 #include <stan/math/prim/fun/to_ref.hpp>
10 #include <stan/math/prim/fun/typedefs.hpp>
11 #include <vector>
12 
13 namespace stan {
14 namespace math {
15 
16 /**
17  * Returns the nth power of the specific matrix. M^n = M * M * ... * M.
18  *
19  * @tparam R number of rows, can be Eigen::Dynamic
20  * @tparam C number of columns, can be Eigen::Dynamic
21  * @param[in] M a square matrix
22  * @param[in] n exponent
23  * @return nth power of M
24  * @throw std::domain_error if the matrix contains NaNs or infinities.
25  * @throw std::invalid_argument if the exponent is negative or the matrix is not
26  * square.
27  */
28 template <typename T, require_rev_matrix_t<T>* = nullptr>
matrix_power(const T & M,const int n)29 inline plain_type_t<T> matrix_power(const T& M, const int n) {
30   check_square("matrix_power", "M", M);
31   check_nonnegative("matrix_power", "n", n);
32 
33   if (M.size() == 0)
34     return M;
35 
36   const auto& M_ref = to_ref(M);
37   check_finite("matrix_power", "M", M_ref);
38 
39   size_t N = M.rows();
40 
41   if (n == 0)
42     return Eigen::MatrixXd::Identity(N, N);
43 
44   if (n == 1)
45     return M_ref;
46 
47   arena_t<std::vector<Eigen::MatrixXd>> arena_powers(n + 1);
48   arena_t<plain_type_t<T>> arena_M = M_ref;
49 
50   arena_powers[0] = Eigen::MatrixXd::Identity(N, N);
51   arena_powers[1] = M_ref.val();
52   for (size_t i = 2; i <= n; ++i) {
53     arena_powers[i] = arena_powers[1] * arena_powers[i - 1];
54   }
55   using ret_type = return_var_matrix_t<T>;
56   arena_t<ret_type> res = arena_powers[arena_powers.size() - 1];
57 
58   reverse_pass_callback([arena_M, n, res, arena_powers]() mutable {
59     const auto& M_val = arena_powers[1];
60     Eigen::MatrixXd adj_C = res.adj();
61     Eigen::MatrixXd adj_M = Eigen::MatrixXd::Zero(M_val.rows(), M_val.cols());
62     for (size_t i = n; i > 1; --i) {
63       adj_M += adj_C * arena_powers[i - 1].transpose();
64       adj_C = M_val.transpose() * adj_C;
65     }
66     arena_M.adj() += adj_M + adj_C;
67   });
68 
69   return ret_type(res);
70 }
71 
72 }  // namespace math
73 }  // namespace stan
74 #endif
75