1 #ifndef STAN_MATH_PRIM_FUN_MATRIX_EXP_PADE_HPP
2 #define STAN_MATH_PRIM_FUN_MATRIX_EXP_PADE_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/MatrixExponential.h>
7
8 namespace stan {
9 namespace math {
10
11 /**
12 * Computes the matrix exponential, using a Pade
13 * approximation, coupled with scaling and
14 * squaring.
15 *
16 * @tparam MatrixType type of the matrix
17 * @param[in] arg matrix to exponentiate.
18 * @return Matrix exponential of input matrix.
19 */
20 template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
21 Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
22 EigMat::ColsAtCompileTime>
matrix_exp_pade(const EigMat & arg)23 matrix_exp_pade(const EigMat& arg) {
24 using MatrixType
25 = Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
26 EigMat::ColsAtCompileTime>;
27 check_square("matrix_exp_pade", "arg", arg);
28 if (arg.size() == 0) {
29 return {};
30 }
31
32 MatrixType U, V;
33 int squarings;
34
35 Eigen::matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings, arg(0, 0));
36 // Pade approximant is
37 // (U+V) / (-U+V)
38 MatrixType numer = U + V;
39 MatrixType denom = -U + V;
40 MatrixType pade_approximation = denom.partialPivLu().solve(numer);
41 for (int i = 0; i < squarings; ++i) {
42 pade_approximation *= pade_approximation; // undo scaling by
43 }
44 // repeated squaring
45 return pade_approximation;
46 }
47
48 } // namespace math
49 } // namespace stan
50
51 #endif
52