1 #pragma once
2 #include <anyode/anyode_util.hpp>
3 #include <anyode/anyode_decomposition.hpp>
4 
5 BEGIN_NAMESPACE(AnyODE)
6 template<typename Real_t = double>
7 struct BandedLU : public DecompositionBase<Real_t> { // operates inplace
8     BandedMatrix<Real_t> * m_view;
9     buffer_t<int> m_ipiv;
BandedLUBandedLU10     BandedLU(BandedMatrix<Real_t> * view) :
11         m_view(view),
12         m_ipiv(buffer_factory<int>(view->m_nr))
13     {}
factorizeBandedLU14     int factorize() override final {
15         int info;
16         constexpr gbtrf_callback<Real_t> gbtrf{};
17         gbtrf(&(m_view->m_nr), &(m_view->m_nc), &(m_view->m_kl), &(m_view->m_ku), m_view->m_data,
18               &(m_view->m_ld), buffer_get_raw_ptr(m_ipiv), &info);
19         return info;
20     }
solveBandedLU21     int solve(const Real_t * const b, Real_t * const x) override final {
22         char trans = 'N';
23         int nrhs = 1;
24         int info;
25         std::copy(b, b + m_view->m_nr, x);
26         constexpr gbtrs_callback<Real_t> gbtrs{};
27         gbtrs(&trans, &(m_view->m_nr), &(m_view->m_kl), &(m_view->m_ku), &nrhs, m_view->m_data,
28               &(m_view->m_ld), buffer_get_raw_ptr(m_ipiv), x, &(m_view->m_nr), &info);
29         return info;
30     }
31 };
32 
33 template<typename Real_t = double>
34 struct SVD : public DecompositionBase<Real_t> {
35     // SVD_callbacks<Real_t> m_cbs;
36     DenseMatrix<Real_t> * m_view;
37     buffer_t<Real_t> m_s;
38     int m_ldu;
39     buffer_t<Real_t> m_u;
40     int m_ldvt;
41     buffer_t<Real_t> m_vt;
42     buffer_t<Real_t> m_work;
43     int m_lwork = -1; // Query
44     Real_t m_condition_number = -1;
45 
SVDSVD46     SVD(DenseMatrix<Real_t> * view) :
47         m_view(view), m_s(buffer_factory<Real_t>(std::min(view->m_nr, view->m_nc))),
48         m_ldu(view->m_nr), m_u(buffer_factory<Real_t>(m_ldu*(view->m_nr))),
49         m_ldvt(view->m_nc), m_vt(buffer_factory<Real_t>(m_ldvt*(view->m_nc)))
50     {
51         int info;
52         Real_t optim_work_size;
53         char mode = 'A';
54         constexpr gesvd_callback<Real_t> gesvd{};
55         gesvd(&mode, &mode, &(m_view->m_nr), &(m_view->m_nc), m_view->m_data, &(m_view->m_ld),
56               buffer_get_raw_ptr(m_s), buffer_get_raw_ptr(m_u), &m_ldu,
57               buffer_get_raw_ptr(m_vt), &m_ldvt, &optim_work_size, &m_lwork, &info);
58         m_lwork = static_cast<int>(optim_work_size);
59         m_work = buffer_factory<Real_t>(m_lwork);
60     }
factorizeSVD61     int factorize() override final {
62         int info;
63         char mode = 'A';
64         constexpr gesvd_callback<Real_t> gesvd{};
65         gesvd(&mode, &mode, &(m_view->m_nr), &(m_view->m_nc), m_view->m_data, &(m_view->m_ld),
66               buffer_get_raw_ptr(m_s), buffer_get_raw_ptr(m_u), &m_ldu,
67               buffer_get_raw_ptr(m_vt), &m_ldvt, buffer_get_raw_ptr(m_work), &m_lwork, &info);
68         m_condition_number = std::fabs(m_s[0]/m_s[std::min(m_view->m_nr, m_view->m_nc) - 1]);
69         return info;
70     }
solveSVD71     int solve(const Real_t* const b, Real_t * const x) override final {
72         Real_t alpha=1, beta=0;
73         int incx=1, incy=1;
74         char trans = 'T';
75         auto y1 = buffer_factory<Real_t>(m_view->m_nr);
76         constexpr gemv_callback<Real_t> gemv{};
77         gemv(&trans, &(m_view->m_nr), &(m_view->m_nr), &alpha, buffer_get_raw_ptr(m_u), &(m_ldu),
78              const_cast<Real_t*>(b), &incx, &beta, buffer_get_raw_ptr(y1), &incy);
79         for (int i=0; i < m_view->m_nr; ++i)
80             y1[i] /= m_s[i];
81         gemv(&trans, &(m_view->m_nc), &(m_view->m_nc), &alpha, buffer_get_raw_ptr(m_vt), &m_ldvt,
82              buffer_get_raw_ptr(y1), &incx, &beta, x, &incy);
83         return 0;
84     }
85 
86 };
87 END_NAMESPACE(AnyODE)
88