1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup sp_auxlib
20 //! @{
21 
22 
23 //! wrapper for accesing external functions in ARPACK and SuperLU
24 class sp_auxlib
25   {
26   public:
27 
28   enum form_type
29     {
30     form_none, form_lm, form_sm, form_lr, form_la, form_sr, form_li, form_si, form_sa, form_sigma
31     };
32 
33   inline static form_type interpret_form_str(const char* form_str);
34 
35   //
36   // eigs_sym() for real matrices
37 
38   template<typename eT, typename T1>
39   inline static bool eigs_sym(Col<eT>& eigval, Mat<eT>& eigvec, const SpBase<eT, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts);
40 
41   template<typename eT, typename T1>
42   inline static bool eigs_sym(Col<eT>& eigval, Mat<eT>& eigvec, const SpBase<eT, T1>& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts);
43 
44   template<typename eT>
45   inline static bool eigs_sym_newarp(Col<eT>& eigval, Mat<eT>& eigvec, const SpMat<eT>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts);
46 
47   template<typename eT>
48   inline static bool eigs_sym_newarp(Col<eT>& eigval, Mat<eT>& eigvec, const SpMat<eT>& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts);
49 
50   template<typename eT, bool use_sigma>
51   inline static bool eigs_sym_arpack(Col<eT>& eigval, Mat<eT>& eigvec, const SpMat<eT>& X, const uword n_eigvals, const form_type form_val, const eT sigma, const eigs_opts& opts);
52 
53   //
54   // eigs_gen() for real matrices
55 
56   template<typename T, typename T1>
57   inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase<T, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts);
58 
59   template<typename T, typename T1>
60   inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase<T, T1>& X, const uword n_eigvals, const std::complex<T> sigma, const eigs_opts& opts);
61 
62   template<typename T>
63   inline static bool eigs_gen_newarp(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpMat<T>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts);
64 
65   template<typename T, bool use_sigma>
66   inline static bool eigs_gen_arpack(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpMat<T>& X, const uword n_eigvals, const form_type form_val, const std::complex<T> sigma, const eigs_opts& opts);
67 
68   //
69   // eigs_gen() for complex matrices
70 
71   template<typename T, typename T1>
72   inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase< std::complex<T>, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts);
73 
74   template<typename T, typename T1>
75   inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase< std::complex<T>, T1>& X, const uword n_eigvals, const std::complex<T> sigma, const eigs_opts& opts);
76 
77   template<typename T, bool use_sigma>
78   inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpMat< std::complex<T> >& X, const uword n_eigvals, const form_type form_val, const std::complex<T> sigma, const eigs_opts& opts);
79 
80   //
81   // spsolve() via SuperLU
82 
83   template<typename T1, typename T2>
84   inline static bool spsolve_simple(Mat<typename T1::elem_type>& out, const SpBase<typename T1::elem_type, T1>& A, const Base<typename T1::elem_type, T2>& B, const superlu_opts& user_opts);
85 
86   template<typename T1, typename T2>
87   inline static bool spsolve_refine(Mat<typename T1::elem_type>& out, typename T1::pod_type& out_rcond, const SpBase<typename T1::elem_type, T1>& A, const Base<typename T1::elem_type, T2>& B, const superlu_opts& user_opts);
88 
89   // //
90   // // rcond() via SuperLU
91   //
92   // template<typename T1>
93   // sinline static typename T1::pod_type rcond(const SpBase<typename T1::elem_type, T1>& A);
94 
95   //
96   // support functions
97 
98   #if defined(ARMA_USE_SUPERLU)
99 
100     template<typename eT>
101     inline static typename get_pod_type<eT>::result norm1(superlu::SuperMatrix* A);
102 
103     template<typename eT>
104     inline static typename get_pod_type<eT>::result lu_rcond(superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type<eT>::result norm_val);
105 
106     inline static void set_superlu_opts(superlu::superlu_options_t& options, const superlu_opts& user_opts);
107 
108     template<typename eT>
109     inline static bool copy_to_supermatrix(superlu::SuperMatrix& out, const SpMat<eT>& A);
110 
111     template<typename eT>
112     inline static bool copy_to_supermatrix_with_shift(superlu::SuperMatrix& out, const SpMat<eT>& A, const eT shift);
113 
114     // // for debugging only
115     // template<typename eT>
116     // inline static void copy_to_spmat(SpMat<eT>& out, const superlu::SuperMatrix& A);
117 
118     template<typename eT>
119     inline static bool wrap_to_supermatrix(superlu::SuperMatrix& out, const Mat<eT>& A);
120 
121     inline static void destroy_supermatrix(superlu::SuperMatrix& out);
122 
123   #endif
124 
125 
126 
127   private:
128 
129   // calls arpack saupd()/naupd() because the code is so similar for each
130   // all of the extra variables are later used by seupd()/neupd(), but those
131   // functions are very different and we can't combine their code
132 
133   template<typename eT, typename T>
134   inline static void run_aupd_plain
135     (
136     const uword n_eigvals, char* which,
137     const SpMat<T>& X, const bool sym,
138     blas_int& n, eT& tol, blas_int& maxiter,
139     podarray<T>& resid, blas_int& ncv, podarray<T>& v, blas_int& ldv,
140     podarray<blas_int>& iparam, podarray<blas_int>& ipntr,
141     podarray<T>& workd, podarray<T>& workl, blas_int& lworkl, podarray<eT>& rwork,
142     blas_int& info
143     );
144 
145   template<typename eT, typename T>
146   inline static void run_aupd_shiftinvert
147     (
148     const uword n_eigvals, const T sigma,
149     const SpMat<T>& X, const bool sym,
150     blas_int& n, eT& tol, blas_int& maxiter,
151     podarray<T>& resid, blas_int& ncv, podarray<T>& v, blas_int& ldv,
152     podarray<blas_int>& iparam, podarray<blas_int>& ipntr,
153     podarray<T>& workd, podarray<T>& workl, blas_int& lworkl, podarray<eT>& rwork,
154     blas_int& info
155     );
156 
157 
158   template<typename eT>
159   inline static bool rudimentary_sym_check(const SpMat<eT>& X);
160 
161   template<typename T>
162   inline static bool rudimentary_sym_check(const SpMat< std::complex<T> >& X);
163   };
164 
165 
166 
167 #if defined(ARMA_USE_SUPERLU)
168 
169 class superlu_supermatrix_wrangler
170   {
171   private:
172 
173   bool used = false;
174 
175   arma_aligned superlu::SuperMatrix m;
176 
177   public:
178 
179   inline ~superlu_supermatrix_wrangler();
180   inline  superlu_supermatrix_wrangler();
181 
182   inline superlu_supermatrix_wrangler(const superlu_supermatrix_wrangler&) = delete;
183   inline void operator=              (const superlu_supermatrix_wrangler&) = delete;
184 
185   inline superlu::SuperMatrix& get_ref();
186   inline superlu::SuperMatrix* get_ptr();
187   };
188 
189 
190 class superlu_stat_wrangler
191   {
192   private:
193 
194   arma_aligned superlu::SuperLUStat_t stat;
195 
196   public:
197 
198   inline ~superlu_stat_wrangler();
199   inline  superlu_stat_wrangler();
200 
201   inline superlu_stat_wrangler(const superlu_stat_wrangler&) = delete;
202   inline void operator=       (const superlu_stat_wrangler&) = delete;
203 
204   inline superlu::SuperLUStat_t* get_ptr();
205   };
206 
207 
208 template<typename eT>
209 class superlu_array_wrangler
210   {
211   private:
212 
213   arma_aligned eT* mem = nullptr;
214 
215   public:
216 
217   inline ~superlu_array_wrangler();
218   inline  superlu_array_wrangler(const uword n_elem);
219 
220   inline superlu_array_wrangler()                              = delete;
221   inline superlu_array_wrangler(const superlu_array_wrangler&) = delete;
222   inline void operator=        (const superlu_array_wrangler&) = delete;
223 
224   inline eT* get_ptr();
225   };
226 
227 #endif
228 
229 
230 
231 //! @}
232 
233