1 // Copyright (c) 2017-2021, Lawrence Livermore National Security, LLC and
2 // other Axom Project Developers. See the top-level LICENSE file for details.
3 //
4 // SPDX-License-Identifier: (BSD-3-Clause)
5
6 #ifndef AXOM_NUMERICS_LINEAR_SOLVE_HPP_
7 #define AXOM_NUMERICS_LINEAR_SOLVE_HPP_
8
9 #include "axom/core/numerics/Determinants.hpp" // for Determinants
10 #include "axom/core/numerics/LU.hpp" // for lu_decompose()/lu_solve()
11 #include "axom/core/numerics/Matrix.hpp" // for Matrix
12
13 // C/C++ includes
14 #include <cassert> // for assert()
15
16 namespace axom
17 {
18 namespace numerics
19 {
20 /*!
21 * \brief Solves a linear system of the form \f$ Ax=b \f$.
22 *
23 * \param [in] A a square input matrix
24 * \param [in] b the right-hand side
25 * \param [out] x the solution vector (computed)
26 * \return rc return value, 0 if the solve is successful.
27 *
28 * \pre A.isSquare() == true
29 * \pre b != nullptr
30 * \pre x != nullptr
31 *
32 * \note The input matrix is destroyed (modified) in the process.
33 */
34 template <typename T>
35 int linear_solve(Matrix<T>& A, const T* b, T* x);
36
37 } /* end namespace numerics */
38 } /* end namespace axom */
39
40 //------------------------------------------------------------------------------
41 // Implementation
42 //------------------------------------------------------------------------------
43 namespace axom
44 {
45 namespace numerics
46 {
47 template <typename T>
linear_solve(Matrix<T> & A,const T * b,T * x)48 int linear_solve(Matrix<T>& A, const T* b, T* x)
49 {
50 assert("pre: input matrix must be square" && A.isSquare());
51 assert("pre: solution vector is null" && (x != nullptr));
52 assert("pre: right-hand side vector is null" && (b != nullptr));
53
54 if(!A.isSquare())
55 {
56 return LU_NONSQUARE_MATRIX;
57 }
58
59 int N = A.getNumColumns();
60
61 if(N == 1)
62 {
63 if(utilities::isNearlyEqual(A(0, 0), 0.0))
64 {
65 return -1;
66 }
67
68 x[0] = b[0] / A(0, 0);
69 }
70 else if(N == 2)
71 {
72 // trivial solve
73 T det = numerics::determinant(A);
74
75 if(utilities::isNearlyEqual(det, 0.0))
76 {
77 return -1;
78 }
79
80 T invdet = 1 / det;
81 x[0] = (A(1, 1) * b[0] - A(0, 1) * b[1]) * invdet;
82 x[1] = (-A(1, 0) * b[0] + A(0, 0) * b[1]) * invdet;
83 }
84 else
85 {
86 // non-trivial system, use LU
87 int* pivots = new int[N];
88
89 int rc = lu_decompose(A, pivots);
90 if(rc == LU_SUCCESS)
91 {
92 rc = lu_solve(A, pivots, b, x);
93 }
94
95 delete[] pivots;
96 if(rc != LU_SUCCESS)
97 {
98 return -1;
99 }
100 }
101
102 return 0;
103 }
104
105 } /* end namespace numerics */
106 } /* end namespace axom */
107
108 #endif
109