1 #ifndef STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_TRI_HPP
2 #define STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_TRI_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/to_ref.hpp>
8
9 namespace stan {
10 namespace math {
11
12 /**
13 * Returns the solution of the system Ax=b when A is triangular.
14 *
15 * @tparam TriView Specifies whether A is upper (Eigen::Upper)
16 * or lower triangular (Eigen::Lower).
17 * @tparam T1 type of the triangular matrix
18 * @tparam T2 type of the right-hand side matrix or vector
19 *
20 * @param A Triangular matrix.
21 * @param b Right hand side matrix or vector.
22 * @return x = A^-1 b, solution of the linear system.
23 * @throws std::domain_error if A is not square or the rows of b don't
24 * match the size of A.
25 */
26 template <Eigen::UpLoType TriView, typename T1, typename T2,
27 require_all_eigen_t<T1, T2> * = nullptr,
28 require_all_not_eigen_vt<is_var, T1, T2> * = nullptr>
29 inline Eigen::Matrix<return_type_t<T1, T2>, T1::RowsAtCompileTime,
30 T2::ColsAtCompileTime>
mdivide_left_tri(const T1 & A,const T2 & b)31 mdivide_left_tri(const T1 &A, const T2 &b) {
32 using T_return = return_type_t<T1, T2>;
33 check_square("mdivide_left_tri", "A", A);
34 check_multiplicable("mdivide_left_tri", "A", A, "b", b);
35 if (A.rows() == 0) {
36 return {0, b.cols()};
37 }
38
39 return A.template cast<T_return>()
40 .eval()
41 .template triangularView<TriView>()
42 .solve(b.template cast<T_return>().eval());
43 }
44
45 /**
46 * Returns the solution of the system Ax=b when A is triangular and b=I.
47 *
48 * @tparam T type of the matrix
49 *
50 * @param A Triangular matrix.
51 * @return x = A^-1 .
52 * @throws std::domain_error if A is not square
53 */
54 template <Eigen::UpLoType TriView, typename T, require_eigen_t<T> * = nullptr>
mdivide_left_tri(const T & A)55 inline plain_type_t<T> mdivide_left_tri(const T &A) {
56 check_square("mdivide_left_tri", "A", A);
57 if (A.rows() == 0) {
58 return {};
59 }
60
61 int n = A.rows();
62 plain_type_t<T> b = plain_type_t<T>::Identity(n, n);
63 A.template triangularView<TriView>().solveInPlace(b);
64 return b;
65 }
66
67 } // namespace math
68 } // namespace stan
69
70 #endif
71