1 #pragma once
2 
3 #include <cmath>
4 
5 #include "anyode/anyode_matrix.hpp"
6 #include "anyode/anyode_buffer.hpp"
7 
8 #if ANYODE_NO_LAPACK == 1
9 #include "anyode/anyode_blasless.hpp"
10 #else
11 #include "anyode/anyode_blas_lapack.hpp"
12 #endif
13 
14 namespace AnyODE {
15 
16     template<typename Real_t>
17     struct DecompositionBase {
~DecompositionBaseAnyODE::DecompositionBase18         virtual ~DecompositionBase() {};
19         virtual int factorize() = 0;
20         virtual int solve(const Real_t * const, Real_t * const) = 0;
21     };
22 
23     template<typename Real_t = double>
24     struct DenseLU : public DecompositionBase<Real_t> {
25         // DenseLU_callbacks<Real_t> m_cbs;
26         DenseMatrix<Real_t> * m_view;
27         buffer_t<int> m_ipiv;
28 
DenseLUAnyODE::DenseLU29         DenseLU(DenseMatrix<Real_t> * view) :
30             m_view(view),
31             m_ipiv(buffer_factory<int>(view->m_nr))
32         {}
factorizeAnyODE::DenseLU33         int factorize() override final {
34             int info;
35             constexpr getrf_callback<Real_t> getrf{};
36             getrf(&(m_view->m_nr), &(m_view->m_nc), m_view->m_data, &(m_view->m_ld),
37                           buffer_get_raw_ptr(m_ipiv), &info);
38             return info;
39         }
solveAnyODE::DenseLU40         int solve(const Real_t * const b, Real_t * const x) override final {
41             char trans = 'N';
42             int nrhs = 1;
43             int info;
44             std::copy(b, b + m_view->m_nr, x);
45             constexpr getrs_callback<Real_t> getrs{};
46             getrs(&trans, &(m_view->m_nr), &nrhs, m_view->m_data, &(m_view->m_ld),
47                           buffer_get_raw_ptr(m_ipiv), x, &(m_view->m_nr), &info);
48             return info;
49         }
50     };
51 
52     template<typename Real_t = double>
53     struct DiagonalInv : public DecompositionBase<Real_t> {
54         DiagonalMatrix<Real_t> * m_view;
DiagonalInvAnyODE::DiagonalInv55         DiagonalInv(DiagonalMatrix<Real_t> * view) : m_view(view)
56         {
57         }
factorizeAnyODE::DiagonalInv58         int factorize() final {
59             for (int i=0; i < m_view->m_nc; ++i)
60                 m_view->m_data[i] = 1/m_view->m_data[i];
61             return 0;
62         }
solveAnyODE::DiagonalInv63         int solve(const Real_t * const b, Real_t * const x) final {
64             for (int i=0; i < m_view->m_nc; ++i)
65                 x[i] = m_view->m_data[i]*b[i];
66             return 0;
67         }
68     };
69 
70 }
71