1 #ifndef STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_TRI_HPP
2 #define STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_TRI_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/to_ref.hpp>
8 
9 namespace stan {
10 namespace math {
11 
12 /**
13  * Returns the solution of the system Ax=b when A is triangular.
14  *
15  * @tparam TriView Specifies whether A is upper (Eigen::Upper)
16  * or lower triangular (Eigen::Lower).
17  * @tparam T1 type of the triangular matrix
18  * @tparam T2 type of the right-hand side matrix or vector
19  *
20  * @param A Triangular matrix.
21  * @param b Right hand side matrix or vector.
22  * @return x = A^-1 b, solution of the linear system.
23  * @throws std::domain_error if A is not square or the rows of b don't
24  * match the size of A.
25  */
26 template <Eigen::UpLoType TriView, typename T1, typename T2,
27           require_all_eigen_t<T1, T2> * = nullptr,
28           require_all_not_eigen_vt<is_var, T1, T2> * = nullptr>
29 inline Eigen::Matrix<return_type_t<T1, T2>, T1::RowsAtCompileTime,
30                      T2::ColsAtCompileTime>
mdivide_left_tri(const T1 & A,const T2 & b)31 mdivide_left_tri(const T1 &A, const T2 &b) {
32   using T_return = return_type_t<T1, T2>;
33   check_square("mdivide_left_tri", "A", A);
34   check_multiplicable("mdivide_left_tri", "A", A, "b", b);
35   if (A.rows() == 0) {
36     return {0, b.cols()};
37   }
38 
39   return A.template cast<T_return>()
40       .eval()
41       .template triangularView<TriView>()
42       .solve(b.template cast<T_return>().eval());
43 }
44 
45 /**
46  * Returns the solution of the system Ax=b when A is triangular and b=I.
47  *
48  * @tparam T type of the matrix
49  *
50  * @param A Triangular matrix.
51  * @return x = A^-1 .
52  * @throws std::domain_error if A is not square
53  */
54 template <Eigen::UpLoType TriView, typename T, require_eigen_t<T> * = nullptr>
mdivide_left_tri(const T & A)55 inline plain_type_t<T> mdivide_left_tri(const T &A) {
56   check_square("mdivide_left_tri", "A", A);
57   if (A.rows() == 0) {
58     return {};
59   }
60 
61   int n = A.rows();
62   plain_type_t<T> b = plain_type_t<T>::Identity(n, n);
63   A.template triangularView<TriView>().solveInPlace(b);
64   return b;
65 }
66 
67 }  // namespace math
68 }  // namespace stan
69 
70 #endif
71