1"""
2Copyright (C) 2010 David Fong and Michael Saunders
3
4LSMR uses an iterative method.
5
607 Jun 2010: Documentation updated
703 Jun 2010: First release version in Python
8
9David Chin-lung Fong            clfong@stanford.edu
10Institute for Computational and Mathematical Engineering
11Stanford University
12
13Michael Saunders                saunders@stanford.edu
14Systems Optimization Laboratory
15Dept of MS&E, Stanford University.
16
17"""
18
19__all__ = ['lsmr']
20
21from numpy import zeros, infty, atleast_1d, result_type
22from numpy.linalg import norm
23from math import sqrt
24from scipy.sparse.linalg.interface import aslinearoperator
25
26from .lsqr import _sym_ortho
27
28
29def lsmr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8,
30         maxiter=None, show=False, x0=None):
31    """Iterative solver for least-squares problems.
32
33    lsmr solves the system of linear equations ``Ax = b``. If the system
34    is inconsistent, it solves the least-squares problem ``min ||b - Ax||_2``.
35    ``A`` is a rectangular matrix of dimension m-by-n, where all cases are
36    allowed: m = n, m > n, or m < n. ``b`` is a vector of length m.
37    The matrix A may be dense or sparse (usually sparse).
38
39    Parameters
40    ----------
41    A : {matrix, sparse matrix, ndarray, LinearOperator}
42        Matrix A in the linear system.
43        Alternatively, ``A`` can be a linear operator which can
44        produce ``Ax`` and ``A^H x`` using, e.g.,
45        ``scipy.sparse.linalg.LinearOperator``.
46    b : array_like, shape (m,)
47        Vector ``b`` in the linear system.
48    damp : float
49        Damping factor for regularized least-squares. `lsmr` solves
50        the regularized least-squares problem::
51
52         min ||(b) - (  A   )x||
53             ||(0)   (damp*I) ||_2
54
55        where damp is a scalar.  If damp is None or 0, the system
56        is solved without regularization.
57    atol, btol : float, optional
58        Stopping tolerances. `lsmr` continues iterations until a
59        certain backward error estimate is smaller than some quantity
60        depending on atol and btol.  Let ``r = b - Ax`` be the
61        residual vector for the current approximate solution ``x``.
62        If ``Ax = b`` seems to be consistent, ``lsmr`` terminates
63        when ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``.
64        Otherwise, lsmr terminates when ``norm(A^H r) <=
65        atol * norm(A) * norm(r)``.  If both tolerances are 1.0e-6 (say),
66        the final ``norm(r)`` should be accurate to about 6
67        digits. (The final ``x`` will usually have fewer correct digits,
68        depending on ``cond(A)`` and the size of LAMBDA.)  If `atol`
69        or `btol` is None, a default value of 1.0e-6 will be used.
70        Ideally, they should be estimates of the relative error in the
71        entries of ``A`` and ``b`` respectively.  For example, if the entries
72        of ``A`` have 7 correct digits, set ``atol = 1e-7``. This prevents
73        the algorithm from doing unnecessary work beyond the
74        uncertainty of the input data.
75    conlim : float, optional
76        `lsmr` terminates if an estimate of ``cond(A)`` exceeds
77        `conlim`.  For compatible systems ``Ax = b``, conlim could be
78        as large as 1.0e+12 (say).  For least-squares problems,
79        `conlim` should be less than 1.0e+8. If `conlim` is None, the
80        default value is 1e+8.  Maximum precision can be obtained by
81        setting ``atol = btol = conlim = 0``, but the number of
82        iterations may then be excessive.
83    maxiter : int, optional
84        `lsmr` terminates if the number of iterations reaches
85        `maxiter`.  The default is ``maxiter = min(m, n)``.  For
86        ill-conditioned systems, a larger value of `maxiter` may be
87        needed.
88    show : bool, optional
89        Print iterations logs if ``show=True``.
90    x0 : array_like, shape (n,), optional
91        Initial guess of ``x``, if None zeros are used.
92
93        .. versionadded:: 1.0.0
94
95    Returns
96    -------
97    x : ndarray of float
98        Least-square solution returned.
99    istop : int
100        istop gives the reason for stopping::
101
102          istop   = 0 means x=0 is a solution.  If x0 was given, then x=x0 is a
103                      solution.
104                  = 1 means x is an approximate solution to A*x = B,
105                      according to atol and btol.
106                  = 2 means x approximately solves the least-squares problem
107                      according to atol.
108                  = 3 means COND(A) seems to be greater than CONLIM.
109                  = 4 is the same as 1 with atol = btol = eps (machine
110                      precision)
111                  = 5 is the same as 2 with atol = eps.
112                  = 6 is the same as 3 with CONLIM = 1/eps.
113                  = 7 means ITN reached maxiter before the other stopping
114                      conditions were satisfied.
115
116    itn : int
117        Number of iterations used.
118    normr : float
119        ``norm(b-Ax)``
120    normar : float
121        ``norm(A^H (b - Ax))``
122    norma : float
123        ``norm(A)``
124    conda : float
125        Condition number of A.
126    normx : float
127        ``norm(x)``
128
129    Notes
130    -----
131
132    .. versionadded:: 0.11.0
133
134    References
135    ----------
136    .. [1] D. C.-L. Fong and M. A. Saunders,
137           "LSMR: An iterative algorithm for sparse least-squares problems",
138           SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011.
139           :arxiv:`1006.0758`
140    .. [2] LSMR Software, https://web.stanford.edu/group/SOL/software/lsmr/
141
142    Examples
143    --------
144    >>> from scipy.sparse import csc_matrix
145    >>> from scipy.sparse.linalg import lsmr
146    >>> A = csc_matrix([[1., 0.], [1., 1.], [0., 1.]], dtype=float)
147
148    The first example has the trivial solution `[0, 0]`
149
150    >>> b = np.array([0., 0., 0.], dtype=float)
151    >>> x, istop, itn, normr = lsmr(A, b)[:4]
152    >>> istop
153    0
154    >>> x
155    array([ 0.,  0.])
156
157    The stopping code `istop=0` returned indicates that a vector of zeros was
158    found as a solution. The returned solution `x` indeed contains `[0., 0.]`.
159    The next example has a non-trivial solution:
160
161    >>> b = np.array([1., 0., -1.], dtype=float)
162    >>> x, istop, itn, normr = lsmr(A, b)[:4]
163    >>> istop
164    1
165    >>> x
166    array([ 1., -1.])
167    >>> itn
168    1
169    >>> normr
170    4.440892098500627e-16
171
172    As indicated by `istop=1`, `lsmr` found a solution obeying the tolerance
173    limits. The given solution `[1., -1.]` obviously solves the equation. The
174    remaining return values include information about the number of iterations
175    (`itn=1`) and the remaining difference of left and right side of the solved
176    equation.
177    The final example demonstrates the behavior in the case where there is no
178    solution for the equation:
179
180    >>> b = np.array([1., 0.01, -1.], dtype=float)
181    >>> x, istop, itn, normr = lsmr(A, b)[:4]
182    >>> istop
183    2
184    >>> x
185    array([ 1.00333333, -0.99666667])
186    >>> A.dot(x)-b
187    array([ 0.00333333, -0.00333333,  0.00333333])
188    >>> normr
189    0.005773502691896255
190
191    `istop` indicates that the system is inconsistent and thus `x` is rather an
192    approximate solution to the corresponding least-squares problem. `normr`
193    contains the minimal distance that was found.
194    """
195
196    A = aslinearoperator(A)
197    b = atleast_1d(b)
198    if b.ndim > 1:
199        b = b.squeeze()
200
201    msg = ('The exact solution is x = 0, or x = x0, if x0 was given  ',
202         'Ax - b is small enough, given atol, btol                  ',
203         'The least-squares solution is good enough, given atol     ',
204         'The estimate of cond(Abar) has exceeded conlim            ',
205         'Ax - b is small enough for this machine                   ',
206         'The least-squares solution is good enough for this machine',
207         'Cond(Abar) seems to be too large for this machine         ',
208         'The iteration limit has been reached                      ')
209
210    hdg1 = '   itn      x(1)       norm r    norm Ar'
211    hdg2 = ' compatible   LS      norm A   cond A'
212    pfreq = 20   # print frequency (for repeating the heading)
213    pcount = 0   # print counter
214
215    m, n = A.shape
216
217    # stores the num of singular values
218    minDim = min([m, n])
219
220    if maxiter is None:
221        maxiter = minDim
222
223    if x0 is None:
224        dtype = result_type(A, b, float)
225    else:
226        dtype = result_type(A, b, x0, float)
227
228    if show:
229        print(' ')
230        print('LSMR            Least-squares solution of  Ax = b\n')
231        print(f'The matrix A has {m} rows and {n} columns')
232        print('damp = %20.14e\n' % (damp))
233        print('atol = %8.2e                 conlim = %8.2e\n' % (atol, conlim))
234        print('btol = %8.2e             maxiter = %8g\n' % (btol, maxiter))
235
236    u = b
237    normb = norm(b)
238    if x0 is None:
239        x = zeros(n, dtype)
240        beta = normb.copy()
241    else:
242        x = atleast_1d(x0)
243        u = u - A.matvec(x)
244        beta = norm(u)
245
246    if beta > 0:
247        u = (1 / beta) * u
248        v = A.rmatvec(u)
249        alpha = norm(v)
250    else:
251        v = zeros(n, dtype)
252        alpha = 0
253
254    if alpha > 0:
255        v = (1 / alpha) * v
256
257    # Initialize variables for 1st iteration.
258
259    itn = 0
260    zetabar = alpha * beta
261    alphabar = alpha
262    rho = 1
263    rhobar = 1
264    cbar = 1
265    sbar = 0
266
267    h = v.copy()
268    hbar = zeros(n, dtype)
269
270    # Initialize variables for estimation of ||r||.
271
272    betadd = beta
273    betad = 0
274    rhodold = 1
275    tautildeold = 0
276    thetatilde = 0
277    zeta = 0
278    d = 0
279
280    # Initialize variables for estimation of ||A|| and cond(A)
281
282    normA2 = alpha * alpha
283    maxrbar = 0
284    minrbar = 1e+100
285    normA = sqrt(normA2)
286    condA = 1
287    normx = 0
288
289    # Items for use in stopping rules, normb set earlier
290    istop = 0
291    ctol = 0
292    if conlim > 0:
293        ctol = 1 / conlim
294    normr = beta
295
296    # Reverse the order here from the original matlab code because
297    # there was an error on return when arnorm==0
298    normar = alpha * beta
299    if normar == 0:
300        if show:
301            print(msg[0])
302        return x, istop, itn, normr, normar, normA, condA, normx
303
304    if show:
305        print(' ')
306        print(hdg1, hdg2)
307        test1 = 1
308        test2 = alpha / beta
309        str1 = '%6g %12.5e' % (itn, x[0])
310        str2 = ' %10.3e %10.3e' % (normr, normar)
311        str3 = '  %8.1e %8.1e' % (test1, test2)
312        print(''.join([str1, str2, str3]))
313
314    # Main iteration loop.
315    while itn < maxiter:
316        itn = itn + 1
317
318        # Perform the next step of the bidiagonalization to obtain the
319        # next  beta, u, alpha, v.  These satisfy the relations
320        #         beta*u  =  a*v   -  alpha*u,
321        #        alpha*v  =  A'*u  -  beta*v.
322
323        u *= -alpha
324        u += A.matvec(v)
325        beta = norm(u)
326
327        if beta > 0:
328            u *= (1 / beta)
329            v *= -beta
330            v += A.rmatvec(u)
331            alpha = norm(v)
332            if alpha > 0:
333                v *= (1 / alpha)
334
335        # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
336
337        # Construct rotation Qhat_{k,2k+1}.
338
339        chat, shat, alphahat = _sym_ortho(alphabar, damp)
340
341        # Use a plane rotation (Q_i) to turn B_i to R_i
342
343        rhoold = rho
344        c, s, rho = _sym_ortho(alphahat, beta)
345        thetanew = s*alpha
346        alphabar = c*alpha
347
348        # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
349
350        rhobarold = rhobar
351        zetaold = zeta
352        thetabar = sbar * rho
353        rhotemp = cbar * rho
354        cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew)
355        zeta = cbar * zetabar
356        zetabar = - sbar * zetabar
357
358        # Update h, h_hat, x.
359
360        hbar *= - (thetabar * rho / (rhoold * rhobarold))
361        hbar += h
362        x += (zeta / (rho * rhobar)) * hbar
363        h *= - (thetanew / rho)
364        h += v
365
366        # Estimate of ||r||.
367
368        # Apply rotation Qhat_{k,2k+1}.
369        betaacute = chat * betadd
370        betacheck = -shat * betadd
371
372        # Apply rotation Q_{k,k+1}.
373        betahat = c * betaacute
374        betadd = -s * betaacute
375
376        # Apply rotation Qtilde_{k-1}.
377        # betad = betad_{k-1} here.
378
379        thetatildeold = thetatilde
380        ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar)
381        thetatilde = stildeold * rhobar
382        rhodold = ctildeold * rhobar
383        betad = - stildeold * betad + ctildeold * betahat
384
385        # betad   = betad_k here.
386        # rhodold = rhod_k  here.
387
388        tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold
389        taud = (zeta - thetatilde * tautildeold) / rhodold
390        d = d + betacheck * betacheck
391        normr = sqrt(d + (betad - taud)**2 + betadd * betadd)
392
393        # Estimate ||A||.
394        normA2 = normA2 + beta * beta
395        normA = sqrt(normA2)
396        normA2 = normA2 + alpha * alpha
397
398        # Estimate cond(A).
399        maxrbar = max(maxrbar, rhobarold)
400        if itn > 1:
401            minrbar = min(minrbar, rhobarold)
402        condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp)
403
404        # Test for convergence.
405
406        # Compute norms for convergence testing.
407        normar = abs(zetabar)
408        normx = norm(x)
409
410        # Now use these norms to estimate certain other quantities,
411        # some of which will be small near a solution.
412
413        test1 = normr / normb
414        if (normA * normr) != 0:
415            test2 = normar / (normA * normr)
416        else:
417            test2 = infty
418        test3 = 1 / condA
419        t1 = test1 / (1 + normA * normx / normb)
420        rtol = btol + atol * normA * normx / normb
421
422        # The following tests guard against extremely small values of
423        # atol, btol or ctol.  (The user may have set any or all of
424        # the parameters atol, btol, conlim  to 0.)
425        # The effect is equivalent to the normAl tests using
426        # atol = eps,  btol = eps,  conlim = 1/eps.
427
428        if itn >= maxiter:
429            istop = 7
430        if 1 + test3 <= 1:
431            istop = 6
432        if 1 + test2 <= 1:
433            istop = 5
434        if 1 + t1 <= 1:
435            istop = 4
436
437        # Allow for tolerances set by the user.
438
439        if test3 <= ctol:
440            istop = 3
441        if test2 <= atol:
442            istop = 2
443        if test1 <= rtol:
444            istop = 1
445
446        # See if it is time to print something.
447
448        if show:
449            if (n <= 40) or (itn <= 10) or (itn >= maxiter - 10) or \
450               (itn % 10 == 0) or (test3 <= 1.1 * ctol) or \
451               (test2 <= 1.1 * atol) or (test1 <= 1.1 * rtol) or \
452               (istop != 0):
453
454                if pcount >= pfreq:
455                    pcount = 0
456                    print(' ')
457                    print(hdg1, hdg2)
458                pcount = pcount + 1
459                str1 = '%6g %12.5e' % (itn, x[0])
460                str2 = ' %10.3e %10.3e' % (normr, normar)
461                str3 = '  %8.1e %8.1e' % (test1, test2)
462                str4 = ' %8.1e %8.1e' % (normA, condA)
463                print(''.join([str1, str2, str3, str4]))
464
465        if istop > 0:
466            break
467
468    # Print the stopping condition.
469
470    if show:
471        print(' ')
472        print('LSMR finished')
473        print(msg[istop])
474        print('istop =%8g    normr =%8.1e' % (istop, normr))
475        print('    normA =%8.1e    normAr =%8.1e' % (normA, normar))
476        print('itn   =%8g    condA =%8.1e' % (itn, condA))
477        print('    normx =%8.1e' % (normx))
478        print(str1, str2)
479        print(str3, str4)
480
481    return x, istop, itn, normr, normar, normA, condA, normx
482