1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #include "lsmrBase.h"
19 #include "vcl_compiler_detection.h"
20 
21 #include <algorithm>
22 #include <cmath>
23 #include <iostream>
24 #include <numeric>
25 #include <vector>
26 
daxpy(unsigned int n,double alpha,const double * x,double * y)27 inline void daxpy( unsigned int n, double alpha, const double * x, double * y )
28 {
29   const double * xend = x+n;
30   while ( x!=xend ) {
31     *y++ += alpha * *x++;
32   }
33 }
34 
35 #define Sqr(x) ((x)*(x))
36 
lsmrBase()37 lsmrBase::lsmrBase()
38 {
39   this->eps = 1e-16;
40   this->atol = 1e-6;
41   this->btol = 1e-6;
42   this->conlim = 1.0 / ( 10 * sqrt( this->eps ) );
43   this->itnlim = 10;
44   this->nout = nullptr;
45   this->istop = 0;
46   this->itn = 0;
47   this->normA = 0.0;
48   this->condA = 0.0;
49   this->normr = 0.0;
50   this->normAr = 0.0;
51   this->normx = 0.0;
52   this->normb = 0.0;
53   this->dxmax = 0.0;
54   this->maxdx = 0;
55   this->damp = 0.0;
56   this->damped = false;
57   this->localSize = 0;
58 }
59 
60 
61 lsmrBase::~lsmrBase() = default;
62 
63 
64 unsigned int
GetStoppingReason() const65 lsmrBase::GetStoppingReason() const
66 {
67   return this->istop;
68 }
69 
70 
71 unsigned int
GetNumberOfIterationsPerformed() const72 lsmrBase::GetNumberOfIterationsPerformed() const
73 {
74   return this->itn;
75 }
76 
77 
78 double
GetFrobeniusNormEstimateOfAbar() const79 lsmrBase::GetFrobeniusNormEstimateOfAbar() const
80 {
81   return this->normA;
82 }
83 
84 
85 double
GetConditionNumberEstimateOfAbar() const86 lsmrBase::GetConditionNumberEstimateOfAbar() const
87 {
88   return this->condA;
89 }
90 
91 
92 double
GetFinalEstimateOfNormRbar() const93 lsmrBase::GetFinalEstimateOfNormRbar() const
94 {
95   return this->normr;
96 }
97 
98 
99 double
GetFinalEstimateOfNormOfResiduals() const100 lsmrBase::GetFinalEstimateOfNormOfResiduals() const
101 {
102   return this->normAr;
103 }
104 
105 
106 double
GetFinalEstimateOfNormOfX() const107 lsmrBase::GetFinalEstimateOfNormOfX() const
108 {
109   return this->normx;
110 }
111 
112 
113 void
SetLocalSize(unsigned int n)114 lsmrBase::SetLocalSize( unsigned int n )
115 {
116   this->localSize = n;
117 }
118 
119 
120 void
SetEpsilon(double value)121 lsmrBase::SetEpsilon( double value )
122 {
123   this->eps = value;
124 }
125 
126 
127 void
SetDamp(double value)128 lsmrBase::SetDamp( double value )
129 {
130   this->damp = value;
131 }
132 
133 
134 void
SetToleranceA(double value)135 lsmrBase::SetToleranceA( double value )
136 {
137   this->atol = value;
138 }
139 
140 
141 void
SetToleranceB(double value)142 lsmrBase::SetToleranceB( double value )
143 {
144   this->btol = value;
145 }
146 
147 
148 void
SetMaximumNumberOfIterations(unsigned int value)149 lsmrBase::SetMaximumNumberOfIterations( unsigned int value )
150 {
151   this->itnlim = value;
152 }
153 
154 
155 void
SetUpperLimitOnConditional(double value)156 lsmrBase::SetUpperLimitOnConditional( double value )
157 {
158   this->conlim = value;
159 }
160 
161 
162 void
SetOutputStream(std::ostream & os)163 lsmrBase::SetOutputStream( std::ostream & os )
164 {
165   this->nout = &os;
166 }
167 
168 
169 /**
170  *  returns sqrt( a**2 + b**2 )
171  *  with precautions to avoid overflow.
172  */
173 double
D2Norm(double a,double b) const174 lsmrBase::D2Norm( double a, double b ) const
175 {
176   const double scale = std::abs(a) + std::abs(b);
177   const double zero = 0.0;
178 
179   if( scale == zero )
180     {
181       return zero;
182     }
183 
184   const double sa = a / scale;
185   const double sb = b / scale;
186 
187   return scale * sqrt( sa * sa + sb * sb );
188 }
189 
190 
191 /** Simplified for this use from the BLAS version. */
192 void
Scale(unsigned int n,double factor,double * x) const193 lsmrBase::Scale( unsigned int n, double factor, double *x ) const
194 {
195   double * xend = x + n;
196   while( x != xend )
197     {
198       *x++ *= factor;
199     }
200 }
201 
202 double
Dnrm2(unsigned int n,const double * x) const203 lsmrBase::Dnrm2( unsigned int n, const double *x ) const
204 {
205   double magnitudeOfLargestElement = 0.0;
206 
207   double sumOfSquaresScaled = 1.0;
208 
209   for ( unsigned int i = 0; i < n; i++ )
210     {
211       if ( x[i] != 0.0 )
212     {
213       double dx = x[i];
214       const double absxi = std::abs(dx);
215 
216       if ( magnitudeOfLargestElement < absxi )
217         {
218           // rescale the sum to the range of the new element
219           dx = magnitudeOfLargestElement / absxi;
220           sumOfSquaresScaled = sumOfSquaresScaled * (dx * dx) + 1.0;
221           magnitudeOfLargestElement = absxi;
222         }
223       else
224         {
225           // rescale the new element to the range of the sum
226           dx = absxi / magnitudeOfLargestElement;
227           sumOfSquaresScaled += dx * dx;
228         }
229     }
230     }
231 
232   const double norm = magnitudeOfLargestElement * sqrt( sumOfSquaresScaled );
233 
234   return norm;
235 }
236 
237 /**
238  *
239  *  The array b must have size m
240  *
241  */
242 void lsmrBase::
Solve(unsigned int m,unsigned int n,const double * b,double * x)243 Solve( unsigned int m, unsigned int n, const double * b, double * x )
244 {
245   const double zero = 0.0;
246   const double one = 1.0;
247 
248   double test1;
249   double test2;
250 
251   // Initialize.
252 
253   unsigned int localVecs = std::min( localSize, std::min( m,n ) );
254 
255   if( this->nout )
256     {
257     (*this->nout) << " Enter LSMR.       Least-squares solution of  Ax = b\n" << std::endl;
258     (*this->nout) << " The matrix  A  has " << m << " rows   and " << n << " columns" << std::endl;
259     (*this->nout) << " damp   = " << this->damp << std::endl;
260     (*this->nout) << " atol   = " << this->atol << ", conlim = " << this->conlim << std::endl;
261     (*this->nout) << " btol   = " << this->btol << ", itnlim = " << this->itnlim << std::endl;
262     (*this->nout) << " localSize (no. of vectors for local reorthogonalization) = " << this->localSize << std::endl;
263     }
264 
265   int pfreq = 20;
266   int pcount = 0;
267   this->damped = ( this->damp > zero );
268 
269   std::vector<double> workBuffer( m+5*n+n*localVecs );
270   double * u = &workBuffer[0];
271   double * v = u+m;
272   double * w = v+n;
273   double * h = w+n;
274   double * hbar = h+n;
275   double * localV = hbar+n;
276 
277   //-------------------------------------------------------------------
278   //  Set up the first vectors u and v for the bidiagonalization.
279   //  These satisfy  beta*u = b,  alpha*v = A(transpose)*u.
280   //-------------------------------------------------------------------
281   std::copy( b, b+m, u );
282   std::fill( v, v+n, zero);
283   std::fill( w, v+n, zero);
284   std::fill( x, x+n, zero);
285   this->Scale( m, (-1.0), u );
286   this->Aprod1( m, n, x, u );
287   this->Scale( m, (-1.0), u );
288 
289   double alpha = zero;
290 
291   double beta =  this->Dnrm2( m, u );
292 
293   if( beta > zero ) {
294     this->Scale( m, ( one / beta ), u );
295     this->Aprod2( m, n, v, u );   //     v = A'*u
296     alpha = this->Dnrm2( n, v );
297   }
298 
299   if( alpha > zero )
300     {
301       this->Scale( n, ( one / alpha ), v );
302       std::copy( v, v+n, w );
303     }
304 
305   this->normAr = alpha * beta;
306 
307   if ( this->normAr == zero )
308     {
309       this->TerminationPrintOut();
310       return;
311     }
312 
313   // Initialization for local reorthogonalization.
314   bool localOrtho = false;
315   bool localVQueueFull = false;
316   unsigned int localPointer = 0;
317   if ( localVecs > 0 ) {
318     localOrtho      = true;
319     std::copy( v, v+n, localV );
320   }
321 
322   // Initialize variables for 1st iteration.
323   this->itn = 0;
324   double zetabar  = alpha*beta;
325   double alphabar = alpha;
326   double rho      = one;
327   double rhobar   = one;
328   double cbar     = one;
329   double sbar     = zero;
330 
331   std::copy( v, v+n, h );
332   std::fill( hbar, hbar+n, zero);
333 
334   // Initialize variables for estimation of ||r||.
335   double betadd      = beta;
336   double betad       = zero;
337   double rhodold     = one;
338   double tautildeold = zero;
339   double thetatilde  = zero;
340   double zeta        = zero;
341   double d           = zero;
342 
343   // Initialize variables for estimation of ||A|| and cond(A).
344 
345   double normA2  = alpha*alpha;
346   double maxrbar = zero;
347   double minrbar = 1e+100;
348 
349   // Items for use in stopping rules.
350   this->normb  = beta;
351   this->istop  = 0;
352   double ctol   = zero;
353 
354   if (this->conlim > zero) {
355     ctol = one/this->conlim;
356   }
357   this->normr  = beta;
358 
359   if ( this->nout )
360     {
361       if ( damped )
362     {
363       (*this->nout) << "   Itn       x(1)           norm rbar    Abar'rbar"
364         " Compatible    LS    norm Abar cond Abar\n";
365     }
366       else
367     {
368       (*this->nout) << "   Itn       x(1)            norm r         A'r   "
369         " Compatible    LS      norm A    cond A\n";
370     }
371 
372       test1 = one;
373       test2 = alpha / beta;
374 
375       (*this->nout) << this->itn << ", " << x[0] << ", " << this->normr << ", " << this->normA << ", " << test1 << ", " << test2 << std::endl;
376     }
377 
378   //  Main iteration loop
379   do {
380     this->itn++;
381 
382     //----------------------------------------------------------------
383     //  Perform the next step of the bidiagonalization to obtain the
384     //  next beta, u, alpha, v.  These satisfy
385     //      beta*u = A*v  - alpha*u,
386     //     alpha*v = A'*u -  beta*v.
387     //----------------------------------------------------------------
388     this->Scale( m, (-alpha), u );
389 
390     this->Aprod1( m, n, v, u );   //   u = A * v
391 
392     beta = this->Dnrm2( m, u );
393 
394     if ( beta > zero )
395       {
396     this->Scale( m, (one/beta), u );
397     if ( localOrtho ) {
398       if (localPointer+1 < localVecs) {
399         localPointer = localPointer + 1;
400       } else {
401         localPointer = 0;
402         localVQueueFull = true;
403       }
404       std::copy( v, v+n, localV+localPointer*n );
405     }
406     this->Scale( n, (- beta), v );
407     this->Aprod2( m, n, v, u );    // v = A'*u
408     if ( localOrtho ) {
409       unsigned int localOrthoLimit = localVQueueFull ? localVecs : localPointer+1;
410 
411       for( unsigned int localOrthoCount =0; localOrthoCount<localOrthoLimit;
412            ++localOrthoCount) {
413         double d = std::inner_product(v,v+n,localV+n*localOrthoCount,0.0);
414         daxpy( n, -d, localV+localOrthoCount*n, v );
415       }
416     }
417 
418     alpha  = this->Dnrm2( n, v );
419 
420     if ( alpha > zero )
421       {
422             this->Scale( n, (one/alpha), v );
423       }
424       }
425 
426     // At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
427 
428 
429     //----------------------------------------------------------------
430     // Construct rotation Qhat_{k,2k+1}.
431 
432     double alphahat = this->D2Norm( alphabar, damp );
433     double chat     = alphabar/alphahat;
434     double shat     = damp/alphahat;
435 
436     // Use a plane rotation (Q_i) to turn B_i to R_i.
437 
438     double rhoold   = rho;
439     rho      = D2Norm(alphahat, beta);
440     double c        = alphahat/rho;
441     double s        = beta/rho;
442     double thetanew = s*alpha;
443     alphabar = c*alpha;
444 
445     // Use a plane rotation (Qbar_i) to turn R_i^T into R_i^bar.
446 
447     double rhobarold = rhobar;
448     double zetaold   = zeta;
449     double thetabar  = sbar*rho;
450     double rhotemp   = cbar*rho;
451     rhobar    = this->D2Norm(cbar*rho, thetanew);
452     cbar      = cbar*rho/rhobar;
453     sbar      = thetanew/rhobar;
454     zeta      =   cbar*zetabar;
455     zetabar   = - sbar*zetabar;
456 
457     // Update h, h_hat, x.
458 
459     for( unsigned int i=0;i<n;++i) {
460       hbar[i] = h[i] - (thetabar*rho/(rhoold*rhobarold))*hbar[i];
461       x[i] = x[i] + (zeta/(rho*rhobar))*hbar[i];
462       h[i] = v[i] - (thetanew/rho)*h[i];
463     }
464 
465     // Estimate ||r||.
466 
467     // Apply rotation Qhat_{k,2k+1}.
468     double betaacute =   chat* betadd;
469     double betacheck = - shat* betadd;
470 
471     // Apply rotation Q_{k,k+1}.
472     double betahat   =   c*betaacute;
473     betadd    = - s*betaacute;
474 
475     // Apply rotation Qtilde_{k-1}.
476     // betad = betad_{k-1} here.
477 
478     double thetatildeold = thetatilde;
479     double rhotildeold   = this->D2Norm(rhodold, thetabar);
480     double ctildeold     = rhodold/rhotildeold;
481     double stildeold     = thetabar/rhotildeold;
482     thetatilde    = stildeold* rhobar;
483     rhodold       =   ctildeold* rhobar;
484     betad         = - stildeold*betad + ctildeold*betahat;
485 
486     // betad   = betad_k here.
487     // rhodold = rhod_k  here.
488 
489     tautildeold   = (zetaold - thetatildeold*tautildeold)/rhotildeold;
490     double taud          = (zeta - thetatilde*tautildeold)/rhodold;
491     d             = d + betacheck*betacheck;
492     this->normr         = sqrt(d + Sqr(betad - taud) + Sqr(betadd));
493 
494     // Estimate ||A||.
495     normA2        = normA2 + Sqr(beta);
496     this->normA   = sqrt(normA2);
497     normA2        = normA2 + Sqr(alpha);
498 
499     // Estimate cond(A).
500     maxrbar       = std::max(maxrbar,rhobarold);
501     if (this->itn > 1) {
502       minrbar    = std::min(minrbar,rhobarold);
503     }
504     this->condA   = std::max(maxrbar,rhotemp)/std::min(minrbar,rhotemp);
505 
506     //----------------------------------------------------------------
507     //Test for convergence.
508     //---------------------------------------------------------------
509 
510     // Compute norms for convergence testing.
511     this->normAr  = std::abs(zetabar);
512     this->normx   = this->Dnrm2(n, x);
513 
514     // Now use these norms to estimate certain other quantities,
515     // some of which will be small near a solution.
516 
517     test1   = this->normr / this->normb;
518     test2   = this->normAr/(this->normA*this->normr);
519     double test3   = one/this->condA;
520     double t1      = test1/(one + this->normA*this->normx/this->normb);
521     double rtol    = this->btol + this->atol*this->normA*normx/this->normb;
522 
523     // The following tests guard against extremely small values of
524     // atol, btol or ctol.  (The user may have set any or all of
525     // the parameters atol, btol, conlim  to 0.)
526     // The effect is equivalent to the normAl tests using
527     // atol = eps,  btol = eps,  conlim = 1/eps.
528 
529     if ( this->itn >= this->itnlim ) this->istop = 7;
530     if (one+test3 <=  one) this->istop = 6;
531     if (one+test2 <=  one) this->istop = 5;
532     if (one+t1    <=  one) this->istop = 4;
533 
534     // Allow for tolerances set by the user.
535 
536     if ( test3   <= ctol ) this->istop = 3;
537     if ( test2   <= this->atol ) this->istop = 2;
538     if ( test1   <= rtol ) this->istop = 1;
539 
540     //----------------------------------------------------------------
541     // See if it is time to print something.
542     //----------------------------------------------------------------
543     if ( this->nout ) {
544       bool prnt = false;
545       if ( n<=40 ) prnt = true;
546       if ( this->itn <= 10 ) prnt = true;
547       if ( this->itn >= this->itnlim-10 ) prnt = true;
548       if ( (this->itn % 10)  ==  0 ) prnt = true;
549       if ( test3 <=  1.1*ctol ) prnt = true;
550       if ( test2 <=  1.1*this->atol ) prnt = true;
551       if ( test1 <=  1.1*rtol ) prnt = true;
552       if ( this->istop!=0 ) prnt = true;
553 
554       if ( prnt ) { // Print a line for this iteration
555     if ( pcount >= pfreq ) { // Print a heading first
556       pcount = 0;
557       if ( damped )
558         {
559           (*this->nout) << "   Itn       x(1)           norm rbar    Abar'rbar"
560         " Compatible    LS    norm Abar cond Abar\n";
561         } else {
562         (*this->nout) << "   Itn       x(1)            norm r         A'r   "
563           " Compatible    LS      norm A    cond A\n";
564       }
565     }
566     pcount = pcount + 1;
567     (*this->nout)
568       << this->itn << ", " << x[0] << ", " <<this->normr << ", " << this->normAr << ", " << test1 << ", " << test2
569       << ", " << this->normA << ", " << this->condA << std::endl;
570       }
571     }
572 
573   } while ( this->istop == 0);
574 
575   this->TerminationPrintOut();
576 }
577 
578 
579 void lsmrBase::
TerminationPrintOut()580 TerminationPrintOut()
581 {
582   const char * msg[] = {
583     "The exact solution is  x = 0                         ",
584     "Ax - b is small enough, given atol, btol             ",
585     "The least-squares solution is good enough, given atol",
586     "The estimate of cond(Abar) has exceeded conlim       ",
587     "Ax - b is small enough for this machine              ",
588     "The LS solution is good enough for this machine      ",
589     "Cond(Abar) seems to be too large for this machine    ",
590     "The iteration limit has been reached                 " };
591 
592   if ( this->damped && this->istop==2 ) this->istop=3;
593 
594   if ( this->nout ) {
595     (*this->nout) << " Exit  LSMR.       istop  = " << this->istop << "     ,itn    = " << this->itn << std::endl
596           << " Exit  LSMR.       normA  = " << this->normA << "     ,condA  = " << this->condA << std::endl
597           << " Exit  LSMR.       normb  = " << this->normb << "     ,normx  = " << this->normx << std::endl
598           << " Exit  LSMR.       normr  = " << this->normr << "     ,normAr = " << this->normAr << std::endl
599           << " Exit  LSMR.       " << msg[this->istop] << std::endl;
600   }
601 }
602