1 // Copyright (C) 2019-2021 Yixuan Qiu <yixuan.qiu@cos.name> 2 // 3 // This Source Code Form is subject to the terms of the Mozilla 4 // Public License v. 2.0. If a copy of the MPL was not distributed 5 // with this file, You can obtain one at https://mozilla.org/MPL/2.0/. 6 7 #ifndef SPECTRA_BK_LDLT_H 8 #define SPECTRA_BK_LDLT_H 9 10 #include <Eigen/Core> 11 #include <vector> 12 #include <stdexcept> 13 14 #include "../Util/CompInfo.h" 15 16 namespace Spectra { 17 18 // Bunch-Kaufman LDLT decomposition 19 // References: 20 // 1. Bunch, J. R., & Kaufman, L. (1977). Some stable methods for calculating inertia and solving symmetric linear systems. 21 // Mathematics of computation, 31(137), 163-179. 22 // 2. Golub, G. H., & Van Loan, C. F. (2012). Matrix computations (Vol. 3). JHU press. Section 4.4. 23 // 3. Bunch-Parlett diagonal pivoting <http://oz.nthu.edu.tw/~d947207/Chap13_GE3.ppt> 24 // 4. Ashcraft, C., Grimes, R. G., & Lewis, J. G. (1998). Accurate symmetric indefinite linear equation solvers. 25 // SIAM Journal on Matrix Analysis and Applications, 20(2), 513-561. 26 template <typename Scalar = double> 27 class BKLDLT 28 { 29 private: 30 using Index = Eigen::Index; 31 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>; 32 using MapVec = Eigen::Map<Vector>; 33 using MapConstVec = Eigen::Map<const Vector>; 34 using IntVector = Eigen::Matrix<Index, Eigen::Dynamic, 1>; 35 using GenericVector = Eigen::Ref<Vector>; 36 using ConstGenericVector = const Eigen::Ref<const Vector>; 37 38 Index m_n; 39 Vector m_data; // storage for a lower-triangular matrix 40 std::vector<Scalar*> m_colptr; // pointers to columns 41 IntVector m_perm; // [-2, -1, 3, 1, 4, 5]: 0 <-> 2, 1 <-> 1, 2 <-> 3, 3 <-> 1, 4 <-> 4, 5 <-> 5 42 std::vector<std::pair<Index, Index>> m_permc; // compressed version of m_perm: [(0, 2), (2, 3), (3, 1)] 43 44 bool m_computed; 45 CompInfo m_info; 46 47 // Access to elements 48 // Pointer to the k-th column col_pointer(Index k)49 Scalar* col_pointer(Index k) { return m_colptr[k]; } 50 // A[i, j] -> m_colptr[j][i - j], i >= j coeff(Index i,Index j)51 Scalar& coeff(Index i, Index j) { return m_colptr[j][i - j]; } coeff(Index i,Index j)52 const Scalar& coeff(Index i, Index j) const { return m_colptr[j][i - j]; } 53 // A[i, i] -> m_colptr[i][0] diag_coeff(Index i)54 Scalar& diag_coeff(Index i) { return m_colptr[i][0]; } diag_coeff(Index i)55 const Scalar& diag_coeff(Index i) const { return m_colptr[i][0]; } 56 57 // Compute column pointers compute_pointer()58 void compute_pointer() 59 { 60 m_colptr.clear(); 61 m_colptr.reserve(m_n); 62 Scalar* head = m_data.data(); 63 64 for (Index i = 0; i < m_n; i++) 65 { 66 m_colptr.push_back(head); 67 head += (m_n - i); 68 } 69 } 70 71 // Copy mat - shift * I to m_data 72 template <typename Derived> copy_data(const Eigen::MatrixBase<Derived> & mat,int uplo,const Scalar & shift)73 void copy_data(const Eigen::MatrixBase<Derived>& mat, int uplo, const Scalar& shift) 74 { 75 // If mat is an expression, first evaluate it into a temporary object 76 // This can be achieved by assigning mat to a const Eigen::Ref<const Matrix>& 77 // If mat is a plain object, no temporary object is created 78 const Eigen::Ref<const typename Derived::PlainObject>& src(mat); 79 80 // Efficient copying for column-major matrices with lower triangular part 81 if ((!Derived::PlainObject::IsRowMajor) && uplo == Eigen::Lower) 82 { 83 for (Index j = 0; j < m_n; j++) 84 { 85 const Scalar* begin = &src.coeffRef(j, j); 86 const Index len = m_n - j; 87 std::copy(begin, begin + len, col_pointer(j)); 88 diag_coeff(j) -= shift; 89 } 90 return; 91 } 92 93 Scalar* dest = m_data.data(); 94 for (Index j = 0; j < m_n; j++) 95 { 96 for (Index i = j; i < m_n; i++, dest++) 97 { 98 if (uplo == Eigen::Lower) 99 *dest = src.coeff(i, j); 100 else 101 *dest = src.coeff(j, i); 102 } 103 diag_coeff(j) -= shift; 104 } 105 } 106 107 // Compute compressed permutations compress_permutation()108 void compress_permutation() 109 { 110 for (Index i = 0; i < m_n; i++) 111 { 112 // Recover the permutation action 113 const Index perm = (m_perm[i] >= 0) ? (m_perm[i]) : (-m_perm[i] - 1); 114 if (perm != i) 115 m_permc.push_back(std::make_pair(i, perm)); 116 } 117 } 118 119 // Working on the A[k:end, k:end] submatrix 120 // Exchange k <-> r 121 // Assume r >= k pivoting_1x1(Index k,Index r)122 void pivoting_1x1(Index k, Index r) 123 { 124 // No permutation 125 if (k == r) 126 { 127 m_perm[k] = r; 128 return; 129 } 130 131 // A[k, k] <-> A[r, r] 132 std::swap(diag_coeff(k), diag_coeff(r)); 133 134 // A[(r+1):end, k] <-> A[(r+1):end, r] 135 std::swap_ranges(&coeff(r + 1, k), col_pointer(k + 1), &coeff(r + 1, r)); 136 137 // A[(k+1):(r-1), k] <-> A[r, (k+1):(r-1)] 138 Scalar* src = &coeff(k + 1, k); 139 for (Index j = k + 1; j < r; j++, src++) 140 { 141 std::swap(*src, coeff(r, j)); 142 } 143 144 m_perm[k] = r; 145 } 146 147 // Working on the A[k:end, k:end] submatrix 148 // Exchange [k+1, k] <-> [r, p] 149 // Assume p >= k, r >= k+1 pivoting_2x2(Index k,Index r,Index p)150 void pivoting_2x2(Index k, Index r, Index p) 151 { 152 pivoting_1x1(k, p); 153 pivoting_1x1(k + 1, r); 154 155 // A[k+1, k] <-> A[r, k] 156 std::swap(coeff(k + 1, k), coeff(r, k)); 157 158 // Use negative signs to indicate a 2x2 block 159 // Also minus one to distinguish a negative zero from a positive zero 160 m_perm[k] = -m_perm[k] - 1; 161 m_perm[k + 1] = -m_perm[k + 1] - 1; 162 } 163 164 // A[r1, c1:c2] <-> A[r2, c1:c2] 165 // Assume r2 >= r1 > c2 >= c1 interchange_rows(Index r1,Index r2,Index c1,Index c2)166 void interchange_rows(Index r1, Index r2, Index c1, Index c2) 167 { 168 if (r1 == r2) 169 return; 170 171 for (Index j = c1; j <= c2; j++) 172 { 173 std::swap(coeff(r1, j), coeff(r2, j)); 174 } 175 } 176 177 // lambda = |A[r, k]| = max{|A[k+1, k]|, ..., |A[end, k]|} 178 // Largest (in magnitude) off-diagonal element in the first column of the current reduced matrix 179 // r is the row index 180 // Assume k < end find_lambda(Index k,Index & r)181 Scalar find_lambda(Index k, Index& r) 182 { 183 using std::abs; 184 185 const Scalar* head = col_pointer(k); // => A[k, k] 186 const Scalar* end = col_pointer(k + 1); 187 // Start with r=k+1, lambda=A[k+1, k] 188 r = k + 1; 189 Scalar lambda = abs(head[1]); 190 // Scan remaining elements 191 for (const Scalar* ptr = head + 2; ptr < end; ptr++) 192 { 193 const Scalar abs_elem = abs(*ptr); 194 if (lambda < abs_elem) 195 { 196 lambda = abs_elem; 197 r = k + (ptr - head); 198 } 199 } 200 201 return lambda; 202 } 203 204 // sigma = |A[p, r]| = max {|A[k, r]|, ..., |A[end, r]|} \ {A[r, r]} 205 // Largest (in magnitude) off-diagonal element in the r-th column of the current reduced matrix 206 // p is the row index 207 // Assume k < r < end find_sigma(Index k,Index r,Index & p)208 Scalar find_sigma(Index k, Index r, Index& p) 209 { 210 using std::abs; 211 212 // First search A[r+1, r], ..., A[end, r], which has the same task as find_lambda() 213 // If r == end, we skip this search 214 Scalar sigma = Scalar(-1); 215 if (r < m_n - 1) 216 sigma = find_lambda(r, p); 217 218 // Then search A[k, r], ..., A[r-1, r], which maps to A[r, k], ..., A[r, r-1] 219 for (Index j = k; j < r; j++) 220 { 221 const Scalar abs_elem = abs(coeff(r, j)); 222 if (sigma < abs_elem) 223 { 224 sigma = abs_elem; 225 p = j; 226 } 227 } 228 229 return sigma; 230 } 231 232 // Generate permutations and apply to A 233 // Return true if the resulting pivoting is 1x1, and false if 2x2 permutate_mat(Index k,const Scalar & alpha)234 bool permutate_mat(Index k, const Scalar& alpha) 235 { 236 using std::abs; 237 238 Index r = k, p = k; 239 const Scalar lambda = find_lambda(k, r); 240 241 // If lambda=0, no need to interchange 242 if (lambda > Scalar(0)) 243 { 244 const Scalar abs_akk = abs(diag_coeff(k)); 245 // If |A[k, k]| >= alpha * lambda, no need to interchange 246 if (abs_akk < alpha * lambda) 247 { 248 const Scalar sigma = find_sigma(k, r, p); 249 250 // If sigma * |A[k, k]| >= alpha * lambda^2, no need to interchange 251 if (sigma * abs_akk < alpha * lambda * lambda) 252 { 253 if (abs_akk >= alpha * sigma) 254 { 255 // Permutation on A 256 pivoting_1x1(k, r); 257 258 // Permutation on L 259 interchange_rows(k, r, 0, k - 1); 260 return true; 261 } 262 else 263 { 264 // There are two versions of permutation here 265 // 1. A[k+1, k] <-> A[r, k] 266 // 2. A[k+1, k] <-> A[r, p], where p >= k and r >= k+1 267 // 268 // Version 1 and 2 are used by Ref[1] and Ref[2], respectively 269 270 // Version 1 implementation 271 p = k; 272 273 // Version 2 implementation 274 // [r, p] and [p, r] are symmetric, but we need to make sure 275 // p >= k and r >= k+1, so it is safe to always make r > p 276 // One exception is when min{r,p} == k+1, in which case we make 277 // r = k+1, so that only one permutation needs to be performed 278 /* const Index rp_min = std::min(r, p); 279 const Index rp_max = std::max(r, p); 280 if(rp_min == k + 1) 281 { 282 r = rp_min; p = rp_max; 283 } else { 284 r = rp_max; p = rp_min; 285 } */ 286 287 // Right now we use Version 1 since it reduces the overhead of interchange 288 289 // Permutation on A 290 pivoting_2x2(k, r, p); 291 // Permutation on L 292 interchange_rows(k, p, 0, k - 1); 293 interchange_rows(k + 1, r, 0, k - 1); 294 return false; 295 } 296 } 297 } 298 } 299 300 return true; 301 } 302 303 // E = [e11, e12] 304 // [e21, e22] 305 // Overwrite E with inv(E) inverse_inplace_2x2(Scalar & e11,Scalar & e21,Scalar & e22)306 void inverse_inplace_2x2(Scalar& e11, Scalar& e21, Scalar& e22) const 307 { 308 // inv(E) = [d11, d12], d11 = e22/delta, d21 = -e21/delta, d22 = e11/delta 309 // [d21, d22] 310 const Scalar delta = e11 * e22 - e21 * e21; 311 std::swap(e11, e22); 312 e11 /= delta; 313 e22 /= delta; 314 e21 = -e21 / delta; 315 } 316 317 // Return value is the status, CompInfo::Successful/NumericalIssue gaussian_elimination_1x1(Index k)318 CompInfo gaussian_elimination_1x1(Index k) 319 { 320 // D = 1 / A[k, k] 321 const Scalar akk = diag_coeff(k); 322 // Return CompInfo::NumericalIssue if not invertible 323 if (akk == Scalar(0)) 324 return CompInfo::NumericalIssue; 325 326 diag_coeff(k) = Scalar(1) / akk; 327 328 // B -= l * l' / A[k, k], B := A[(k+1):end, (k+1):end], l := L[(k+1):end, k] 329 Scalar* lptr = col_pointer(k) + 1; 330 const Index ldim = m_n - k - 1; 331 MapVec l(lptr, ldim); 332 for (Index j = 0; j < ldim; j++) 333 { 334 MapVec(col_pointer(j + k + 1), ldim - j).noalias() -= (lptr[j] / akk) * l.tail(ldim - j); 335 } 336 337 // l /= A[k, k] 338 l /= akk; 339 340 return CompInfo::Successful; 341 } 342 343 // Return value is the status, CompInfo::Successful/NumericalIssue gaussian_elimination_2x2(Index k)344 CompInfo gaussian_elimination_2x2(Index k) 345 { 346 // D = inv(E) 347 Scalar& e11 = diag_coeff(k); 348 Scalar& e21 = coeff(k + 1, k); 349 Scalar& e22 = diag_coeff(k + 1); 350 // Return CompInfo::NumericalIssue if not invertible 351 if (e11 * e22 - e21 * e21 == Scalar(0)) 352 return CompInfo::NumericalIssue; 353 354 inverse_inplace_2x2(e11, e21, e22); 355 356 // X = l * inv(E), l := L[(k+2):end, k:(k+1)] 357 Scalar* l1ptr = &coeff(k + 2, k); 358 Scalar* l2ptr = &coeff(k + 2, k + 1); 359 const Index ldim = m_n - k - 2; 360 MapVec l1(l1ptr, ldim), l2(l2ptr, ldim); 361 362 Eigen::Matrix<Scalar, Eigen::Dynamic, 2> X(ldim, 2); 363 X.col(0).noalias() = l1 * e11 + l2 * e21; 364 X.col(1).noalias() = l1 * e21 + l2 * e22; 365 366 // B -= l * inv(E) * l' = X * l', B = A[(k+2):end, (k+2):end] 367 for (Index j = 0; j < ldim; j++) 368 { 369 MapVec(col_pointer(j + k + 2), ldim - j).noalias() -= (X.col(0).tail(ldim - j) * l1ptr[j] + X.col(1).tail(ldim - j) * l2ptr[j]); 370 } 371 372 // l = X 373 l1.noalias() = X.col(0); 374 l2.noalias() = X.col(1); 375 376 return CompInfo::Successful; 377 } 378 379 public: BKLDLT()380 BKLDLT() : 381 m_n(0), m_computed(false), m_info(CompInfo::NotComputed) 382 {} 383 384 // Factorize mat - shift * I 385 template <typename Derived> 386 BKLDLT(const Eigen::MatrixBase<Derived>& mat, int uplo = Eigen::Lower, const Scalar& shift = Scalar(0)) : 387 m_n(mat.rows()), m_computed(false), m_info(CompInfo::NotComputed) 388 { 389 compute(mat, uplo, shift); 390 } 391 392 template <typename Derived> 393 void compute(const Eigen::MatrixBase<Derived>& mat, int uplo = Eigen::Lower, const Scalar& shift = Scalar(0)) 394 { 395 using std::abs; 396 397 m_n = mat.rows(); 398 if (m_n != mat.cols()) 399 throw std::invalid_argument("BKLDLT: matrix must be square"); 400 401 m_perm.setLinSpaced(m_n, 0, m_n - 1); 402 m_permc.clear(); 403 404 // Copy data 405 m_data.resize((m_n * (m_n + 1)) / 2); 406 compute_pointer(); 407 copy_data(mat, uplo, shift); 408 409 const Scalar alpha = (1.0 + std::sqrt(17.0)) / 8.0; 410 Index k = 0; 411 for (k = 0; k < m_n - 1; k++) 412 { 413 // 1. Interchange rows and columns of A, and save the result to m_perm 414 bool is_1x1 = permutate_mat(k, alpha); 415 416 // 2. Gaussian elimination 417 if (is_1x1) 418 { 419 m_info = gaussian_elimination_1x1(k); 420 } 421 else 422 { 423 m_info = gaussian_elimination_2x2(k); 424 k++; 425 } 426 427 // 3. Check status 428 if (m_info != CompInfo::Successful) 429 break; 430 } 431 // Invert the last 1x1 block if it exists 432 if (k == m_n - 1) 433 { 434 const Scalar akk = diag_coeff(k); 435 if (akk == Scalar(0)) 436 m_info = CompInfo::NumericalIssue; 437 438 diag_coeff(k) = Scalar(1) / diag_coeff(k); 439 } 440 441 compress_permutation(); 442 443 m_computed = true; 444 } 445 446 // Solve Ax=b solve_inplace(GenericVector b)447 void solve_inplace(GenericVector b) const 448 { 449 if (!m_computed) 450 throw std::logic_error("BKLDLT: need to call compute() first"); 451 452 // PAP' = LDL' 453 // 1. b -> Pb 454 Scalar* x = b.data(); 455 MapVec res(x, m_n); 456 Index npermc = m_permc.size(); 457 for (Index i = 0; i < npermc; i++) 458 { 459 std::swap(x[m_permc[i].first], x[m_permc[i].second]); 460 } 461 462 // 2. Lz = Pb 463 // If m_perm[end] < 0, then end with m_n - 3, otherwise end with m_n - 2 464 const Index end = (m_perm[m_n - 1] < 0) ? (m_n - 3) : (m_n - 2); 465 for (Index i = 0; i <= end; i++) 466 { 467 const Index b1size = m_n - i - 1; 468 const Index b2size = b1size - 1; 469 if (m_perm[i] >= 0) 470 { 471 MapConstVec l(&coeff(i + 1, i), b1size); 472 res.segment(i + 1, b1size).noalias() -= l * x[i]; 473 } 474 else 475 { 476 MapConstVec l1(&coeff(i + 2, i), b2size); 477 MapConstVec l2(&coeff(i + 2, i + 1), b2size); 478 res.segment(i + 2, b2size).noalias() -= (l1 * x[i] + l2 * x[i + 1]); 479 i++; 480 } 481 } 482 483 // 3. Dw = z 484 for (Index i = 0; i < m_n; i++) 485 { 486 const Scalar e11 = diag_coeff(i); 487 if (m_perm[i] >= 0) 488 { 489 x[i] *= e11; 490 } 491 else 492 { 493 const Scalar e21 = coeff(i + 1, i), e22 = diag_coeff(i + 1); 494 const Scalar wi = x[i] * e11 + x[i + 1] * e21; 495 x[i + 1] = x[i] * e21 + x[i + 1] * e22; 496 x[i] = wi; 497 i++; 498 } 499 } 500 501 // 4. L'y = w 502 // If m_perm[end] < 0, then start with m_n - 3, otherwise start with m_n - 2 503 Index i = (m_perm[m_n - 1] < 0) ? (m_n - 3) : (m_n - 2); 504 for (; i >= 0; i--) 505 { 506 const Index ldim = m_n - i - 1; 507 MapConstVec l(&coeff(i + 1, i), ldim); 508 x[i] -= res.segment(i + 1, ldim).dot(l); 509 510 if (m_perm[i] < 0) 511 { 512 MapConstVec l2(&coeff(i + 1, i - 1), ldim); 513 x[i - 1] -= res.segment(i + 1, ldim).dot(l2); 514 i--; 515 } 516 } 517 518 // 5. x = P'y 519 for (Index i = npermc - 1; i >= 0; i--) 520 { 521 std::swap(x[m_permc[i].first], x[m_permc[i].second]); 522 } 523 } 524 solve(ConstGenericVector & b)525 Vector solve(ConstGenericVector& b) const 526 { 527 Vector res = b; 528 solve_inplace(res); 529 return res; 530 } 531 info()532 CompInfo info() const { return m_info; } 533 }; 534 535 } // namespace Spectra 536 537 #endif // SPECTRA_BK_LDLT_H 538