1 // Copyright (C) 2019-2021 Yixuan Qiu <yixuan.qiu@cos.name>
2 //
3 // This Source Code Form is subject to the terms of the Mozilla
4 // Public License v. 2.0. If a copy of the MPL was not distributed
5 // with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
6 
7 #ifndef SPECTRA_BK_LDLT_H
8 #define SPECTRA_BK_LDLT_H
9 
10 #include <Eigen/Core>
11 #include <vector>
12 #include <stdexcept>
13 
14 #include "../Util/CompInfo.h"
15 
16 namespace Spectra {
17 
18 // Bunch-Kaufman LDLT decomposition
19 // References:
20 // 1. Bunch, J. R., & Kaufman, L. (1977). Some stable methods for calculating inertia and solving symmetric linear systems.
21 //    Mathematics of computation, 31(137), 163-179.
22 // 2. Golub, G. H., & Van Loan, C. F. (2012). Matrix computations (Vol. 3). JHU press. Section 4.4.
23 // 3. Bunch-Parlett diagonal pivoting <http://oz.nthu.edu.tw/~d947207/Chap13_GE3.ppt>
24 // 4. Ashcraft, C., Grimes, R. G., & Lewis, J. G. (1998). Accurate symmetric indefinite linear equation solvers.
25 //    SIAM Journal on Matrix Analysis and Applications, 20(2), 513-561.
26 template <typename Scalar = double>
27 class BKLDLT
28 {
29 private:
30     using Index = Eigen::Index;
31     using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
32     using MapVec = Eigen::Map<Vector>;
33     using MapConstVec = Eigen::Map<const Vector>;
34     using IntVector = Eigen::Matrix<Index, Eigen::Dynamic, 1>;
35     using GenericVector = Eigen::Ref<Vector>;
36     using ConstGenericVector = const Eigen::Ref<const Vector>;
37 
38     Index m_n;
39     Vector m_data;                                 // storage for a lower-triangular matrix
40     std::vector<Scalar*> m_colptr;                 // pointers to columns
41     IntVector m_perm;                              // [-2, -1, 3, 1, 4, 5]: 0 <-> 2, 1 <-> 1, 2 <-> 3, 3 <-> 1, 4 <-> 4, 5 <-> 5
42     std::vector<std::pair<Index, Index>> m_permc;  // compressed version of m_perm: [(0, 2), (2, 3), (3, 1)]
43 
44     bool m_computed;
45     CompInfo m_info;
46 
47     // Access to elements
48     // Pointer to the k-th column
col_pointer(Index k)49     Scalar* col_pointer(Index k) { return m_colptr[k]; }
50     // A[i, j] -> m_colptr[j][i - j], i >= j
coeff(Index i,Index j)51     Scalar& coeff(Index i, Index j) { return m_colptr[j][i - j]; }
coeff(Index i,Index j)52     const Scalar& coeff(Index i, Index j) const { return m_colptr[j][i - j]; }
53     // A[i, i] -> m_colptr[i][0]
diag_coeff(Index i)54     Scalar& diag_coeff(Index i) { return m_colptr[i][0]; }
diag_coeff(Index i)55     const Scalar& diag_coeff(Index i) const { return m_colptr[i][0]; }
56 
57     // Compute column pointers
compute_pointer()58     void compute_pointer()
59     {
60         m_colptr.clear();
61         m_colptr.reserve(m_n);
62         Scalar* head = m_data.data();
63 
64         for (Index i = 0; i < m_n; i++)
65         {
66             m_colptr.push_back(head);
67             head += (m_n - i);
68         }
69     }
70 
71     // Copy mat - shift * I to m_data
72     template <typename Derived>
copy_data(const Eigen::MatrixBase<Derived> & mat,int uplo,const Scalar & shift)73     void copy_data(const Eigen::MatrixBase<Derived>& mat, int uplo, const Scalar& shift)
74     {
75         // If mat is an expression, first evaluate it into a temporary object
76         // This can be achieved by assigning mat to a const Eigen::Ref<const Matrix>&
77         // If mat is a plain object, no temporary object is created
78         const Eigen::Ref<const typename Derived::PlainObject>& src(mat);
79 
80         // Efficient copying for column-major matrices with lower triangular part
81         if ((!Derived::PlainObject::IsRowMajor) && uplo == Eigen::Lower)
82         {
83             for (Index j = 0; j < m_n; j++)
84             {
85                 const Scalar* begin = &src.coeffRef(j, j);
86                 const Index len = m_n - j;
87                 std::copy(begin, begin + len, col_pointer(j));
88                 diag_coeff(j) -= shift;
89             }
90             return;
91         }
92 
93         Scalar* dest = m_data.data();
94         for (Index j = 0; j < m_n; j++)
95         {
96             for (Index i = j; i < m_n; i++, dest++)
97             {
98                 if (uplo == Eigen::Lower)
99                     *dest = src.coeff(i, j);
100                 else
101                     *dest = src.coeff(j, i);
102             }
103             diag_coeff(j) -= shift;
104         }
105     }
106 
107     // Compute compressed permutations
compress_permutation()108     void compress_permutation()
109     {
110         for (Index i = 0; i < m_n; i++)
111         {
112             // Recover the permutation action
113             const Index perm = (m_perm[i] >= 0) ? (m_perm[i]) : (-m_perm[i] - 1);
114             if (perm != i)
115                 m_permc.push_back(std::make_pair(i, perm));
116         }
117     }
118 
119     // Working on the A[k:end, k:end] submatrix
120     // Exchange k <-> r
121     // Assume r >= k
pivoting_1x1(Index k,Index r)122     void pivoting_1x1(Index k, Index r)
123     {
124         // No permutation
125         if (k == r)
126         {
127             m_perm[k] = r;
128             return;
129         }
130 
131         // A[k, k] <-> A[r, r]
132         std::swap(diag_coeff(k), diag_coeff(r));
133 
134         // A[(r+1):end, k] <-> A[(r+1):end, r]
135         std::swap_ranges(&coeff(r + 1, k), col_pointer(k + 1), &coeff(r + 1, r));
136 
137         // A[(k+1):(r-1), k] <-> A[r, (k+1):(r-1)]
138         Scalar* src = &coeff(k + 1, k);
139         for (Index j = k + 1; j < r; j++, src++)
140         {
141             std::swap(*src, coeff(r, j));
142         }
143 
144         m_perm[k] = r;
145     }
146 
147     // Working on the A[k:end, k:end] submatrix
148     // Exchange [k+1, k] <-> [r, p]
149     // Assume p >= k, r >= k+1
pivoting_2x2(Index k,Index r,Index p)150     void pivoting_2x2(Index k, Index r, Index p)
151     {
152         pivoting_1x1(k, p);
153         pivoting_1x1(k + 1, r);
154 
155         // A[k+1, k] <-> A[r, k]
156         std::swap(coeff(k + 1, k), coeff(r, k));
157 
158         // Use negative signs to indicate a 2x2 block
159         // Also minus one to distinguish a negative zero from a positive zero
160         m_perm[k] = -m_perm[k] - 1;
161         m_perm[k + 1] = -m_perm[k + 1] - 1;
162     }
163 
164     // A[r1, c1:c2] <-> A[r2, c1:c2]
165     // Assume r2 >= r1 > c2 >= c1
interchange_rows(Index r1,Index r2,Index c1,Index c2)166     void interchange_rows(Index r1, Index r2, Index c1, Index c2)
167     {
168         if (r1 == r2)
169             return;
170 
171         for (Index j = c1; j <= c2; j++)
172         {
173             std::swap(coeff(r1, j), coeff(r2, j));
174         }
175     }
176 
177     // lambda = |A[r, k]| = max{|A[k+1, k]|, ..., |A[end, k]|}
178     // Largest (in magnitude) off-diagonal element in the first column of the current reduced matrix
179     // r is the row index
180     // Assume k < end
find_lambda(Index k,Index & r)181     Scalar find_lambda(Index k, Index& r)
182     {
183         using std::abs;
184 
185         const Scalar* head = col_pointer(k);  // => A[k, k]
186         const Scalar* end = col_pointer(k + 1);
187         // Start with r=k+1, lambda=A[k+1, k]
188         r = k + 1;
189         Scalar lambda = abs(head[1]);
190         // Scan remaining elements
191         for (const Scalar* ptr = head + 2; ptr < end; ptr++)
192         {
193             const Scalar abs_elem = abs(*ptr);
194             if (lambda < abs_elem)
195             {
196                 lambda = abs_elem;
197                 r = k + (ptr - head);
198             }
199         }
200 
201         return lambda;
202     }
203 
204     // sigma = |A[p, r]| = max {|A[k, r]|, ..., |A[end, r]|} \ {A[r, r]}
205     // Largest (in magnitude) off-diagonal element in the r-th column of the current reduced matrix
206     // p is the row index
207     // Assume k < r < end
find_sigma(Index k,Index r,Index & p)208     Scalar find_sigma(Index k, Index r, Index& p)
209     {
210         using std::abs;
211 
212         // First search A[r+1, r], ...,  A[end, r], which has the same task as find_lambda()
213         // If r == end, we skip this search
214         Scalar sigma = Scalar(-1);
215         if (r < m_n - 1)
216             sigma = find_lambda(r, p);
217 
218         // Then search A[k, r], ..., A[r-1, r], which maps to A[r, k], ..., A[r, r-1]
219         for (Index j = k; j < r; j++)
220         {
221             const Scalar abs_elem = abs(coeff(r, j));
222             if (sigma < abs_elem)
223             {
224                 sigma = abs_elem;
225                 p = j;
226             }
227         }
228 
229         return sigma;
230     }
231 
232     // Generate permutations and apply to A
233     // Return true if the resulting pivoting is 1x1, and false if 2x2
permutate_mat(Index k,const Scalar & alpha)234     bool permutate_mat(Index k, const Scalar& alpha)
235     {
236         using std::abs;
237 
238         Index r = k, p = k;
239         const Scalar lambda = find_lambda(k, r);
240 
241         // If lambda=0, no need to interchange
242         if (lambda > Scalar(0))
243         {
244             const Scalar abs_akk = abs(diag_coeff(k));
245             // If |A[k, k]| >= alpha * lambda, no need to interchange
246             if (abs_akk < alpha * lambda)
247             {
248                 const Scalar sigma = find_sigma(k, r, p);
249 
250                 // If sigma * |A[k, k]| >= alpha * lambda^2, no need to interchange
251                 if (sigma * abs_akk < alpha * lambda * lambda)
252                 {
253                     if (abs_akk >= alpha * sigma)
254                     {
255                         // Permutation on A
256                         pivoting_1x1(k, r);
257 
258                         // Permutation on L
259                         interchange_rows(k, r, 0, k - 1);
260                         return true;
261                     }
262                     else
263                     {
264                         // There are two versions of permutation here
265                         // 1. A[k+1, k] <-> A[r, k]
266                         // 2. A[k+1, k] <-> A[r, p], where p >= k and r >= k+1
267                         //
268                         // Version 1 and 2 are used by Ref[1] and Ref[2], respectively
269 
270                         // Version 1 implementation
271                         p = k;
272 
273                         // Version 2 implementation
274                         // [r, p] and [p, r] are symmetric, but we need to make sure
275                         // p >= k and r >= k+1, so it is safe to always make r > p
276                         // One exception is when min{r,p} == k+1, in which case we make
277                         // r = k+1, so that only one permutation needs to be performed
278                         /* const Index rp_min = std::min(r, p);
279                         const Index rp_max = std::max(r, p);
280                         if(rp_min == k + 1)
281                         {
282                             r = rp_min; p = rp_max;
283                         } else {
284                             r = rp_max; p = rp_min;
285                         } */
286 
287                         // Right now we use Version 1 since it reduces the overhead of interchange
288 
289                         // Permutation on A
290                         pivoting_2x2(k, r, p);
291                         // Permutation on L
292                         interchange_rows(k, p, 0, k - 1);
293                         interchange_rows(k + 1, r, 0, k - 1);
294                         return false;
295                     }
296                 }
297             }
298         }
299 
300         return true;
301     }
302 
303     // E = [e11, e12]
304     //     [e21, e22]
305     // Overwrite E with inv(E)
inverse_inplace_2x2(Scalar & e11,Scalar & e21,Scalar & e22)306     void inverse_inplace_2x2(Scalar& e11, Scalar& e21, Scalar& e22) const
307     {
308         // inv(E) = [d11, d12], d11 = e22/delta, d21 = -e21/delta, d22 = e11/delta
309         //          [d21, d22]
310         const Scalar delta = e11 * e22 - e21 * e21;
311         std::swap(e11, e22);
312         e11 /= delta;
313         e22 /= delta;
314         e21 = -e21 / delta;
315     }
316 
317     // Return value is the status, CompInfo::Successful/NumericalIssue
gaussian_elimination_1x1(Index k)318     CompInfo gaussian_elimination_1x1(Index k)
319     {
320         // D = 1 / A[k, k]
321         const Scalar akk = diag_coeff(k);
322         // Return CompInfo::NumericalIssue if not invertible
323         if (akk == Scalar(0))
324             return CompInfo::NumericalIssue;
325 
326         diag_coeff(k) = Scalar(1) / akk;
327 
328         // B -= l * l' / A[k, k], B := A[(k+1):end, (k+1):end], l := L[(k+1):end, k]
329         Scalar* lptr = col_pointer(k) + 1;
330         const Index ldim = m_n - k - 1;
331         MapVec l(lptr, ldim);
332         for (Index j = 0; j < ldim; j++)
333         {
334             MapVec(col_pointer(j + k + 1), ldim - j).noalias() -= (lptr[j] / akk) * l.tail(ldim - j);
335         }
336 
337         // l /= A[k, k]
338         l /= akk;
339 
340         return CompInfo::Successful;
341     }
342 
343     // Return value is the status, CompInfo::Successful/NumericalIssue
gaussian_elimination_2x2(Index k)344     CompInfo gaussian_elimination_2x2(Index k)
345     {
346         // D = inv(E)
347         Scalar& e11 = diag_coeff(k);
348         Scalar& e21 = coeff(k + 1, k);
349         Scalar& e22 = diag_coeff(k + 1);
350         // Return CompInfo::NumericalIssue if not invertible
351         if (e11 * e22 - e21 * e21 == Scalar(0))
352             return CompInfo::NumericalIssue;
353 
354         inverse_inplace_2x2(e11, e21, e22);
355 
356         // X = l * inv(E), l := L[(k+2):end, k:(k+1)]
357         Scalar* l1ptr = &coeff(k + 2, k);
358         Scalar* l2ptr = &coeff(k + 2, k + 1);
359         const Index ldim = m_n - k - 2;
360         MapVec l1(l1ptr, ldim), l2(l2ptr, ldim);
361 
362         Eigen::Matrix<Scalar, Eigen::Dynamic, 2> X(ldim, 2);
363         X.col(0).noalias() = l1 * e11 + l2 * e21;
364         X.col(1).noalias() = l1 * e21 + l2 * e22;
365 
366         // B -= l * inv(E) * l' = X * l', B = A[(k+2):end, (k+2):end]
367         for (Index j = 0; j < ldim; j++)
368         {
369             MapVec(col_pointer(j + k + 2), ldim - j).noalias() -= (X.col(0).tail(ldim - j) * l1ptr[j] + X.col(1).tail(ldim - j) * l2ptr[j]);
370         }
371 
372         // l = X
373         l1.noalias() = X.col(0);
374         l2.noalias() = X.col(1);
375 
376         return CompInfo::Successful;
377     }
378 
379 public:
BKLDLT()380     BKLDLT() :
381         m_n(0), m_computed(false), m_info(CompInfo::NotComputed)
382     {}
383 
384     // Factorize mat - shift * I
385     template <typename Derived>
386     BKLDLT(const Eigen::MatrixBase<Derived>& mat, int uplo = Eigen::Lower, const Scalar& shift = Scalar(0)) :
387         m_n(mat.rows()), m_computed(false), m_info(CompInfo::NotComputed)
388     {
389         compute(mat, uplo, shift);
390     }
391 
392     template <typename Derived>
393     void compute(const Eigen::MatrixBase<Derived>& mat, int uplo = Eigen::Lower, const Scalar& shift = Scalar(0))
394     {
395         using std::abs;
396 
397         m_n = mat.rows();
398         if (m_n != mat.cols())
399             throw std::invalid_argument("BKLDLT: matrix must be square");
400 
401         m_perm.setLinSpaced(m_n, 0, m_n - 1);
402         m_permc.clear();
403 
404         // Copy data
405         m_data.resize((m_n * (m_n + 1)) / 2);
406         compute_pointer();
407         copy_data(mat, uplo, shift);
408 
409         const Scalar alpha = (1.0 + std::sqrt(17.0)) / 8.0;
410         Index k = 0;
411         for (k = 0; k < m_n - 1; k++)
412         {
413             // 1. Interchange rows and columns of A, and save the result to m_perm
414             bool is_1x1 = permutate_mat(k, alpha);
415 
416             // 2. Gaussian elimination
417             if (is_1x1)
418             {
419                 m_info = gaussian_elimination_1x1(k);
420             }
421             else
422             {
423                 m_info = gaussian_elimination_2x2(k);
424                 k++;
425             }
426 
427             // 3. Check status
428             if (m_info != CompInfo::Successful)
429                 break;
430         }
431         // Invert the last 1x1 block if it exists
432         if (k == m_n - 1)
433         {
434             const Scalar akk = diag_coeff(k);
435             if (akk == Scalar(0))
436                 m_info = CompInfo::NumericalIssue;
437 
438             diag_coeff(k) = Scalar(1) / diag_coeff(k);
439         }
440 
441         compress_permutation();
442 
443         m_computed = true;
444     }
445 
446     // Solve Ax=b
solve_inplace(GenericVector b)447     void solve_inplace(GenericVector b) const
448     {
449         if (!m_computed)
450             throw std::logic_error("BKLDLT: need to call compute() first");
451 
452         // PAP' = LDL'
453         // 1. b -> Pb
454         Scalar* x = b.data();
455         MapVec res(x, m_n);
456         Index npermc = m_permc.size();
457         for (Index i = 0; i < npermc; i++)
458         {
459             std::swap(x[m_permc[i].first], x[m_permc[i].second]);
460         }
461 
462         // 2. Lz = Pb
463         // If m_perm[end] < 0, then end with m_n - 3, otherwise end with m_n - 2
464         const Index end = (m_perm[m_n - 1] < 0) ? (m_n - 3) : (m_n - 2);
465         for (Index i = 0; i <= end; i++)
466         {
467             const Index b1size = m_n - i - 1;
468             const Index b2size = b1size - 1;
469             if (m_perm[i] >= 0)
470             {
471                 MapConstVec l(&coeff(i + 1, i), b1size);
472                 res.segment(i + 1, b1size).noalias() -= l * x[i];
473             }
474             else
475             {
476                 MapConstVec l1(&coeff(i + 2, i), b2size);
477                 MapConstVec l2(&coeff(i + 2, i + 1), b2size);
478                 res.segment(i + 2, b2size).noalias() -= (l1 * x[i] + l2 * x[i + 1]);
479                 i++;
480             }
481         }
482 
483         // 3. Dw = z
484         for (Index i = 0; i < m_n; i++)
485         {
486             const Scalar e11 = diag_coeff(i);
487             if (m_perm[i] >= 0)
488             {
489                 x[i] *= e11;
490             }
491             else
492             {
493                 const Scalar e21 = coeff(i + 1, i), e22 = diag_coeff(i + 1);
494                 const Scalar wi = x[i] * e11 + x[i + 1] * e21;
495                 x[i + 1] = x[i] * e21 + x[i + 1] * e22;
496                 x[i] = wi;
497                 i++;
498             }
499         }
500 
501         // 4. L'y = w
502         // If m_perm[end] < 0, then start with m_n - 3, otherwise start with m_n - 2
503         Index i = (m_perm[m_n - 1] < 0) ? (m_n - 3) : (m_n - 2);
504         for (; i >= 0; i--)
505         {
506             const Index ldim = m_n - i - 1;
507             MapConstVec l(&coeff(i + 1, i), ldim);
508             x[i] -= res.segment(i + 1, ldim).dot(l);
509 
510             if (m_perm[i] < 0)
511             {
512                 MapConstVec l2(&coeff(i + 1, i - 1), ldim);
513                 x[i - 1] -= res.segment(i + 1, ldim).dot(l2);
514                 i--;
515             }
516         }
517 
518         // 5. x = P'y
519         for (Index i = npermc - 1; i >= 0; i--)
520         {
521             std::swap(x[m_permc[i].first], x[m_permc[i].second]);
522         }
523     }
524 
solve(ConstGenericVector & b)525     Vector solve(ConstGenericVector& b) const
526     {
527         Vector res = b;
528         solve_inplace(res);
529         return res;
530     }
531 
info()532     CompInfo info() const { return m_info; }
533 };
534 
535 }  // namespace Spectra
536 
537 #endif  // SPECTRA_BK_LDLT_H
538