1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18 #include "lsmrBase.h"
19 #include "vcl_compiler_detection.h"
20
21 #include <algorithm>
22 #include <cmath>
23 #include <iostream>
24 #include <numeric>
25 #include <vector>
26
daxpy(unsigned int n,double alpha,const double * x,double * y)27 inline void daxpy( unsigned int n, double alpha, const double * x, double * y )
28 {
29 const double * xend = x+n;
30 while ( x!=xend ) {
31 *y++ += alpha * *x++;
32 }
33 }
34
35 #define Sqr(x) ((x)*(x))
36
lsmrBase()37 lsmrBase::lsmrBase()
38 {
39 this->eps = 1e-16;
40 this->atol = 1e-6;
41 this->btol = 1e-6;
42 this->conlim = 1.0 / ( 10 * sqrt( this->eps ) );
43 this->itnlim = 10;
44 this->nout = nullptr;
45 this->istop = 0;
46 this->itn = 0;
47 this->normA = 0.0;
48 this->condA = 0.0;
49 this->normr = 0.0;
50 this->normAr = 0.0;
51 this->normx = 0.0;
52 this->normb = 0.0;
53 this->dxmax = 0.0;
54 this->maxdx = 0;
55 this->damp = 0.0;
56 this->damped = false;
57 this->localSize = 0;
58 }
59
60
61 lsmrBase::~lsmrBase() = default;
62
63
64 unsigned int
GetStoppingReason() const65 lsmrBase::GetStoppingReason() const
66 {
67 return this->istop;
68 }
69
70
71 unsigned int
GetNumberOfIterationsPerformed() const72 lsmrBase::GetNumberOfIterationsPerformed() const
73 {
74 return this->itn;
75 }
76
77
78 double
GetFrobeniusNormEstimateOfAbar() const79 lsmrBase::GetFrobeniusNormEstimateOfAbar() const
80 {
81 return this->normA;
82 }
83
84
85 double
GetConditionNumberEstimateOfAbar() const86 lsmrBase::GetConditionNumberEstimateOfAbar() const
87 {
88 return this->condA;
89 }
90
91
92 double
GetFinalEstimateOfNormRbar() const93 lsmrBase::GetFinalEstimateOfNormRbar() const
94 {
95 return this->normr;
96 }
97
98
99 double
GetFinalEstimateOfNormOfResiduals() const100 lsmrBase::GetFinalEstimateOfNormOfResiduals() const
101 {
102 return this->normAr;
103 }
104
105
106 double
GetFinalEstimateOfNormOfX() const107 lsmrBase::GetFinalEstimateOfNormOfX() const
108 {
109 return this->normx;
110 }
111
112
113 void
SetLocalSize(unsigned int n)114 lsmrBase::SetLocalSize( unsigned int n )
115 {
116 this->localSize = n;
117 }
118
119
120 void
SetEpsilon(double value)121 lsmrBase::SetEpsilon( double value )
122 {
123 this->eps = value;
124 }
125
126
127 void
SetDamp(double value)128 lsmrBase::SetDamp( double value )
129 {
130 this->damp = value;
131 }
132
133
134 void
SetToleranceA(double value)135 lsmrBase::SetToleranceA( double value )
136 {
137 this->atol = value;
138 }
139
140
141 void
SetToleranceB(double value)142 lsmrBase::SetToleranceB( double value )
143 {
144 this->btol = value;
145 }
146
147
148 void
SetMaximumNumberOfIterations(unsigned int value)149 lsmrBase::SetMaximumNumberOfIterations( unsigned int value )
150 {
151 this->itnlim = value;
152 }
153
154
155 void
SetUpperLimitOnConditional(double value)156 lsmrBase::SetUpperLimitOnConditional( double value )
157 {
158 this->conlim = value;
159 }
160
161
162 void
SetOutputStream(std::ostream & os)163 lsmrBase::SetOutputStream( std::ostream & os )
164 {
165 this->nout = &os;
166 }
167
168
169 /**
170 * returns sqrt( a**2 + b**2 )
171 * with precautions to avoid overflow.
172 */
173 double
D2Norm(double a,double b) const174 lsmrBase::D2Norm( double a, double b ) const
175 {
176 const double scale = std::abs(a) + std::abs(b);
177 const double zero = 0.0;
178
179 if( scale == zero )
180 {
181 return zero;
182 }
183
184 const double sa = a / scale;
185 const double sb = b / scale;
186
187 return scale * sqrt( sa * sa + sb * sb );
188 }
189
190
191 /** Simplified for this use from the BLAS version. */
192 void
Scale(unsigned int n,double factor,double * x) const193 lsmrBase::Scale( unsigned int n, double factor, double *x ) const
194 {
195 double * xend = x + n;
196 while( x != xend )
197 {
198 *x++ *= factor;
199 }
200 }
201
202 double
Dnrm2(unsigned int n,const double * x) const203 lsmrBase::Dnrm2( unsigned int n, const double *x ) const
204 {
205 double magnitudeOfLargestElement = 0.0;
206
207 double sumOfSquaresScaled = 1.0;
208
209 for ( unsigned int i = 0; i < n; i++ )
210 {
211 if ( x[i] != 0.0 )
212 {
213 double dx = x[i];
214 const double absxi = std::abs(dx);
215
216 if ( magnitudeOfLargestElement < absxi )
217 {
218 // rescale the sum to the range of the new element
219 dx = magnitudeOfLargestElement / absxi;
220 sumOfSquaresScaled = sumOfSquaresScaled * (dx * dx) + 1.0;
221 magnitudeOfLargestElement = absxi;
222 }
223 else
224 {
225 // rescale the new element to the range of the sum
226 dx = absxi / magnitudeOfLargestElement;
227 sumOfSquaresScaled += dx * dx;
228 }
229 }
230 }
231
232 const double norm = magnitudeOfLargestElement * sqrt( sumOfSquaresScaled );
233
234 return norm;
235 }
236
237 /**
238 *
239 * The array b must have size m
240 *
241 */
242 void lsmrBase::
Solve(unsigned int m,unsigned int n,const double * b,double * x)243 Solve( unsigned int m, unsigned int n, const double * b, double * x )
244 {
245 const double zero = 0.0;
246 const double one = 1.0;
247
248 double test1;
249 double test2;
250
251 // Initialize.
252
253 unsigned int localVecs = std::min( localSize, std::min( m,n ) );
254
255 if( this->nout )
256 {
257 (*this->nout) << " Enter LSMR. Least-squares solution of Ax = b\n" << std::endl;
258 (*this->nout) << " The matrix A has " << m << " rows and " << n << " columns" << std::endl;
259 (*this->nout) << " damp = " << this->damp << std::endl;
260 (*this->nout) << " atol = " << this->atol << ", conlim = " << this->conlim << std::endl;
261 (*this->nout) << " btol = " << this->btol << ", itnlim = " << this->itnlim << std::endl;
262 (*this->nout) << " localSize (no. of vectors for local reorthogonalization) = " << this->localSize << std::endl;
263 }
264
265 int pfreq = 20;
266 int pcount = 0;
267 this->damped = ( this->damp > zero );
268
269 std::vector<double> workBuffer( m+5*n+n*localVecs );
270 double * u = &workBuffer[0];
271 double * v = u+m;
272 double * w = v+n;
273 double * h = w+n;
274 double * hbar = h+n;
275 double * localV = hbar+n;
276
277 //-------------------------------------------------------------------
278 // Set up the first vectors u and v for the bidiagonalization.
279 // These satisfy beta*u = b, alpha*v = A(transpose)*u.
280 //-------------------------------------------------------------------
281 std::copy( b, b+m, u );
282 std::fill( v, v+n, zero);
283 std::fill( w, v+n, zero);
284 std::fill( x, x+n, zero);
285 this->Scale( m, (-1.0), u );
286 this->Aprod1( m, n, x, u );
287 this->Scale( m, (-1.0), u );
288
289 double alpha = zero;
290
291 double beta = this->Dnrm2( m, u );
292
293 if( beta > zero ) {
294 this->Scale( m, ( one / beta ), u );
295 this->Aprod2( m, n, v, u ); // v = A'*u
296 alpha = this->Dnrm2( n, v );
297 }
298
299 if( alpha > zero )
300 {
301 this->Scale( n, ( one / alpha ), v );
302 std::copy( v, v+n, w );
303 }
304
305 this->normAr = alpha * beta;
306
307 if ( this->normAr == zero )
308 {
309 this->TerminationPrintOut();
310 return;
311 }
312
313 // Initialization for local reorthogonalization.
314 bool localOrtho = false;
315 bool localVQueueFull = false;
316 unsigned int localPointer = 0;
317 if ( localVecs > 0 ) {
318 localOrtho = true;
319 std::copy( v, v+n, localV );
320 }
321
322 // Initialize variables for 1st iteration.
323 this->itn = 0;
324 double zetabar = alpha*beta;
325 double alphabar = alpha;
326 double rho = one;
327 double rhobar = one;
328 double cbar = one;
329 double sbar = zero;
330
331 std::copy( v, v+n, h );
332 std::fill( hbar, hbar+n, zero);
333
334 // Initialize variables for estimation of ||r||.
335 double betadd = beta;
336 double betad = zero;
337 double rhodold = one;
338 double tautildeold = zero;
339 double thetatilde = zero;
340 double zeta = zero;
341 double d = zero;
342
343 // Initialize variables for estimation of ||A|| and cond(A).
344
345 double normA2 = alpha*alpha;
346 double maxrbar = zero;
347 double minrbar = 1e+100;
348
349 // Items for use in stopping rules.
350 this->normb = beta;
351 this->istop = 0;
352 double ctol = zero;
353
354 if (this->conlim > zero) {
355 ctol = one/this->conlim;
356 }
357 this->normr = beta;
358
359 if ( this->nout )
360 {
361 if ( damped )
362 {
363 (*this->nout) << " Itn x(1) norm rbar Abar'rbar"
364 " Compatible LS norm Abar cond Abar\n";
365 }
366 else
367 {
368 (*this->nout) << " Itn x(1) norm r A'r "
369 " Compatible LS norm A cond A\n";
370 }
371
372 test1 = one;
373 test2 = alpha / beta;
374
375 (*this->nout) << this->itn << ", " << x[0] << ", " << this->normr << ", " << this->normA << ", " << test1 << ", " << test2 << std::endl;
376 }
377
378 // Main iteration loop
379 do {
380 this->itn++;
381
382 //----------------------------------------------------------------
383 // Perform the next step of the bidiagonalization to obtain the
384 // next beta, u, alpha, v. These satisfy
385 // beta*u = A*v - alpha*u,
386 // alpha*v = A'*u - beta*v.
387 //----------------------------------------------------------------
388 this->Scale( m, (-alpha), u );
389
390 this->Aprod1( m, n, v, u ); // u = A * v
391
392 beta = this->Dnrm2( m, u );
393
394 if ( beta > zero )
395 {
396 this->Scale( m, (one/beta), u );
397 if ( localOrtho ) {
398 if (localPointer+1 < localVecs) {
399 localPointer = localPointer + 1;
400 } else {
401 localPointer = 0;
402 localVQueueFull = true;
403 }
404 std::copy( v, v+n, localV+localPointer*n );
405 }
406 this->Scale( n, (- beta), v );
407 this->Aprod2( m, n, v, u ); // v = A'*u
408 if ( localOrtho ) {
409 unsigned int localOrthoLimit = localVQueueFull ? localVecs : localPointer+1;
410
411 for( unsigned int localOrthoCount =0; localOrthoCount<localOrthoLimit;
412 ++localOrthoCount) {
413 double d = std::inner_product(v,v+n,localV+n*localOrthoCount,0.0);
414 daxpy( n, -d, localV+localOrthoCount*n, v );
415 }
416 }
417
418 alpha = this->Dnrm2( n, v );
419
420 if ( alpha > zero )
421 {
422 this->Scale( n, (one/alpha), v );
423 }
424 }
425
426 // At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
427
428
429 //----------------------------------------------------------------
430 // Construct rotation Qhat_{k,2k+1}.
431
432 double alphahat = this->D2Norm( alphabar, damp );
433 double chat = alphabar/alphahat;
434 double shat = damp/alphahat;
435
436 // Use a plane rotation (Q_i) to turn B_i to R_i.
437
438 double rhoold = rho;
439 rho = D2Norm(alphahat, beta);
440 double c = alphahat/rho;
441 double s = beta/rho;
442 double thetanew = s*alpha;
443 alphabar = c*alpha;
444
445 // Use a plane rotation (Qbar_i) to turn R_i^T into R_i^bar.
446
447 double rhobarold = rhobar;
448 double zetaold = zeta;
449 double thetabar = sbar*rho;
450 double rhotemp = cbar*rho;
451 rhobar = this->D2Norm(cbar*rho, thetanew);
452 cbar = cbar*rho/rhobar;
453 sbar = thetanew/rhobar;
454 zeta = cbar*zetabar;
455 zetabar = - sbar*zetabar;
456
457 // Update h, h_hat, x.
458
459 for( unsigned int i=0;i<n;++i) {
460 hbar[i] = h[i] - (thetabar*rho/(rhoold*rhobarold))*hbar[i];
461 x[i] = x[i] + (zeta/(rho*rhobar))*hbar[i];
462 h[i] = v[i] - (thetanew/rho)*h[i];
463 }
464
465 // Estimate ||r||.
466
467 // Apply rotation Qhat_{k,2k+1}.
468 double betaacute = chat* betadd;
469 double betacheck = - shat* betadd;
470
471 // Apply rotation Q_{k,k+1}.
472 double betahat = c*betaacute;
473 betadd = - s*betaacute;
474
475 // Apply rotation Qtilde_{k-1}.
476 // betad = betad_{k-1} here.
477
478 double thetatildeold = thetatilde;
479 double rhotildeold = this->D2Norm(rhodold, thetabar);
480 double ctildeold = rhodold/rhotildeold;
481 double stildeold = thetabar/rhotildeold;
482 thetatilde = stildeold* rhobar;
483 rhodold = ctildeold* rhobar;
484 betad = - stildeold*betad + ctildeold*betahat;
485
486 // betad = betad_k here.
487 // rhodold = rhod_k here.
488
489 tautildeold = (zetaold - thetatildeold*tautildeold)/rhotildeold;
490 double taud = (zeta - thetatilde*tautildeold)/rhodold;
491 d = d + betacheck*betacheck;
492 this->normr = sqrt(d + Sqr(betad - taud) + Sqr(betadd));
493
494 // Estimate ||A||.
495 normA2 = normA2 + Sqr(beta);
496 this->normA = sqrt(normA2);
497 normA2 = normA2 + Sqr(alpha);
498
499 // Estimate cond(A).
500 maxrbar = std::max(maxrbar,rhobarold);
501 if (this->itn > 1) {
502 minrbar = std::min(minrbar,rhobarold);
503 }
504 this->condA = std::max(maxrbar,rhotemp)/std::min(minrbar,rhotemp);
505
506 //----------------------------------------------------------------
507 //Test for convergence.
508 //---------------------------------------------------------------
509
510 // Compute norms for convergence testing.
511 this->normAr = std::abs(zetabar);
512 this->normx = this->Dnrm2(n, x);
513
514 // Now use these norms to estimate certain other quantities,
515 // some of which will be small near a solution.
516
517 test1 = this->normr / this->normb;
518 test2 = this->normAr/(this->normA*this->normr);
519 double test3 = one/this->condA;
520 double t1 = test1/(one + this->normA*this->normx/this->normb);
521 double rtol = this->btol + this->atol*this->normA*normx/this->normb;
522
523 // The following tests guard against extremely small values of
524 // atol, btol or ctol. (The user may have set any or all of
525 // the parameters atol, btol, conlim to 0.)
526 // The effect is equivalent to the normAl tests using
527 // atol = eps, btol = eps, conlim = 1/eps.
528
529 if ( this->itn >= this->itnlim ) this->istop = 7;
530 if (one+test3 <= one) this->istop = 6;
531 if (one+test2 <= one) this->istop = 5;
532 if (one+t1 <= one) this->istop = 4;
533
534 // Allow for tolerances set by the user.
535
536 if ( test3 <= ctol ) this->istop = 3;
537 if ( test2 <= this->atol ) this->istop = 2;
538 if ( test1 <= rtol ) this->istop = 1;
539
540 //----------------------------------------------------------------
541 // See if it is time to print something.
542 //----------------------------------------------------------------
543 if ( this->nout ) {
544 bool prnt = false;
545 if ( n<=40 ) prnt = true;
546 if ( this->itn <= 10 ) prnt = true;
547 if ( this->itn >= this->itnlim-10 ) prnt = true;
548 if ( (this->itn % 10) == 0 ) prnt = true;
549 if ( test3 <= 1.1*ctol ) prnt = true;
550 if ( test2 <= 1.1*this->atol ) prnt = true;
551 if ( test1 <= 1.1*rtol ) prnt = true;
552 if ( this->istop!=0 ) prnt = true;
553
554 if ( prnt ) { // Print a line for this iteration
555 if ( pcount >= pfreq ) { // Print a heading first
556 pcount = 0;
557 if ( damped )
558 {
559 (*this->nout) << " Itn x(1) norm rbar Abar'rbar"
560 " Compatible LS norm Abar cond Abar\n";
561 } else {
562 (*this->nout) << " Itn x(1) norm r A'r "
563 " Compatible LS norm A cond A\n";
564 }
565 }
566 pcount = pcount + 1;
567 (*this->nout)
568 << this->itn << ", " << x[0] << ", " <<this->normr << ", " << this->normAr << ", " << test1 << ", " << test2
569 << ", " << this->normA << ", " << this->condA << std::endl;
570 }
571 }
572
573 } while ( this->istop == 0);
574
575 this->TerminationPrintOut();
576 }
577
578
579 void lsmrBase::
TerminationPrintOut()580 TerminationPrintOut()
581 {
582 const char * msg[] = {
583 "The exact solution is x = 0 ",
584 "Ax - b is small enough, given atol, btol ",
585 "The least-squares solution is good enough, given atol",
586 "The estimate of cond(Abar) has exceeded conlim ",
587 "Ax - b is small enough for this machine ",
588 "The LS solution is good enough for this machine ",
589 "Cond(Abar) seems to be too large for this machine ",
590 "The iteration limit has been reached " };
591
592 if ( this->damped && this->istop==2 ) this->istop=3;
593
594 if ( this->nout ) {
595 (*this->nout) << " Exit LSMR. istop = " << this->istop << " ,itn = " << this->itn << std::endl
596 << " Exit LSMR. normA = " << this->normA << " ,condA = " << this->condA << std::endl
597 << " Exit LSMR. normb = " << this->normb << " ,normx = " << this->normx << std::endl
598 << " Exit LSMR. normr = " << this->normr << " ,normAr = " << this->normAr << std::endl
599 << " Exit LSMR. " << msg[this->istop] << std::endl;
600 }
601 }
602