1 // Copyright (C) 2018-2021 Yixuan Qiu <yixuan.qiu@cos.name>
2 //
3 // This Source Code Form is subject to the terms of the Mozilla
4 // Public License v. 2.0. If a copy of the MPL was not distributed
5 // with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
6 
7 #ifndef SPECTRA_PARTIAL_SVD_SOLVER_H
8 #define SPECTRA_PARTIAL_SVD_SOLVER_H
9 
10 #include <Eigen/Core>
11 #include "../SymEigsSolver.h"
12 
13 namespace Spectra {
14 
15 // Abstract class for matrix operation
16 template <typename Scalar_>
17 class SVDMatOp
18 {
19 public:
20     using Scalar = Scalar_;
21 
22 private:
23     using Index = Eigen::Index;
24 
25 public:
26     virtual Index rows() const = 0;
27     virtual Index cols() const = 0;
28 
29     // y_out = A' * A * x_in or y_out = A * A' * x_in
30     virtual void perform_op(const Scalar* x_in, Scalar* y_out) const = 0;
31 
~SVDMatOp()32     virtual ~SVDMatOp() {}
33 };
34 
35 // Operation of a tall matrix in SVD
36 // We compute the eigenvalues of A' * A
37 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
38 template <typename Scalar, typename MatrixType>
39 class SVDTallMatOp : public SVDMatOp<Scalar>
40 {
41 private:
42     using Index = Eigen::Index;
43     using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
44     using MapConstVec = Eigen::Map<const Vector>;
45     using MapVec = Eigen::Map<Vector>;
46     using ConstGenericMatrix = const Eigen::Ref<const MatrixType>;
47 
48     ConstGenericMatrix m_mat;
49     const Index m_dim;
50     mutable Vector m_cache;
51 
52 public:
53     // Constructor
SVDTallMatOp(ConstGenericMatrix & mat)54     SVDTallMatOp(ConstGenericMatrix& mat) :
55         m_mat(mat),
56         m_dim((std::min)(mat.rows(), mat.cols())),
57         m_cache(mat.rows())
58     {}
59 
60     // These are the rows and columns of A' * A
rows()61     Index rows() const override { return m_dim; }
cols()62     Index cols() const override { return m_dim; }
63 
64     // y_out = A' * A * x_in
perform_op(const Scalar * x_in,Scalar * y_out)65     void perform_op(const Scalar* x_in, Scalar* y_out) const override
66     {
67         MapConstVec x(x_in, m_mat.cols());
68         MapVec y(y_out, m_mat.cols());
69         m_cache.noalias() = m_mat * x;
70         y.noalias() = m_mat.transpose() * m_cache;
71     }
72 };
73 
74 // Operation of a wide matrix in SVD
75 // We compute the eigenvalues of A * A'
76 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
77 template <typename Scalar, typename MatrixType>
78 class SVDWideMatOp : public SVDMatOp<Scalar>
79 {
80 private:
81     using Index = Eigen::Index;
82     using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
83     using MapConstVec = Eigen::Map<const Vector>;
84     using MapVec = Eigen::Map<Vector>;
85     using ConstGenericMatrix = const Eigen::Ref<const MatrixType>;
86 
87     ConstGenericMatrix m_mat;
88     const Index m_dim;
89     mutable Vector m_cache;
90 
91 public:
92     // Constructor
SVDWideMatOp(ConstGenericMatrix & mat)93     SVDWideMatOp(ConstGenericMatrix& mat) :
94         m_mat(mat),
95         m_dim((std::min)(mat.rows(), mat.cols())),
96         m_cache(mat.cols())
97     {}
98 
99     // These are the rows and columns of A * A'
rows()100     Index rows() const override { return m_dim; }
cols()101     Index cols() const override { return m_dim; }
102 
103     // y_out = A * A' * x_in
perform_op(const Scalar * x_in,Scalar * y_out)104     void perform_op(const Scalar* x_in, Scalar* y_out) const override
105     {
106         MapConstVec x(x_in, m_mat.rows());
107         MapVec y(y_out, m_mat.rows());
108         m_cache.noalias() = m_mat.transpose() * x;
109         y.noalias() = m_mat * m_cache;
110     }
111 };
112 
113 // Partial SVD solver
114 // MatrixType is either Eigen::Matrix<Scalar, ...> or Eigen::SparseMatrix<Scalar, ...>
115 template <typename MatrixType = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>>
116 class PartialSVDSolver
117 {
118 private:
119     using Scalar = typename MatrixType::Scalar;
120     using Index = Eigen::Index;
121     using Matrix = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>;
122     using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
123     using ConstGenericMatrix = const Eigen::Ref<const MatrixType>;
124 
125     ConstGenericMatrix m_mat;
126     const Index m_m;
127     const Index m_n;
128     SVDMatOp<Scalar>* m_op;
129     SymEigsSolver<SVDMatOp<Scalar>>* m_eigs;
130     Index m_nconv;
131     Matrix m_evecs;
132 
133 public:
134     // Constructor
PartialSVDSolver(ConstGenericMatrix & mat,Index ncomp,Index ncv)135     PartialSVDSolver(ConstGenericMatrix& mat, Index ncomp, Index ncv) :
136         m_mat(mat), m_m(mat.rows()), m_n(mat.cols()), m_evecs(0, 0)
137     {
138         // Determine the matrix type, tall or wide
139         if (m_m > m_n)
140         {
141             m_op = new SVDTallMatOp<Scalar, MatrixType>(mat);
142         }
143         else
144         {
145             m_op = new SVDWideMatOp<Scalar, MatrixType>(mat);
146         }
147 
148         // Solver object
149         m_eigs = new SymEigsSolver<SVDMatOp<Scalar>>(*m_op, ncomp, ncv);
150     }
151 
152     // Destructor
~PartialSVDSolver()153     virtual ~PartialSVDSolver()
154     {
155         delete m_eigs;
156         delete m_op;
157     }
158 
159     // Computation
160     Index compute(Index maxit = 1000, Scalar tol = 1e-10)
161     {
162         m_eigs->init();
163         m_nconv = m_eigs->compute(SortRule::LargestAlge, maxit, tol);
164 
165         return m_nconv;
166     }
167 
168     // The converged singular values
singular_values()169     Vector singular_values() const
170     {
171         Vector svals = m_eigs->eigenvalues().cwiseSqrt();
172 
173         return svals;
174     }
175 
176     // The converged left singular vectors
matrix_U(Index nu)177     Matrix matrix_U(Index nu)
178     {
179         if (m_evecs.cols() < 1)
180         {
181             m_evecs = m_eigs->eigenvectors();
182         }
183         nu = (std::min)(nu, m_nconv);
184         if (m_m <= m_n)
185         {
186             return m_evecs.leftCols(nu);
187         }
188 
189         return m_mat * (m_evecs.leftCols(nu).array().rowwise() / m_eigs->eigenvalues().head(nu).transpose().array().sqrt()).matrix();
190     }
191 
192     // The converged right singular vectors
matrix_V(Index nv)193     Matrix matrix_V(Index nv)
194     {
195         if (m_evecs.cols() < 1)
196         {
197             m_evecs = m_eigs->eigenvectors();
198         }
199         nv = (std::min)(nv, m_nconv);
200         if (m_m > m_n)
201         {
202             return m_evecs.leftCols(nv);
203         }
204 
205         return m_mat.transpose() * (m_evecs.leftCols(nv).array().rowwise() / m_eigs->eigenvalues().head(nv).transpose().array().sqrt()).matrix();
206     }
207 };
208 
209 }  // namespace Spectra
210 
211 #endif  // SPECTRA_PARTIAL_SVD_SOLVER_H
212