1 #ifndef STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_TRI_HPP
2 #define STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_TRI_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/to_ref.hpp>
8
9 namespace stan {
10 namespace math {
11
12 /**
13 * Returns the solution of the system xA=b when A is triangular
14 *
15 * @tparam TriView Specifies whether A is upper (Eigen::Upper)
16 * or lower triangular (Eigen::Lower).
17 * @tparam EigMat1 type of the right-hand side matrix or vector
18 * @tparam EigMat2 type of the triangular matrix
19 *
20 * @param A Triangular matrix. Specify upper or lower with TriView
21 * being Eigen::Upper or Eigen::Lower.
22 * @param b Right hand side matrix or vector.
23 * @return x = b A^-1, solution of the linear system.
24 * @throws std::domain_error if A is not square or the rows of b don't
25 * match the size of A.
26 */
27 template <Eigen::UpLoType TriView, typename EigMat1, typename EigMat2,
28 require_all_eigen_t<EigMat1, EigMat2>* = nullptr>
29 inline Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
30 EigMat1::RowsAtCompileTime, EigMat2::ColsAtCompileTime>
mdivide_right_tri(const EigMat1 & b,const EigMat2 & A)31 mdivide_right_tri(const EigMat1& b, const EigMat2& A) {
32 check_square("mdivide_right_tri", "A", A);
33 check_multiplicable("mdivide_right_tri", "b", b, "A", A);
34 if (TriView != Eigen::Lower && TriView != Eigen::Upper) {
35 throw_domain_error("mdivide_right_tri",
36 "triangular view must be Eigen::Lower or Eigen::Upper",
37 "", "");
38 }
39 if (A.rows() == 0) {
40 return {b.rows(), 0};
41 }
42
43 return Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
44 EigMat2::RowsAtCompileTime, EigMat2::ColsAtCompileTime>(
45 A)
46 .template triangularView<TriView>()
47 .transpose()
48 .solve(
49 Eigen::Matrix<return_type_t<EigMat1, EigMat2>,
50 EigMat1::RowsAtCompileTime, EigMat1::ColsAtCompileTime>(
51 b)
52 .transpose())
53 .transpose();
54 }
55
56 } // namespace math
57 } // namespace stan
58
59 #endif
60