1""" 2Copyright (C) 2010 David Fong and Michael Saunders 3 4LSMR uses an iterative method. 5 607 Jun 2010: Documentation updated 703 Jun 2010: First release version in Python 8 9David Chin-lung Fong clfong@stanford.edu 10Institute for Computational and Mathematical Engineering 11Stanford University 12 13Michael Saunders saunders@stanford.edu 14Systems Optimization Laboratory 15Dept of MS&E, Stanford University. 16 17""" 18 19__all__ = ['lsmr'] 20 21from numpy import zeros, infty, atleast_1d, result_type 22from numpy.linalg import norm 23from math import sqrt 24from scipy.sparse.linalg.interface import aslinearoperator 25 26from .lsqr import _sym_ortho 27 28 29def lsmr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8, 30 maxiter=None, show=False, x0=None): 31 """Iterative solver for least-squares problems. 32 33 lsmr solves the system of linear equations ``Ax = b``. If the system 34 is inconsistent, it solves the least-squares problem ``min ||b - Ax||_2``. 35 ``A`` is a rectangular matrix of dimension m-by-n, where all cases are 36 allowed: m = n, m > n, or m < n. ``b`` is a vector of length m. 37 The matrix A may be dense or sparse (usually sparse). 38 39 Parameters 40 ---------- 41 A : {matrix, sparse matrix, ndarray, LinearOperator} 42 Matrix A in the linear system. 43 Alternatively, ``A`` can be a linear operator which can 44 produce ``Ax`` and ``A^H x`` using, e.g., 45 ``scipy.sparse.linalg.LinearOperator``. 46 b : array_like, shape (m,) 47 Vector ``b`` in the linear system. 48 damp : float 49 Damping factor for regularized least-squares. `lsmr` solves 50 the regularized least-squares problem:: 51 52 min ||(b) - ( A )x|| 53 ||(0) (damp*I) ||_2 54 55 where damp is a scalar. If damp is None or 0, the system 56 is solved without regularization. 57 atol, btol : float, optional 58 Stopping tolerances. `lsmr` continues iterations until a 59 certain backward error estimate is smaller than some quantity 60 depending on atol and btol. Let ``r = b - Ax`` be the 61 residual vector for the current approximate solution ``x``. 62 If ``Ax = b`` seems to be consistent, ``lsmr`` terminates 63 when ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``. 64 Otherwise, lsmr terminates when ``norm(A^H r) <= 65 atol * norm(A) * norm(r)``. If both tolerances are 1.0e-6 (say), 66 the final ``norm(r)`` should be accurate to about 6 67 digits. (The final ``x`` will usually have fewer correct digits, 68 depending on ``cond(A)`` and the size of LAMBDA.) If `atol` 69 or `btol` is None, a default value of 1.0e-6 will be used. 70 Ideally, they should be estimates of the relative error in the 71 entries of ``A`` and ``b`` respectively. For example, if the entries 72 of ``A`` have 7 correct digits, set ``atol = 1e-7``. This prevents 73 the algorithm from doing unnecessary work beyond the 74 uncertainty of the input data. 75 conlim : float, optional 76 `lsmr` terminates if an estimate of ``cond(A)`` exceeds 77 `conlim`. For compatible systems ``Ax = b``, conlim could be 78 as large as 1.0e+12 (say). For least-squares problems, 79 `conlim` should be less than 1.0e+8. If `conlim` is None, the 80 default value is 1e+8. Maximum precision can be obtained by 81 setting ``atol = btol = conlim = 0``, but the number of 82 iterations may then be excessive. 83 maxiter : int, optional 84 `lsmr` terminates if the number of iterations reaches 85 `maxiter`. The default is ``maxiter = min(m, n)``. For 86 ill-conditioned systems, a larger value of `maxiter` may be 87 needed. 88 show : bool, optional 89 Print iterations logs if ``show=True``. 90 x0 : array_like, shape (n,), optional 91 Initial guess of ``x``, if None zeros are used. 92 93 .. versionadded:: 1.0.0 94 95 Returns 96 ------- 97 x : ndarray of float 98 Least-square solution returned. 99 istop : int 100 istop gives the reason for stopping:: 101 102 istop = 0 means x=0 is a solution. If x0 was given, then x=x0 is a 103 solution. 104 = 1 means x is an approximate solution to A*x = B, 105 according to atol and btol. 106 = 2 means x approximately solves the least-squares problem 107 according to atol. 108 = 3 means COND(A) seems to be greater than CONLIM. 109 = 4 is the same as 1 with atol = btol = eps (machine 110 precision) 111 = 5 is the same as 2 with atol = eps. 112 = 6 is the same as 3 with CONLIM = 1/eps. 113 = 7 means ITN reached maxiter before the other stopping 114 conditions were satisfied. 115 116 itn : int 117 Number of iterations used. 118 normr : float 119 ``norm(b-Ax)`` 120 normar : float 121 ``norm(A^H (b - Ax))`` 122 norma : float 123 ``norm(A)`` 124 conda : float 125 Condition number of A. 126 normx : float 127 ``norm(x)`` 128 129 Notes 130 ----- 131 132 .. versionadded:: 0.11.0 133 134 References 135 ---------- 136 .. [1] D. C.-L. Fong and M. A. Saunders, 137 "LSMR: An iterative algorithm for sparse least-squares problems", 138 SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011. 139 :arxiv:`1006.0758` 140 .. [2] LSMR Software, https://web.stanford.edu/group/SOL/software/lsmr/ 141 142 Examples 143 -------- 144 >>> from scipy.sparse import csc_matrix 145 >>> from scipy.sparse.linalg import lsmr 146 >>> A = csc_matrix([[1., 0.], [1., 1.], [0., 1.]], dtype=float) 147 148 The first example has the trivial solution `[0, 0]` 149 150 >>> b = np.array([0., 0., 0.], dtype=float) 151 >>> x, istop, itn, normr = lsmr(A, b)[:4] 152 >>> istop 153 0 154 >>> x 155 array([ 0., 0.]) 156 157 The stopping code `istop=0` returned indicates that a vector of zeros was 158 found as a solution. The returned solution `x` indeed contains `[0., 0.]`. 159 The next example has a non-trivial solution: 160 161 >>> b = np.array([1., 0., -1.], dtype=float) 162 >>> x, istop, itn, normr = lsmr(A, b)[:4] 163 >>> istop 164 1 165 >>> x 166 array([ 1., -1.]) 167 >>> itn 168 1 169 >>> normr 170 4.440892098500627e-16 171 172 As indicated by `istop=1`, `lsmr` found a solution obeying the tolerance 173 limits. The given solution `[1., -1.]` obviously solves the equation. The 174 remaining return values include information about the number of iterations 175 (`itn=1`) and the remaining difference of left and right side of the solved 176 equation. 177 The final example demonstrates the behavior in the case where there is no 178 solution for the equation: 179 180 >>> b = np.array([1., 0.01, -1.], dtype=float) 181 >>> x, istop, itn, normr = lsmr(A, b)[:4] 182 >>> istop 183 2 184 >>> x 185 array([ 1.00333333, -0.99666667]) 186 >>> A.dot(x)-b 187 array([ 0.00333333, -0.00333333, 0.00333333]) 188 >>> normr 189 0.005773502691896255 190 191 `istop` indicates that the system is inconsistent and thus `x` is rather an 192 approximate solution to the corresponding least-squares problem. `normr` 193 contains the minimal distance that was found. 194 """ 195 196 A = aslinearoperator(A) 197 b = atleast_1d(b) 198 if b.ndim > 1: 199 b = b.squeeze() 200 201 msg = ('The exact solution is x = 0, or x = x0, if x0 was given ', 202 'Ax - b is small enough, given atol, btol ', 203 'The least-squares solution is good enough, given atol ', 204 'The estimate of cond(Abar) has exceeded conlim ', 205 'Ax - b is small enough for this machine ', 206 'The least-squares solution is good enough for this machine', 207 'Cond(Abar) seems to be too large for this machine ', 208 'The iteration limit has been reached ') 209 210 hdg1 = ' itn x(1) norm r norm Ar' 211 hdg2 = ' compatible LS norm A cond A' 212 pfreq = 20 # print frequency (for repeating the heading) 213 pcount = 0 # print counter 214 215 m, n = A.shape 216 217 # stores the num of singular values 218 minDim = min([m, n]) 219 220 if maxiter is None: 221 maxiter = minDim 222 223 if x0 is None: 224 dtype = result_type(A, b, float) 225 else: 226 dtype = result_type(A, b, x0, float) 227 228 if show: 229 print(' ') 230 print('LSMR Least-squares solution of Ax = b\n') 231 print(f'The matrix A has {m} rows and {n} columns') 232 print('damp = %20.14e\n' % (damp)) 233 print('atol = %8.2e conlim = %8.2e\n' % (atol, conlim)) 234 print('btol = %8.2e maxiter = %8g\n' % (btol, maxiter)) 235 236 u = b 237 normb = norm(b) 238 if x0 is None: 239 x = zeros(n, dtype) 240 beta = normb.copy() 241 else: 242 x = atleast_1d(x0) 243 u = u - A.matvec(x) 244 beta = norm(u) 245 246 if beta > 0: 247 u = (1 / beta) * u 248 v = A.rmatvec(u) 249 alpha = norm(v) 250 else: 251 v = zeros(n, dtype) 252 alpha = 0 253 254 if alpha > 0: 255 v = (1 / alpha) * v 256 257 # Initialize variables for 1st iteration. 258 259 itn = 0 260 zetabar = alpha * beta 261 alphabar = alpha 262 rho = 1 263 rhobar = 1 264 cbar = 1 265 sbar = 0 266 267 h = v.copy() 268 hbar = zeros(n, dtype) 269 270 # Initialize variables for estimation of ||r||. 271 272 betadd = beta 273 betad = 0 274 rhodold = 1 275 tautildeold = 0 276 thetatilde = 0 277 zeta = 0 278 d = 0 279 280 # Initialize variables for estimation of ||A|| and cond(A) 281 282 normA2 = alpha * alpha 283 maxrbar = 0 284 minrbar = 1e+100 285 normA = sqrt(normA2) 286 condA = 1 287 normx = 0 288 289 # Items for use in stopping rules, normb set earlier 290 istop = 0 291 ctol = 0 292 if conlim > 0: 293 ctol = 1 / conlim 294 normr = beta 295 296 # Reverse the order here from the original matlab code because 297 # there was an error on return when arnorm==0 298 normar = alpha * beta 299 if normar == 0: 300 if show: 301 print(msg[0]) 302 return x, istop, itn, normr, normar, normA, condA, normx 303 304 if show: 305 print(' ') 306 print(hdg1, hdg2) 307 test1 = 1 308 test2 = alpha / beta 309 str1 = '%6g %12.5e' % (itn, x[0]) 310 str2 = ' %10.3e %10.3e' % (normr, normar) 311 str3 = ' %8.1e %8.1e' % (test1, test2) 312 print(''.join([str1, str2, str3])) 313 314 # Main iteration loop. 315 while itn < maxiter: 316 itn = itn + 1 317 318 # Perform the next step of the bidiagonalization to obtain the 319 # next beta, u, alpha, v. These satisfy the relations 320 # beta*u = a*v - alpha*u, 321 # alpha*v = A'*u - beta*v. 322 323 u *= -alpha 324 u += A.matvec(v) 325 beta = norm(u) 326 327 if beta > 0: 328 u *= (1 / beta) 329 v *= -beta 330 v += A.rmatvec(u) 331 alpha = norm(v) 332 if alpha > 0: 333 v *= (1 / alpha) 334 335 # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}. 336 337 # Construct rotation Qhat_{k,2k+1}. 338 339 chat, shat, alphahat = _sym_ortho(alphabar, damp) 340 341 # Use a plane rotation (Q_i) to turn B_i to R_i 342 343 rhoold = rho 344 c, s, rho = _sym_ortho(alphahat, beta) 345 thetanew = s*alpha 346 alphabar = c*alpha 347 348 # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar 349 350 rhobarold = rhobar 351 zetaold = zeta 352 thetabar = sbar * rho 353 rhotemp = cbar * rho 354 cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew) 355 zeta = cbar * zetabar 356 zetabar = - sbar * zetabar 357 358 # Update h, h_hat, x. 359 360 hbar *= - (thetabar * rho / (rhoold * rhobarold)) 361 hbar += h 362 x += (zeta / (rho * rhobar)) * hbar 363 h *= - (thetanew / rho) 364 h += v 365 366 # Estimate of ||r||. 367 368 # Apply rotation Qhat_{k,2k+1}. 369 betaacute = chat * betadd 370 betacheck = -shat * betadd 371 372 # Apply rotation Q_{k,k+1}. 373 betahat = c * betaacute 374 betadd = -s * betaacute 375 376 # Apply rotation Qtilde_{k-1}. 377 # betad = betad_{k-1} here. 378 379 thetatildeold = thetatilde 380 ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar) 381 thetatilde = stildeold * rhobar 382 rhodold = ctildeold * rhobar 383 betad = - stildeold * betad + ctildeold * betahat 384 385 # betad = betad_k here. 386 # rhodold = rhod_k here. 387 388 tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold 389 taud = (zeta - thetatilde * tautildeold) / rhodold 390 d = d + betacheck * betacheck 391 normr = sqrt(d + (betad - taud)**2 + betadd * betadd) 392 393 # Estimate ||A||. 394 normA2 = normA2 + beta * beta 395 normA = sqrt(normA2) 396 normA2 = normA2 + alpha * alpha 397 398 # Estimate cond(A). 399 maxrbar = max(maxrbar, rhobarold) 400 if itn > 1: 401 minrbar = min(minrbar, rhobarold) 402 condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp) 403 404 # Test for convergence. 405 406 # Compute norms for convergence testing. 407 normar = abs(zetabar) 408 normx = norm(x) 409 410 # Now use these norms to estimate certain other quantities, 411 # some of which will be small near a solution. 412 413 test1 = normr / normb 414 if (normA * normr) != 0: 415 test2 = normar / (normA * normr) 416 else: 417 test2 = infty 418 test3 = 1 / condA 419 t1 = test1 / (1 + normA * normx / normb) 420 rtol = btol + atol * normA * normx / normb 421 422 # The following tests guard against extremely small values of 423 # atol, btol or ctol. (The user may have set any or all of 424 # the parameters atol, btol, conlim to 0.) 425 # The effect is equivalent to the normAl tests using 426 # atol = eps, btol = eps, conlim = 1/eps. 427 428 if itn >= maxiter: 429 istop = 7 430 if 1 + test3 <= 1: 431 istop = 6 432 if 1 + test2 <= 1: 433 istop = 5 434 if 1 + t1 <= 1: 435 istop = 4 436 437 # Allow for tolerances set by the user. 438 439 if test3 <= ctol: 440 istop = 3 441 if test2 <= atol: 442 istop = 2 443 if test1 <= rtol: 444 istop = 1 445 446 # See if it is time to print something. 447 448 if show: 449 if (n <= 40) or (itn <= 10) or (itn >= maxiter - 10) or \ 450 (itn % 10 == 0) or (test3 <= 1.1 * ctol) or \ 451 (test2 <= 1.1 * atol) or (test1 <= 1.1 * rtol) or \ 452 (istop != 0): 453 454 if pcount >= pfreq: 455 pcount = 0 456 print(' ') 457 print(hdg1, hdg2) 458 pcount = pcount + 1 459 str1 = '%6g %12.5e' % (itn, x[0]) 460 str2 = ' %10.3e %10.3e' % (normr, normar) 461 str3 = ' %8.1e %8.1e' % (test1, test2) 462 str4 = ' %8.1e %8.1e' % (normA, condA) 463 print(''.join([str1, str2, str3, str4])) 464 465 if istop > 0: 466 break 467 468 # Print the stopping condition. 469 470 if show: 471 print(' ') 472 print('LSMR finished') 473 print(msg[istop]) 474 print('istop =%8g normr =%8.1e' % (istop, normr)) 475 print(' normA =%8.1e normAr =%8.1e' % (normA, normar)) 476 print('itn =%8g condA =%8.1e' % (itn, condA)) 477 print(' normx =%8.1e' % (normx)) 478 print(str1, str2) 479 print(str3, str4) 480 481 return x, istop, itn, normr, normar, normA, condA, normx 482