1 // SPDX-License-Identifier: Apache-2.0 2 // 3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) 4 // Copyright 2008-2016 National ICT Australia (NICTA) 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ------------------------------------------------------------------------ 17 18 19 //! \addtogroup sp_auxlib 20 //! @{ 21 22 23 //! wrapper for accesing external functions in ARPACK and SuperLU 24 class sp_auxlib 25 { 26 public: 27 28 enum form_type 29 { 30 form_none, form_lm, form_sm, form_lr, form_la, form_sr, form_li, form_si, form_sa, form_sigma 31 }; 32 33 inline static form_type interpret_form_str(const char* form_str); 34 35 // 36 // eigs_sym() for real matrices 37 38 template<typename eT, typename T1> 39 inline static bool eigs_sym(Col<eT>& eigval, Mat<eT>& eigvec, const SpBase<eT, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); 40 41 template<typename eT, typename T1> 42 inline static bool eigs_sym(Col<eT>& eigval, Mat<eT>& eigvec, const SpBase<eT, T1>& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts); 43 44 template<typename eT> 45 inline static bool eigs_sym_newarp(Col<eT>& eigval, Mat<eT>& eigvec, const SpMat<eT>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); 46 47 template<typename eT> 48 inline static bool eigs_sym_newarp(Col<eT>& eigval, Mat<eT>& eigvec, const SpMat<eT>& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts); 49 50 template<typename eT, bool use_sigma> 51 inline static bool eigs_sym_arpack(Col<eT>& eigval, Mat<eT>& eigvec, const SpMat<eT>& X, const uword n_eigvals, const form_type form_val, const eT sigma, const eigs_opts& opts); 52 53 // 54 // eigs_gen() for real matrices 55 56 template<typename T, typename T1> 57 inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase<T, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); 58 59 template<typename T, typename T1> 60 inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase<T, T1>& X, const uword n_eigvals, const std::complex<T> sigma, const eigs_opts& opts); 61 62 template<typename T> 63 inline static bool eigs_gen_newarp(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpMat<T>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); 64 65 template<typename T, bool use_sigma> 66 inline static bool eigs_gen_arpack(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpMat<T>& X, const uword n_eigvals, const form_type form_val, const std::complex<T> sigma, const eigs_opts& opts); 67 68 // 69 // eigs_gen() for complex matrices 70 71 template<typename T, typename T1> 72 inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase< std::complex<T>, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); 73 74 template<typename T, typename T1> 75 inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpBase< std::complex<T>, T1>& X, const uword n_eigvals, const std::complex<T> sigma, const eigs_opts& opts); 76 77 template<typename T, bool use_sigma> 78 inline static bool eigs_gen(Col< std::complex<T> >& eigval, Mat< std::complex<T> >& eigvec, const SpMat< std::complex<T> >& X, const uword n_eigvals, const form_type form_val, const std::complex<T> sigma, const eigs_opts& opts); 79 80 // 81 // spsolve() via SuperLU 82 83 template<typename T1, typename T2> 84 inline static bool spsolve_simple(Mat<typename T1::elem_type>& out, const SpBase<typename T1::elem_type, T1>& A, const Base<typename T1::elem_type, T2>& B, const superlu_opts& user_opts); 85 86 template<typename T1, typename T2> 87 inline static bool spsolve_refine(Mat<typename T1::elem_type>& out, typename T1::pod_type& out_rcond, const SpBase<typename T1::elem_type, T1>& A, const Base<typename T1::elem_type, T2>& B, const superlu_opts& user_opts); 88 89 // // 90 // // rcond() via SuperLU 91 // 92 // template<typename T1> 93 // sinline static typename T1::pod_type rcond(const SpBase<typename T1::elem_type, T1>& A); 94 95 // 96 // support functions 97 98 #if defined(ARMA_USE_SUPERLU) 99 100 template<typename eT> 101 inline static typename get_pod_type<eT>::result norm1(superlu::SuperMatrix* A); 102 103 template<typename eT> 104 inline static typename get_pod_type<eT>::result lu_rcond(superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type<eT>::result norm_val); 105 106 inline static void set_superlu_opts(superlu::superlu_options_t& options, const superlu_opts& user_opts); 107 108 template<typename eT> 109 inline static bool copy_to_supermatrix(superlu::SuperMatrix& out, const SpMat<eT>& A); 110 111 template<typename eT> 112 inline static bool copy_to_supermatrix_with_shift(superlu::SuperMatrix& out, const SpMat<eT>& A, const eT shift); 113 114 // // for debugging only 115 // template<typename eT> 116 // inline static void copy_to_spmat(SpMat<eT>& out, const superlu::SuperMatrix& A); 117 118 template<typename eT> 119 inline static bool wrap_to_supermatrix(superlu::SuperMatrix& out, const Mat<eT>& A); 120 121 inline static void destroy_supermatrix(superlu::SuperMatrix& out); 122 123 #endif 124 125 126 127 private: 128 129 // calls arpack saupd()/naupd() because the code is so similar for each 130 // all of the extra variables are later used by seupd()/neupd(), but those 131 // functions are very different and we can't combine their code 132 133 template<typename eT, typename T> 134 inline static void run_aupd_plain 135 ( 136 const uword n_eigvals, char* which, 137 const SpMat<T>& X, const bool sym, 138 blas_int& n, eT& tol, blas_int& maxiter, 139 podarray<T>& resid, blas_int& ncv, podarray<T>& v, blas_int& ldv, 140 podarray<blas_int>& iparam, podarray<blas_int>& ipntr, 141 podarray<T>& workd, podarray<T>& workl, blas_int& lworkl, podarray<eT>& rwork, 142 blas_int& info 143 ); 144 145 template<typename eT, typename T> 146 inline static void run_aupd_shiftinvert 147 ( 148 const uword n_eigvals, const T sigma, 149 const SpMat<T>& X, const bool sym, 150 blas_int& n, eT& tol, blas_int& maxiter, 151 podarray<T>& resid, blas_int& ncv, podarray<T>& v, blas_int& ldv, 152 podarray<blas_int>& iparam, podarray<blas_int>& ipntr, 153 podarray<T>& workd, podarray<T>& workl, blas_int& lworkl, podarray<eT>& rwork, 154 blas_int& info 155 ); 156 157 158 template<typename eT> 159 inline static bool rudimentary_sym_check(const SpMat<eT>& X); 160 161 template<typename T> 162 inline static bool rudimentary_sym_check(const SpMat< std::complex<T> >& X); 163 }; 164 165 166 167 #if defined(ARMA_USE_SUPERLU) 168 169 class superlu_supermatrix_wrangler 170 { 171 private: 172 173 bool used = false; 174 175 arma_aligned superlu::SuperMatrix m; 176 177 public: 178 179 inline ~superlu_supermatrix_wrangler(); 180 inline superlu_supermatrix_wrangler(); 181 182 inline superlu_supermatrix_wrangler(const superlu_supermatrix_wrangler&) = delete; 183 inline void operator= (const superlu_supermatrix_wrangler&) = delete; 184 185 inline superlu::SuperMatrix& get_ref(); 186 inline superlu::SuperMatrix* get_ptr(); 187 }; 188 189 190 class superlu_stat_wrangler 191 { 192 private: 193 194 arma_aligned superlu::SuperLUStat_t stat; 195 196 public: 197 198 inline ~superlu_stat_wrangler(); 199 inline superlu_stat_wrangler(); 200 201 inline superlu_stat_wrangler(const superlu_stat_wrangler&) = delete; 202 inline void operator= (const superlu_stat_wrangler&) = delete; 203 204 inline superlu::SuperLUStat_t* get_ptr(); 205 }; 206 207 208 template<typename eT> 209 class superlu_array_wrangler 210 { 211 private: 212 213 arma_aligned eT* mem = nullptr; 214 215 public: 216 217 inline ~superlu_array_wrangler(); 218 inline superlu_array_wrangler(const uword n_elem); 219 220 inline superlu_array_wrangler() = delete; 221 inline superlu_array_wrangler(const superlu_array_wrangler&) = delete; 222 inline void operator= (const superlu_array_wrangler&) = delete; 223 224 inline eT* get_ptr(); 225 }; 226 227 #endif 228 229 230 231 //! @} 232 233