1 #ifndef STAN_MATH_REV_FUN_MATRIX_POWER_HPP
2 #define STAN_MATH_REV_FUN_MATRIX_POWER_HPP
3
4 #include <stan/math/rev/core.hpp>
5 #include <stan/math/rev/fun/value_of.hpp>
6 #include <stan/math/rev/core/typedefs.hpp>
7 #include <stan/math/prim/err.hpp>
8 #include <stan/math/prim/fun/Eigen.hpp>
9 #include <stan/math/prim/fun/to_ref.hpp>
10 #include <stan/math/prim/fun/typedefs.hpp>
11 #include <vector>
12
13 namespace stan {
14 namespace math {
15
16 /**
17 * Returns the nth power of the specific matrix. M^n = M * M * ... * M.
18 *
19 * @tparam R number of rows, can be Eigen::Dynamic
20 * @tparam C number of columns, can be Eigen::Dynamic
21 * @param[in] M a square matrix
22 * @param[in] n exponent
23 * @return nth power of M
24 * @throw std::domain_error if the matrix contains NaNs or infinities.
25 * @throw std::invalid_argument if the exponent is negative or the matrix is not
26 * square.
27 */
28 template <typename T, require_rev_matrix_t<T>* = nullptr>
matrix_power(const T & M,const int n)29 inline plain_type_t<T> matrix_power(const T& M, const int n) {
30 check_square("matrix_power", "M", M);
31 check_nonnegative("matrix_power", "n", n);
32
33 if (M.size() == 0)
34 return M;
35
36 const auto& M_ref = to_ref(M);
37 check_finite("matrix_power", "M", M_ref);
38
39 size_t N = M.rows();
40
41 if (n == 0)
42 return Eigen::MatrixXd::Identity(N, N);
43
44 if (n == 1)
45 return M_ref;
46
47 arena_t<std::vector<Eigen::MatrixXd>> arena_powers(n + 1);
48 arena_t<plain_type_t<T>> arena_M = M_ref;
49
50 arena_powers[0] = Eigen::MatrixXd::Identity(N, N);
51 arena_powers[1] = M_ref.val();
52 for (size_t i = 2; i <= n; ++i) {
53 arena_powers[i] = arena_powers[1] * arena_powers[i - 1];
54 }
55 using ret_type = return_var_matrix_t<T>;
56 arena_t<ret_type> res = arena_powers[arena_powers.size() - 1];
57
58 reverse_pass_callback([arena_M, n, res, arena_powers]() mutable {
59 const auto& M_val = arena_powers[1];
60 Eigen::MatrixXd adj_C = res.adj();
61 Eigen::MatrixXd adj_M = Eigen::MatrixXd::Zero(M_val.rows(), M_val.cols());
62 for (size_t i = n; i > 1; --i) {
63 adj_M += adj_C * arena_powers[i - 1].transpose();
64 adj_C = M_val.transpose() * adj_C;
65 }
66 arena_M.adj() += adj_M + adj_C;
67 });
68
69 return ret_type(res);
70 }
71
72 } // namespace math
73 } // namespace stan
74 #endif
75