1 #ifndef STAN_MATH_PRIM_FUN_MATRIX_EXP_PADE_HPP
2 #define STAN_MATH_PRIM_FUN_MATRIX_EXP_PADE_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/MatrixExponential.h>
7 
8 namespace stan {
9 namespace math {
10 
11 /**
12  * Computes the matrix exponential, using a Pade
13  * approximation, coupled with scaling and
14  * squaring.
15  *
16  * @tparam MatrixType type of the matrix
17  * @param[in] arg matrix to exponentiate.
18  * @return Matrix exponential of input matrix.
19  */
20 template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
21 Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
22               EigMat::ColsAtCompileTime>
matrix_exp_pade(const EigMat & arg)23 matrix_exp_pade(const EigMat& arg) {
24   using MatrixType
25       = Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
26                       EigMat::ColsAtCompileTime>;
27   check_square("matrix_exp_pade", "arg", arg);
28   if (arg.size() == 0) {
29     return {};
30   }
31 
32   MatrixType U, V;
33   int squarings;
34 
35   Eigen::matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings, arg(0, 0));
36   // Pade approximant is
37   // (U+V) / (-U+V)
38   MatrixType numer = U + V;
39   MatrixType denom = -U + V;
40   MatrixType pade_approximation = denom.partialPivLu().solve(numer);
41   for (int i = 0; i < squarings; ++i) {
42     pade_approximation *= pade_approximation;  // undo scaling by
43   }
44   // repeated squaring
45   return pade_approximation;
46 }
47 
48 }  // namespace math
49 }  // namespace stan
50 
51 #endif
52