1 #ifndef STAN_MATH_PRIM_FUN_MATRIX_EXP_HPP
2 #define STAN_MATH_PRIM_FUN_MATRIX_EXP_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/exp.hpp>
6 #include <stan/math/prim/fun/matrix_exp_pade.hpp>
7 #include <stan/math/prim/fun/matrix_exp_2x2.hpp>
8 #include <stan/math/prim/fun/square.hpp>
9 #include <cmath>
10 
11 namespace stan {
12 namespace math {
13 
14 /**
15  * Return the matrix exponential of the input
16  * matrix.
17  *
18  * @tparam T type of the matrix
19  * @param[in] A_in Matrix to exponentiate.
20  * @return Matrix exponential, dynamically-sized.
21  * @throw <code>std::invalid_argument</code> if the input matrix
22  * is not square.
23  */
24 template <typename T, typename = require_eigen_t<T>>
matrix_exp(const T & A_in)25 inline plain_type_t<T> matrix_exp(const T& A_in) {
26   using std::exp;
27   const auto& A = A_in.eval();
28   check_square("matrix_exp", "input matrix", A);
29   if (T::RowsAtCompileTime == 1 && T::ColsAtCompileTime == 1) {
30     plain_type_t<T> res;
31     res << exp(A(0));
32     return res;
33   }
34   if (A_in.size() == 0) {
35     return {};
36   }
37   return (A.cols() == 2
38           && square(value_of(A(0, 0)) - value_of(A(1, 1)))
39                      + 4 * value_of(A(0, 1)) * value_of(A(1, 0))
40                  > 0)
41              ? matrix_exp_2x2(A)
42              : matrix_exp_pade(A);
43 }
44 
45 }  // namespace math
46 }  // namespace stan
47 
48 #endif
49