1 // Copyright (c) 2017-2021, Lawrence Livermore National Security, LLC and
2 // other Axom Project Developers. See the top-level LICENSE file for details.
3 //
4 // SPDX-License-Identifier: (BSD-3-Clause)
5 
6 #ifndef AXOM_NUMERICS_LINEAR_SOLVE_HPP_
7 #define AXOM_NUMERICS_LINEAR_SOLVE_HPP_
8 
9 #include "axom/core/numerics/Determinants.hpp"  // for Determinants
10 #include "axom/core/numerics/LU.hpp"            // for lu_decompose()/lu_solve()
11 #include "axom/core/numerics/Matrix.hpp"        // for Matrix
12 
13 // C/C++ includes
14 #include <cassert>  // for assert()
15 
16 namespace axom
17 {
18 namespace numerics
19 {
20 /*!
21  * \brief Solves a linear system of the form \f$ Ax=b \f$.
22  *
23  * \param [in] A a square input matrix
24  * \param [in] b the right-hand side
25  * \param [out] x the solution vector (computed)
26  * \return rc return value, 0 if the solve is successful.
27  *
28  * \pre A.isSquare() == true
29  * \pre b != nullptr
30  * \pre x != nullptr
31  *
32  * \note The input matrix is destroyed (modified) in the process.
33  */
34 template <typename T>
35 int linear_solve(Matrix<T>& A, const T* b, T* x);
36 
37 } /* end namespace numerics */
38 } /* end namespace axom */
39 
40 //------------------------------------------------------------------------------
41 // Implementation
42 //------------------------------------------------------------------------------
43 namespace axom
44 {
45 namespace numerics
46 {
47 template <typename T>
linear_solve(Matrix<T> & A,const T * b,T * x)48 int linear_solve(Matrix<T>& A, const T* b, T* x)
49 {
50   assert("pre: input matrix must be square" && A.isSquare());
51   assert("pre: solution vector is null" && (x != nullptr));
52   assert("pre: right-hand side vector is null" && (b != nullptr));
53 
54   if(!A.isSquare())
55   {
56     return LU_NONSQUARE_MATRIX;
57   }
58 
59   int N = A.getNumColumns();
60 
61   if(N == 1)
62   {
63     if(utilities::isNearlyEqual(A(0, 0), 0.0))
64     {
65       return -1;
66     }
67 
68     x[0] = b[0] / A(0, 0);
69   }
70   else if(N == 2)
71   {
72     // trivial solve
73     T det = numerics::determinant(A);
74 
75     if(utilities::isNearlyEqual(det, 0.0))
76     {
77       return -1;
78     }
79 
80     T invdet = 1 / det;
81     x[0] = (A(1, 1) * b[0] - A(0, 1) * b[1]) * invdet;
82     x[1] = (-A(1, 0) * b[0] + A(0, 0) * b[1]) * invdet;
83   }
84   else
85   {
86     // non-trivial system, use LU
87     int* pivots = new int[N];
88 
89     int rc = lu_decompose(A, pivots);
90     if(rc == LU_SUCCESS)
91     {
92       rc = lu_solve(A, pivots, b, x);
93     }
94 
95     delete[] pivots;
96     if(rc != LU_SUCCESS)
97     {
98       return -1;
99     }
100   }
101 
102   return 0;
103 }
104 
105 } /* end namespace numerics */
106 } /* end namespace axom */
107 
108 #endif
109