1 // Copyright (C) 2018-2021 Yixuan Qiu <yixuan.qiu@cos.name> 2 // 3 // This Source Code Form is subject to the terms of the Mozilla 4 // Public License v. 2.0. If a copy of the MPL was not distributed 5 // with this file, You can obtain one at https://mozilla.org/MPL/2.0/. 6 7 #ifndef SPECTRA_PARTIAL_SVD_SOLVER_H 8 #define SPECTRA_PARTIAL_SVD_SOLVER_H 9 10 #include <Eigen/Core> 11 #include "../SymEigsSolver.h" 12 13 namespace Spectra { 14 15 // Abstract class for matrix operation 16 template <typename Scalar_> 17 class SVDMatOp 18 { 19 public: 20 using Scalar = Scalar_; 21 22 private: 23 using Index = Eigen::Index; 24 25 public: 26 virtual Index rows() const = 0; 27 virtual Index cols() const = 0; 28 29 // y_out = A' * A * x_in or y_out = A * A' * x_in 30 virtual void perform_op(const Scalar* x_in, Scalar* y_out) const = 0; 31 ~SVDMatOp()32 virtual ~SVDMatOp() {} 33 }; 34 35 // Operation of a tall matrix in SVD 36 // We compute the eigenvalues of A' * A 37 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...> 38 template <typename Scalar, typename MatrixType> 39 class SVDTallMatOp : public SVDMatOp<Scalar> 40 { 41 private: 42 using Index = Eigen::Index; 43 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>; 44 using MapConstVec = Eigen::Map<const Vector>; 45 using MapVec = Eigen::Map<Vector>; 46 using ConstGenericMatrix = const Eigen::Ref<const MatrixType>; 47 48 ConstGenericMatrix m_mat; 49 const Index m_dim; 50 mutable Vector m_cache; 51 52 public: 53 // Constructor SVDTallMatOp(ConstGenericMatrix & mat)54 SVDTallMatOp(ConstGenericMatrix& mat) : 55 m_mat(mat), 56 m_dim((std::min)(mat.rows(), mat.cols())), 57 m_cache(mat.rows()) 58 {} 59 60 // These are the rows and columns of A' * A rows()61 Index rows() const override { return m_dim; } cols()62 Index cols() const override { return m_dim; } 63 64 // y_out = A' * A * x_in perform_op(const Scalar * x_in,Scalar * y_out)65 void perform_op(const Scalar* x_in, Scalar* y_out) const override 66 { 67 MapConstVec x(x_in, m_mat.cols()); 68 MapVec y(y_out, m_mat.cols()); 69 m_cache.noalias() = m_mat * x; 70 y.noalias() = m_mat.transpose() * m_cache; 71 } 72 }; 73 74 // Operation of a wide matrix in SVD 75 // We compute the eigenvalues of A * A' 76 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...> 77 template <typename Scalar, typename MatrixType> 78 class SVDWideMatOp : public SVDMatOp<Scalar> 79 { 80 private: 81 using Index = Eigen::Index; 82 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>; 83 using MapConstVec = Eigen::Map<const Vector>; 84 using MapVec = Eigen::Map<Vector>; 85 using ConstGenericMatrix = const Eigen::Ref<const MatrixType>; 86 87 ConstGenericMatrix m_mat; 88 const Index m_dim; 89 mutable Vector m_cache; 90 91 public: 92 // Constructor SVDWideMatOp(ConstGenericMatrix & mat)93 SVDWideMatOp(ConstGenericMatrix& mat) : 94 m_mat(mat), 95 m_dim((std::min)(mat.rows(), mat.cols())), 96 m_cache(mat.cols()) 97 {} 98 99 // These are the rows and columns of A * A' rows()100 Index rows() const override { return m_dim; } cols()101 Index cols() const override { return m_dim; } 102 103 // y_out = A * A' * x_in perform_op(const Scalar * x_in,Scalar * y_out)104 void perform_op(const Scalar* x_in, Scalar* y_out) const override 105 { 106 MapConstVec x(x_in, m_mat.rows()); 107 MapVec y(y_out, m_mat.rows()); 108 m_cache.noalias() = m_mat.transpose() * x; 109 y.noalias() = m_mat * m_cache; 110 } 111 }; 112 113 // Partial SVD solver 114 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...> 115 template <typename MatrixType = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> 116 class PartialSVDSolver 117 { 118 private: 119 using Scalar = typename MatrixType::Scalar; 120 using Index = Eigen::Index; 121 using Matrix = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>; 122 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>; 123 using ConstGenericMatrix = const Eigen::Ref<const MatrixType>; 124 125 ConstGenericMatrix m_mat; 126 const Index m_m; 127 const Index m_n; 128 SVDMatOp<Scalar>* m_op; 129 SymEigsSolver<SVDMatOp<Scalar>>* m_eigs; 130 Index m_nconv; 131 Matrix m_evecs; 132 133 public: 134 // Constructor PartialSVDSolver(ConstGenericMatrix & mat,Index ncomp,Index ncv)135 PartialSVDSolver(ConstGenericMatrix& mat, Index ncomp, Index ncv) : 136 m_mat(mat), m_m(mat.rows()), m_n(mat.cols()), m_evecs(0, 0) 137 { 138 // Determine the matrix type, tall or wide 139 if (m_m > m_n) 140 { 141 m_op = new SVDTallMatOp<Scalar, MatrixType>(mat); 142 } 143 else 144 { 145 m_op = new SVDWideMatOp<Scalar, MatrixType>(mat); 146 } 147 148 // Solver object 149 m_eigs = new SymEigsSolver<SVDMatOp<Scalar>>(*m_op, ncomp, ncv); 150 } 151 152 // Destructor ~PartialSVDSolver()153 virtual ~PartialSVDSolver() 154 { 155 delete m_eigs; 156 delete m_op; 157 } 158 159 // Computation 160 Index compute(Index maxit = 1000, Scalar tol = 1e-10) 161 { 162 m_eigs->init(); 163 m_nconv = m_eigs->compute(SortRule::LargestAlge, maxit, tol); 164 165 return m_nconv; 166 } 167 168 // The converged singular values singular_values()169 Vector singular_values() const 170 { 171 Vector svals = m_eigs->eigenvalues().cwiseSqrt(); 172 173 return svals; 174 } 175 176 // The converged left singular vectors matrix_U(Index nu)177 Matrix matrix_U(Index nu) 178 { 179 if (m_evecs.cols() < 1) 180 { 181 m_evecs = m_eigs->eigenvectors(); 182 } 183 nu = (std::min)(nu, m_nconv); 184 if (m_m <= m_n) 185 { 186 return m_evecs.leftCols(nu); 187 } 188 189 return m_mat * (m_evecs.leftCols(nu).array().rowwise() / m_eigs->eigenvalues().head(nu).transpose().array().sqrt()).matrix(); 190 } 191 192 // The converged right singular vectors matrix_V(Index nv)193 Matrix matrix_V(Index nv) 194 { 195 if (m_evecs.cols() < 1) 196 { 197 m_evecs = m_eigs->eigenvectors(); 198 } 199 nv = (std::min)(nv, m_nconv); 200 if (m_m > m_n) 201 { 202 return m_evecs.leftCols(nv); 203 } 204 205 return m_mat.transpose() * (m_evecs.leftCols(nv).array().rowwise() / m_eigs->eigenvalues().head(nv).transpose().array().sqrt()).matrix(); 206 } 207 }; 208 209 } // namespace Spectra 210 211 #endif // SPECTRA_PARTIAL_SVD_SOLVER_H 212