1 #pragma once 2 3 #include <cmath> 4 5 #include "anyode/anyode_matrix.hpp" 6 #include "anyode/anyode_buffer.hpp" 7 8 #if ANYODE_NO_LAPACK == 1 9 #include "anyode/anyode_blasless.hpp" 10 #else 11 #include "anyode/anyode_blas_lapack.hpp" 12 #endif 13 14 namespace AnyODE { 15 16 template<typename Real_t> 17 struct DecompositionBase { ~DecompositionBaseAnyODE::DecompositionBase18 virtual ~DecompositionBase() {}; 19 virtual int factorize() = 0; 20 virtual int solve(const Real_t * const, Real_t * const) = 0; 21 }; 22 23 template<typename Real_t = double> 24 struct DenseLU : public DecompositionBase<Real_t> { 25 // DenseLU_callbacks<Real_t> m_cbs; 26 DenseMatrix<Real_t> * m_view; 27 buffer_t<int> m_ipiv; 28 DenseLUAnyODE::DenseLU29 DenseLU(DenseMatrix<Real_t> * view) : 30 m_view(view), 31 m_ipiv(buffer_factory<int>(view->m_nr)) 32 {} factorizeAnyODE::DenseLU33 int factorize() override final { 34 int info; 35 constexpr getrf_callback<Real_t> getrf{}; 36 getrf(&(m_view->m_nr), &(m_view->m_nc), m_view->m_data, &(m_view->m_ld), 37 buffer_get_raw_ptr(m_ipiv), &info); 38 return info; 39 } solveAnyODE::DenseLU40 int solve(const Real_t * const b, Real_t * const x) override final { 41 char trans = 'N'; 42 int nrhs = 1; 43 int info; 44 std::copy(b, b + m_view->m_nr, x); 45 constexpr getrs_callback<Real_t> getrs{}; 46 getrs(&trans, &(m_view->m_nr), &nrhs, m_view->m_data, &(m_view->m_ld), 47 buffer_get_raw_ptr(m_ipiv), x, &(m_view->m_nr), &info); 48 return info; 49 } 50 }; 51 52 template<typename Real_t = double> 53 struct DiagonalInv : public DecompositionBase<Real_t> { 54 DiagonalMatrix<Real_t> * m_view; DiagonalInvAnyODE::DiagonalInv55 DiagonalInv(DiagonalMatrix<Real_t> * view) : m_view(view) 56 { 57 } factorizeAnyODE::DiagonalInv58 int factorize() final { 59 for (int i=0; i < m_view->m_nc; ++i) 60 m_view->m_data[i] = 1/m_view->m_data[i]; 61 return 0; 62 } solveAnyODE::DiagonalInv63 int solve(const Real_t * const b, Real_t * const x) final { 64 for (int i=0; i < m_view->m_nc; ++i) 65 x[i] = m_view->m_data[i]*b[i]; 66 return 0; 67 } 68 }; 69 70 } 71