1 #pragma once 2 #include <anyode/anyode_util.hpp> 3 #include <anyode/anyode_decomposition.hpp> 4 5 BEGIN_NAMESPACE(AnyODE) 6 template<typename Real_t = double> 7 struct BandedLU : public DecompositionBase<Real_t> { // operates inplace 8 BandedMatrix<Real_t> * m_view; 9 buffer_t<int> m_ipiv; BandedLUBandedLU10 BandedLU(BandedMatrix<Real_t> * view) : 11 m_view(view), 12 m_ipiv(buffer_factory<int>(view->m_nr)) 13 {} factorizeBandedLU14 int factorize() override final { 15 int info; 16 constexpr gbtrf_callback<Real_t> gbtrf{}; 17 gbtrf(&(m_view->m_nr), &(m_view->m_nc), &(m_view->m_kl), &(m_view->m_ku), m_view->m_data, 18 &(m_view->m_ld), buffer_get_raw_ptr(m_ipiv), &info); 19 return info; 20 } solveBandedLU21 int solve(const Real_t * const b, Real_t * const x) override final { 22 char trans = 'N'; 23 int nrhs = 1; 24 int info; 25 std::copy(b, b + m_view->m_nr, x); 26 constexpr gbtrs_callback<Real_t> gbtrs{}; 27 gbtrs(&trans, &(m_view->m_nr), &(m_view->m_kl), &(m_view->m_ku), &nrhs, m_view->m_data, 28 &(m_view->m_ld), buffer_get_raw_ptr(m_ipiv), x, &(m_view->m_nr), &info); 29 return info; 30 } 31 }; 32 33 template<typename Real_t = double> 34 struct SVD : public DecompositionBase<Real_t> { 35 // SVD_callbacks<Real_t> m_cbs; 36 DenseMatrix<Real_t> * m_view; 37 buffer_t<Real_t> m_s; 38 int m_ldu; 39 buffer_t<Real_t> m_u; 40 int m_ldvt; 41 buffer_t<Real_t> m_vt; 42 buffer_t<Real_t> m_work; 43 int m_lwork = -1; // Query 44 Real_t m_condition_number = -1; 45 SVDSVD46 SVD(DenseMatrix<Real_t> * view) : 47 m_view(view), m_s(buffer_factory<Real_t>(std::min(view->m_nr, view->m_nc))), 48 m_ldu(view->m_nr), m_u(buffer_factory<Real_t>(m_ldu*(view->m_nr))), 49 m_ldvt(view->m_nc), m_vt(buffer_factory<Real_t>(m_ldvt*(view->m_nc))) 50 { 51 int info; 52 Real_t optim_work_size; 53 char mode = 'A'; 54 constexpr gesvd_callback<Real_t> gesvd{}; 55 gesvd(&mode, &mode, &(m_view->m_nr), &(m_view->m_nc), m_view->m_data, &(m_view->m_ld), 56 buffer_get_raw_ptr(m_s), buffer_get_raw_ptr(m_u), &m_ldu, 57 buffer_get_raw_ptr(m_vt), &m_ldvt, &optim_work_size, &m_lwork, &info); 58 m_lwork = static_cast<int>(optim_work_size); 59 m_work = buffer_factory<Real_t>(m_lwork); 60 } factorizeSVD61 int factorize() override final { 62 int info; 63 char mode = 'A'; 64 constexpr gesvd_callback<Real_t> gesvd{}; 65 gesvd(&mode, &mode, &(m_view->m_nr), &(m_view->m_nc), m_view->m_data, &(m_view->m_ld), 66 buffer_get_raw_ptr(m_s), buffer_get_raw_ptr(m_u), &m_ldu, 67 buffer_get_raw_ptr(m_vt), &m_ldvt, buffer_get_raw_ptr(m_work), &m_lwork, &info); 68 m_condition_number = std::fabs(m_s[0]/m_s[std::min(m_view->m_nr, m_view->m_nc) - 1]); 69 return info; 70 } solveSVD71 int solve(const Real_t* const b, Real_t * const x) override final { 72 Real_t alpha=1, beta=0; 73 int incx=1, incy=1; 74 char trans = 'T'; 75 auto y1 = buffer_factory<Real_t>(m_view->m_nr); 76 constexpr gemv_callback<Real_t> gemv{}; 77 gemv(&trans, &(m_view->m_nr), &(m_view->m_nr), &alpha, buffer_get_raw_ptr(m_u), &(m_ldu), 78 const_cast<Real_t*>(b), &incx, &beta, buffer_get_raw_ptr(y1), &incy); 79 for (int i=0; i < m_view->m_nr; ++i) 80 y1[i] /= m_s[i]; 81 gemv(&trans, &(m_view->m_nc), &(m_view->m_nc), &alpha, buffer_get_raw_ptr(m_vt), &m_ldvt, 82 buffer_get_raw_ptr(y1), &incx, &beta, x, &incy); 83 return 0; 84 } 85 86 }; 87 END_NAMESPACE(AnyODE) 88