1 /////////////////////////////////////////////////////////////////////////////// 2 // // 3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis // 4 // Copyright (C) 1998 - 2016 // 5 // All rights reserved // 6 // // 7 // The project is hosted at https://code.google.com/p/tmv-cpp/ // 8 // where you can find the current version and current documention. // 9 // // 10 // For concerns or problems with the software, Mike may be contacted at // 11 // mike_jarvis17 [at] gmail. // 12 // // 13 // This software is licensed under a FreeBSD license. The file // 14 // TMV_LICENSE should have bee included with this distribution. // 15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/. // 16 // // 17 // Essentially, you can use this software however you want provided that // 18 // you include the TMV_LICENSE file in any distribution that uses it. // 19 // // 20 /////////////////////////////////////////////////////////////////////////////// 21 22 23 //#define XDEBUG 24 25 26 #include "TMV_Blas.h" 27 #include "tmv/TMV_BandMatrixArithFunc.h" 28 #include "tmv/TMV_BandMatrix.h" 29 #include "tmv/TMV_VectorArith.h" 30 #include "tmv/TMV_DiagMatrix.h" 31 #include "tmv/TMV_DiagMatrixArithFunc.h" 32 #ifdef BLAS 33 #include "tmv/TMV_MatrixArith.h" 34 #include "tmv/TMV_BandMatrixArith.h" 35 #endif 36 #include <iostream> 37 38 // CBLAS trick of using RowMajor with ConjTrans when we have a 39 // case of A.conjugate() * x doesn't seem to be working with MKL 10.2.2. 40 // I haven't been able to figure out why. (e.g. Is it a bug in the MKL 41 // code, or am I doing something wrong?) So for now, just disable it. 42 #ifdef CBLAS 43 #undef CBLAS 44 #endif 45 46 #ifdef XDEBUG 47 #include "tmv/TMV_MatrixArith.h" 48 #include <iostream> 49 using std::cout; 50 using std::cerr; 51 using std::endl; 52 #endif 53 54 namespace tmv { 55 56 // 57 // BandMatrixComposite 58 // 59 60 template <class T> ls() const61 ptrdiff_t BandMatrixComposite<T>::ls() const 62 { 63 return BandStorageLength( 64 ColMajor,this->colsize(),this->rowsize(), 65 this->nlo(),this->nhi()); 66 } 67 68 template <class T> constLinearView() const69 ConstVectorView<T> BandMatrixComposite<T>::constLinearView() const 70 { 71 cptr(); // This makes the instantiation, but we don't need the result. 72 return ConstVectorView<T>(itsm1.get(),ls(),1,NonConj); 73 } 74 75 template <class T> cptr() const76 const T* BandMatrixComposite<T>::cptr() const 77 { 78 if (!itsm1.get()) { 79 ptrdiff_t cs = this->colsize(); 80 ptrdiff_t rs = this->rowsize(); 81 ptrdiff_t lo = this->nlo(); 82 ptrdiff_t hi = this->nhi(); 83 ptrdiff_t len = ls(); 84 itsm1.resize(len); 85 ptrdiff_t si = stepi(); 86 ptrdiff_t sj = stepj(); 87 ptrdiff_t ds = si + sj; 88 itsm = this->isdm() ? itsm1.get()-lo*si : itsm1.get(); 89 this->assignToB(BandMatrixView<T>( 90 itsm,cs,rs,lo,hi,si,sj,ds,NonConj,len 91 TMV_FIRSTLAST1(itsm1.get(),itsm1.get()+len) ) ); 92 } 93 return itsm; 94 } 95 96 template <class T> stepi() const97 ptrdiff_t BandMatrixComposite<T>::stepi() const 98 { return 1; } 99 100 template <class T> stepj() const101 ptrdiff_t BandMatrixComposite<T>::stepj() const 102 { return this->nlo()+this->nhi(); } 103 104 template <class T> diagstep() const105 ptrdiff_t BandMatrixComposite<T>::diagstep() const 106 { return this->nlo()+this->nhi()+1; } 107 108 109 // 110 // MultMV 111 // 112 113 template <bool add, bool cx, bool ca, bool rm, class T, class Ta, class Tx> RowMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)114 static void RowMultMV( 115 const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, 116 VectorView<T> y) 117 { 118 TMVAssert(A.rowsize()==x.size()); 119 TMVAssert(A.colsize()==y.size()); 120 TMVAssert(x.size() > 0); 121 TMVAssert(y.size() > 0); 122 TMVAssert(y.ct() == NonConj); 123 TMVAssert(x.step()==1); 124 TMVAssert(y.step()==1); 125 TMVAssert(!SameStorage(x,y)); 126 TMVAssert(cx == x.isconj()); 127 TMVAssert(ca == A.isconj()); 128 TMVAssert(rm == A.isrm()); 129 130 const ptrdiff_t si = A.stepi(); 131 const ptrdiff_t sj = (rm ? 1 : A.stepj()); 132 const ptrdiff_t ds = A.diagstep(); 133 const ptrdiff_t M = A.colsize(); 134 const ptrdiff_t N = A.rowsize(); 135 136 const Ta* Aij1 = A.cptr(); 137 const Tx* xj1 = x.cptr(); 138 T* yi = y.ptr(); 139 140 ptrdiff_t k=A.nlo(); 141 ptrdiff_t i=0; 142 ptrdiff_t j1=0; 143 ptrdiff_t j2=A.nhi()+1; 144 ptrdiff_t len = j2; // len = j2-j1 145 for(; i<M; ++i, ++yi) { 146 #ifdef TMVFLDEBUG 147 TMVAssert(yi >= y._first); 148 TMVAssert(yi < y._last); 149 #endif 150 if (!add) *yi = T(0); 151 152 // *yi += A.row(i,j1,j2) * x.subVector(j1,j2); 153 const Ta* Aij = Aij1; 154 const Tx* xj = xj1; 155 for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) { 156 #ifdef TMVFLDEBUG 157 TMVAssert(yi >= y._first); 158 TMVAssert(yi < y._last); 159 #endif 160 *yi += (cx ? TMV_CONJ(*xj) : *xj) * (ca ? TMV_CONJ(*Aij) : *Aij); 161 } 162 163 if (k>0) { --k; ++len; Aij1+=si; } 164 else { ++j1; ++xj1; Aij1+=ds; } 165 if (j2<N) ++j2; 166 else { --len; if (j1==N) { ++i, ++yi; break; } } 167 } 168 if (!add) for(;i<M; ++i, ++yi) { 169 #ifdef TMVFLDEBUG 170 TMVAssert(yi >= y._first); 171 TMVAssert(yi < y._last); 172 #endif 173 *yi = T(0); 174 } 175 } 176 177 template <bool add, bool cx, bool ca, bool cm, class T, class Ta, class Tx> ColMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)178 static void ColMultMV( 179 const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y) 180 { 181 TMVAssert(A.rowsize() == x.size()); 182 TMVAssert(A.colsize() == y.size()); 183 TMVAssert(x.size() > 0); 184 TMVAssert(y.size() > 0); 185 TMVAssert(y.ct() == NonConj); 186 TMVAssert(x.step()==1); 187 TMVAssert(y.step()==1); 188 TMVAssert(!SameStorage(x,y)); 189 TMVAssert(cx == x.isconj()); 190 TMVAssert(ca == A.isconj()); 191 TMVAssert(cm == A.iscm()); 192 193 const ptrdiff_t N = A.rowsize(); 194 const ptrdiff_t M = A.colsize(); 195 const ptrdiff_t si = (cm ? 1 : A.stepi()); 196 const ptrdiff_t sj = A.stepj(); 197 const ptrdiff_t ds = A.diagstep(); 198 199 const Ta* Ai1j = A.cptr(); 200 const Tx* xj = x.cptr(); 201 T* yi1 = y.ptr(); 202 203 ptrdiff_t k=A.nhi(); 204 ptrdiff_t i1=0; 205 ptrdiff_t i2=A.nlo()+1; 206 ptrdiff_t len = i2; // = i2-i1 207 208 if (!add) y.setZero(); 209 210 for(ptrdiff_t j=N; j>0; --j,++xj) { 211 if (*xj != Tx(0)) { 212 // y.subVector(i1,i2) += *xj * A.col(j,i1,i2); 213 const Ta* Aij = Ai1j; 214 T* yi = yi1; 215 for(ptrdiff_t i=len;i>0;--i,++yi,(cm?++Aij:Aij+=si)) { 216 #ifdef TMVFLDEBUG 217 TMVAssert(yi >= y._first); 218 TMVAssert(yi < y._last); 219 #endif 220 *yi += 221 (cx ? TMV_CONJ(*xj) : *xj) * 222 (ca ? TMV_CONJ(*Aij) : *Aij); 223 } 224 } 225 if (k>0) { --k; Ai1j+=sj; ++len; } 226 else { ++i1; ++yi1; Ai1j+=ds; } 227 if (i2<M) ++i2; 228 else { --len; if (i1==M) break; } 229 } 230 } 231 232 template <bool add, bool cx, bool ca, bool dm, class T, class Ta, class Tx> DiagMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)233 static void DiagMultMV( 234 const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, 235 VectorView<T> y) 236 { 237 TMVAssert(A.rowsize() == x.size()); 238 TMVAssert(A.colsize() == y.size()); 239 TMVAssert(x.size() > 0); 240 TMVAssert(y.size() > 0); 241 TMVAssert(y.ct() == NonConj); 242 TMVAssert(x.step()==1); 243 TMVAssert(y.step()==1); 244 TMVAssert(!SameStorage(x,y)); 245 TMVAssert(cx == x.isconj()); 246 TMVAssert(ca == A.isconj()); 247 TMVAssert(dm == A.isdm()); 248 249 const ptrdiff_t si = A.stepi(); 250 const ptrdiff_t sj = A.stepj(); 251 const ptrdiff_t ds = A.diagstep(); 252 const ptrdiff_t lo = A.nlo(); 253 const ptrdiff_t hi = A.nhi(); 254 const ptrdiff_t M = A.colsize(); 255 const ptrdiff_t N = A.rowsize(); 256 257 const Ta* Ai1j1 = A.cptr() + lo*si; 258 const Tx* xj1 = x.cptr(); 259 T* yi1 = y.ptr() + lo; 260 261 ptrdiff_t j2=TMV_MIN(M-lo,N); 262 ptrdiff_t len=j2; // == j2-j1 263 264 if (!add) y.setZero(); 265 266 for(ptrdiff_t k=-A.nlo(); k<=hi; ++k) { 267 // y.subVector(i1,i2) += DiagMatrixViewOf(A.diag(k)) * 268 // x.subVector(j1,j2); 269 const Ta* Aij = Ai1j1; 270 const Tx* xj = xj1; 271 T* yi = yi1; 272 for(ptrdiff_t i=len;i>0;--i,++yi,++xj,(dm?++Aij:Aij+=ds)) { 273 #ifdef TMVFLDEBUG 274 TMVAssert(yi >= y._first); 275 TMVAssert(yi < y._last); 276 #endif 277 *yi += 278 (cx ? TMV_CONJ(*xj) : *xj) * 279 (ca ? TMV_CONJ(*Aij) : *Aij); 280 } 281 if (k<0) { --yi1; ++len; Ai1j1-=si; } 282 else { ++xj1; Ai1j1+=sj; } 283 if (j2 < N) ++j2; else --len; 284 } 285 } 286 287 template <bool add, bool cx, class T, class Ta, class Tx> UnitAMultMV1(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)288 static void UnitAMultMV1( 289 const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, 290 VectorView<T> y) 291 { 292 TMVAssert(A.rowsize() == x.size()); 293 TMVAssert(A.colsize() == y.size()); 294 TMVAssert(x.size() > 0); 295 TMVAssert(y.size() > 0); 296 TMVAssert(y.ct() == NonConj); 297 TMVAssert(x.step()==1); 298 TMVAssert(y.step()==1); 299 TMVAssert(!SameStorage(x,y)); 300 TMVAssert(cx == x.isconj()); 301 302 if (A.isrm()) 303 if (A.isconj()) 304 RowMultMV<add,cx,true,true>(A,x,y); 305 else 306 RowMultMV<add,cx,false,true>(A,x,y); 307 else if (A.iscm()) 308 if (A.isconj()) 309 ColMultMV<add,cx,true,true>(A,x,y); 310 else 311 ColMultMV<add,cx,false,true>(A,x,y); 312 else if (A.isdm()) 313 if (A.isconj()) 314 DiagMultMV<add,cx,true,true>(A,x,y); 315 else 316 DiagMultMV<add,cx,false,true>(A,x,y); 317 else 318 if (A.isconj()) 319 DiagMultMV<add,cx,true,false>(A,x,y); 320 else 321 DiagMultMV<add,cx,false,false>(A,x,y); 322 } 323 324 template <bool add, bool cx, class T, class Ta, class Tx> UnitAMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)325 static void UnitAMultMV( 326 const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, 327 VectorView<T> y) 328 { 329 // Check for 0's in beginning or end of x: 330 // y += [ A1 A2 A3 ] [ 0 ] --> y += A2 x 331 // [ x ] 332 // [ 0 ] 333 334 const ptrdiff_t N = x.size(); // = A.rowsize() 335 ptrdiff_t j2 = N; 336 for(const Tx* x2=x.cptr()+N-1; j2>0 && *x2==Tx(0); --j2,--x2); 337 if (j2 == 0) { 338 if (!add) y.setZero(); 339 return; 340 } 341 ptrdiff_t j1 = 0; 342 for(const Tx* x1=x.cptr(); *x1==Tx(0); ++j1,++x1); 343 if (j1 == 0 && j2 == N) UnitAMultMV1<add,cx>(A,x,y); 344 else { 345 const ptrdiff_t hi = A.nhi(); 346 const ptrdiff_t lo = A.nlo(); 347 const ptrdiff_t M = y.size(); // = A.colsize() 348 // This next bit is copied from the BandMatrix colRange function 349 ptrdiff_t i1 = j1 > hi ? j1-hi : 0; 350 ptrdiff_t i2 = TMV_MIN(j2+lo,M); 351 ptrdiff_t newhi = j1 < hi ? hi-j1 : 0; 352 ptrdiff_t newlo = lo+hi-newhi; 353 ptrdiff_t newM = i2-i1; 354 ptrdiff_t newN = j2-j1; 355 TMVAssert(newM > 0); 356 TMVAssert(newN > 0); 357 if (newhi >= newN) newhi = newN-1; 358 if (newlo >= newM) newlo = newM-1; 359 TMVAssert(A.hasSubBandMatrix(i1,i2,j1,j2,newlo,newhi,1,1)); 360 const Ta* p = A.cptr()+i1*A.stepi()+j1*A.stepj(); 361 ConstBandMatrixView<Ta> Acols( 362 p,newM,newN,newlo,newhi, 363 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 364 UnitAMultMV1<add,cx>(Acols,x.subVector(j1,j2),y.subVector(i1,i2)); 365 if (!add) { 366 y.subVector(0,i1).setZero(); 367 y.subVector(i2,M).setZero(); 368 } 369 } 370 } 371 372 template <bool add, class T, class Ta, class Tx> NonBlasMultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)373 static void NonBlasMultMV( 374 const T alpha, const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, 375 VectorView<T> y) 376 // y (+)= alpha * A * x 377 { 378 TMVAssert(A.rowsize() == x.size()); 379 TMVAssert(A.colsize() == y.size()); 380 TMVAssert(alpha != T(0)); 381 TMVAssert(x.size() > 0); 382 TMVAssert(y.size() > 0); 383 TMVAssert(y.ct() == NonConj); 384 385 #ifdef XDEBUG 386 cout<<"NonBlasMultMV: A = "<<A<<endl; 387 Vector<T> y0 = y; 388 Vector<Tx> x0 = x; 389 Matrix<Ta> A0 = A; 390 Vector<T> y2 = alpha*A0*x0; 391 if (add) y2 += y0; 392 #endif 393 394 if (x.step() != 1 || SameStorage(x,y)) { 395 if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 396 Vector<Tx> xx = TMV_REAL(alpha) * x; 397 if (y.step() != 1) { 398 Vector<T> yy(y.size()); 399 UnitAMultMV<false,false>(A,xx,yy.view()); 400 if (add) y += yy; 401 else y = yy; 402 } else { 403 UnitAMultMV<add,false>(A,xx,y); 404 } 405 } else { 406 Vector<T> xx = alpha * x; 407 if (y.step() != 1) { 408 Vector<T> yy(y.size()); 409 UnitAMultMV<false,false>(A,xx,yy.view()); 410 if (add) y += yy; 411 else y = yy; 412 } else { 413 UnitAMultMV<add,false>(A,xx,y); 414 } 415 } 416 } else if (y.step() != 1 || alpha != TMV_RealType(T)(1)) { 417 Vector<T> yy(y.size()); 418 if (x.isconj()) 419 UnitAMultMV<false,true>(A,x,yy.view()); 420 else 421 UnitAMultMV<false,false>(A,x,yy.view()); 422 if (add) y += alpha * yy; 423 else y = alpha * yy; 424 } else { 425 TMVAssert(alpha == T(1)); 426 TMVAssert(y.step() == 1); 427 TMVAssert(x.step() == 1); 428 TMVAssert(!SameStorage(x,y)); 429 if (x.isconj()) 430 UnitAMultMV<add,true>(A,x,y); 431 else 432 UnitAMultMV<add,false>(A,x,y); 433 } 434 435 #ifdef XDEBUG 436 if (!(Norm(y2-y) <= 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 437 (add?Norm(y0):TMV_RealType(T)(0))))) { 438 cerr<<"NonBlas MultMV: alpha = "<<alpha<<endl; 439 cerr<<"add = "<<add<<endl; 440 cerr<<"A = "<<TMV_Text(A)<<" "<<A.cptr()<<" "<<A0<<endl; 441 cerr<<"x = "<<TMV_Text(x)<<" "<<x.cptr()<< 442 " step "<<x.step()<<" "<<x0<<endl; 443 cerr<<"y = "<<TMV_Text(y)<<" "<<y.cptr()<< 444 " step "<<y.step()<<" "<<y0<<endl; 445 cerr<<"--> y = "<<y<<endl; 446 cerr<<"y2 = "<<y2<<endl; 447 abort(); 448 } 449 #endif 450 } 451 452 #ifdef BLAS 453 template <class T, class Ta, class Tx> BlasMultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,int beta,VectorView<T> y)454 static inline void BlasMultMV( 455 const T alpha, const GenBandMatrix<Ta>& A, 456 const GenVector<Tx>& x, int beta, VectorView<T> y) 457 { 458 if (beta == 1) NonBlasMultMV<true>(alpha,A,x,y); 459 else NonBlasMultMV<false>(alpha,A,x,y); 460 } 461 #ifdef INST_DOUBLE 462 template <> BlasMultMV(const double alpha,const GenBandMatrix<double> & A,const GenVector<double> & x,int beta,VectorView<double> y)463 void BlasMultMV( 464 const double alpha, 465 const GenBandMatrix<double>& A, const GenVector<double>& x, 466 int beta, VectorView<double> y) 467 { 468 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 469 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 470 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 471 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 472 int ds = A.diagstep(); 473 int xs = x.step(); 474 int ys = y.step(); 475 const double* xp = x.cptr(); 476 if (xs < 0) xp += (x.size()-1)*xs; 477 double* yp = y.ptr(); 478 if (ys < 0) yp += (y.size()-1)*ys; 479 if (beta == 0) y.setZero(); 480 double xbeta(1); 481 BLASNAME(dgbmv) ( 482 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 483 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 484 BLASV(alpha),BLASP(A.cptr()-hi),BLASV(ds), 485 BLASP(xp),BLASV(xs),BLASV(xbeta), 486 BLASP(yp),BLASV(ys) BLAS1); 487 } 488 template <> BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<std::complex<double>> & A,const GenVector<std::complex<double>> & x,int beta,VectorView<std::complex<double>> y)489 void BlasMultMV( 490 const std::complex<double> alpha, 491 const GenBandMatrix<std::complex<double> >& A, 492 const GenVector<std::complex<double> >& x, 493 int beta, VectorView<std::complex<double> > y) 494 { 495 if (x.isconj() 496 #ifndef CBLAS 497 && !(A.isconj() && BlasIsCM(A)) 498 #endif 499 ) { 500 Vector<std::complex<double> > xx = alpha*x; 501 return BlasMultMV(std::complex<double>(1),A,xx,beta,y); 502 } 503 504 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 505 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 506 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 507 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 508 int ds = A.diagstep(); 509 int xs = x.step(); 510 int ys = y.step(); 511 const std::complex<double>* xp = x.cptr(); 512 if (xs < 0) xp += (x.size()-1)*xs; 513 std::complex<double>* yp = y.ptr(); 514 if (ys < 0) yp += (y.size()-1)*ys; 515 if (beta == 0) y.setZero(); 516 std::complex<double> xbeta(1); 517 518 if (A.isconj() && BlasIsCM(A)) { 519 #ifdef CBLAS 520 TMV_SWAP(m,n); 521 TMV_SWAP(lo,hi); 522 BLASNAME(zgbmv) ( 523 BLASRM BLASCH_CT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 524 BLASP(&alpha),BLASP(A.cptr()-lo),BLASV(ds), 525 BLASP(xp),BLASV(xs),BLASP(&xbeta), 526 BLASP(yp),BLASV(ys) BLAS1); 527 #else 528 std::complex<double> ca = TMV_CONJ(alpha); 529 if (x.isconj()) { 530 y.conjugateSelf(); 531 BLASNAME(zgbmv) ( 532 BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 533 BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds), 534 BLASP(xp),BLASV(xs),BLASP(&xbeta), 535 BLASP(yp),BLASV(ys) BLAS1); 536 y.conjugateSelf(); 537 } else { 538 Vector<std::complex<double> > xx=ca*x.conjugate(); 539 ca = std::complex<double>(1); 540 xs = 1; 541 xp = xx.cptr(); 542 y.conjugateSelf(); 543 BLASNAME(zgbmv) ( 544 BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 545 BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds), 546 BLASP(xp),BLASV(xs),BLASP(&xbeta), 547 BLASP(yp),BLASV(ys) BLAS1); 548 y.conjugateSelf(); 549 } 550 #endif 551 } else { 552 BLASNAME(zgbmv) ( 553 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 554 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 555 BLASP(&alpha),BLASP(A.cptr()-hi),BLASV(ds), 556 BLASP(xp),BLASV(xs),BLASP(&xbeta), 557 BLASP(yp),BLASV(ys) BLAS1); 558 } 559 } 560 template <> BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<std::complex<double>> & A,const GenVector<double> & x,int beta,VectorView<std::complex<double>> y)561 void BlasMultMV( 562 const std::complex<double> alpha, 563 const GenBandMatrix<std::complex<double> >& A, 564 const GenVector<double>& x, 565 int beta, VectorView<std::complex<double> > y) 566 { BlasMultMV(alpha,A,Vector<std::complex<double> >(x),beta,y); } 567 template <> BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<double> & A,const GenVector<std::complex<double>> & x,int beta,VectorView<std::complex<double>> y)568 void BlasMultMV( 569 const std::complex<double> alpha, 570 const GenBandMatrix<double>& A, 571 const GenVector<std::complex<double> >& x, 572 int beta, VectorView<std::complex<double> > y) 573 { 574 if (beta == 0) { 575 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 576 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 577 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 578 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 579 int ds = A.diagstep(); 580 int xs = 2*x.step(); 581 int ys = 2*y.step(); 582 const double* xp = (const double*) x.cptr(); 583 if (xs < 0) xp += (x.size()-1)*xs; 584 double* yp = (double*) y.ptr(); 585 if (ys < 0) yp += (y.size()-1)*ys; 586 double xalpha(1); 587 y.setZero(); 588 double xbeta(1); 589 BLASNAME(dgbmv) ( 590 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 591 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 592 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 593 BLASP(xp),BLASV(xs),BLASV(xbeta), 594 BLASP(yp),BLASV(ys) BLAS1); 595 BLASNAME(dgbmv) ( 596 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 597 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 598 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 599 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 600 BLASP(yp+1),BLASV(ys) BLAS1); 601 if (x.isconj()) y.conjugateSelf(); 602 y *= alpha; 603 } else if (TMV_IMAG(alpha) == 0. && !x.isconj()) { 604 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 605 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 606 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 607 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 608 int ds = A.diagstep(); 609 int xs = 2*x.step(); 610 int ys = 2*y.step(); 611 const double* xp = (const double*) x.cptr(); 612 if (xs < 0) xp += (x.size()-1)*xs; 613 double* yp = (double*) y.ptr(); 614 if (ys < 0) yp += (y.size()-1)*ys; 615 double xalpha(TMV_REAL(alpha)); 616 double xbeta(beta); 617 BLASNAME(dgbmv) ( 618 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 619 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 620 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 621 BLASP(xp),BLASV(xs),BLASV(xbeta), 622 BLASP(yp),BLASV(ys) BLAS1); 623 BLASNAME(dgbmv) ( 624 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 625 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 626 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 627 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 628 BLASP(yp+1),BLASV(ys) BLAS1); 629 } else { 630 Vector<std::complex<double> > xx = alpha*x; 631 BlasMultMV(std::complex<double>(1),A,xx,1,y); 632 } 633 } 634 template <> BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<double> & A,const GenVector<double> & x,int beta,VectorView<std::complex<double>> y)635 void BlasMultMV( 636 const std::complex<double> alpha, 637 const GenBandMatrix<double>& A, 638 const GenVector<double>& x, 639 int beta, VectorView<std::complex<double> > y) 640 { 641 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 642 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 643 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 644 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 645 int ds = A.diagstep(); 646 int xs = x.step(); 647 int ys = 2*y.step(); 648 const double* xp = x.cptr(); 649 if (xs < 0) xp += (x.size()-1)*xs; 650 double* yp = (double*) y.ptr(); 651 if (ys < 0) yp += (y.size()-1)*ys; 652 double ar(TMV_REAL(alpha)); 653 double ai(TMV_IMAG(alpha)); 654 if (beta == 0) y.setZero(); 655 double xbeta(1); 656 if (ar != 0.) { 657 BLASNAME(dgbmv) ( 658 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 659 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 660 BLASV(ar),BLASP(A.cptr()-hi),BLASV(ds), 661 BLASP(xp),BLASV(xs),BLASV(xbeta), 662 BLASP(yp),BLASV(ys) BLAS1); 663 } 664 if (ai != 0.) { 665 BLASNAME(dgbmv) ( 666 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 667 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 668 BLASV(ai),BLASP(A.cptr()-hi),BLASV(ds), 669 BLASP(xp),BLASV(xs),BLASV(xbeta), 670 BLASP(yp+1),BLASV(ys) BLAS1); 671 } 672 } 673 #endif 674 #ifdef INST_FLOAT 675 template <> BlasMultMV(const float alpha,const GenBandMatrix<float> & A,const GenVector<float> & x,int beta,VectorView<float> y)676 void BlasMultMV( 677 const float alpha, 678 const GenBandMatrix<float>& A, const GenVector<float>& x, 679 int beta, VectorView<float> y) 680 { 681 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 682 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 683 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 684 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 685 int ds = A.diagstep(); 686 int xs = x.step(); 687 int ys = y.step(); 688 const float* xp = x.cptr(); 689 if (xs < 0) xp += (x.size()-1)*xs; 690 float* yp = y.ptr(); 691 if (ys < 0) yp += (y.size()-1)*ys; 692 if (beta == 0) y.setZero(); 693 float xbeta(1); 694 BLASNAME(sgbmv) ( 695 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 696 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 697 BLASV(alpha),BLASP(A.cptr()-hi),BLASV(ds), 698 BLASP(xp),BLASV(xs),BLASV(xbeta), 699 BLASP(yp),BLASV(ys) BLAS1); 700 } 701 template <> BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<std::complex<float>> & A,const GenVector<std::complex<float>> & x,int beta,VectorView<std::complex<float>> y)702 void BlasMultMV( 703 const std::complex<float> alpha, 704 const GenBandMatrix<std::complex<float> >& A, 705 const GenVector<std::complex<float> >& x, 706 int beta, VectorView<std::complex<float> > y) 707 { 708 if (x.isconj() 709 #ifndef CBLAS 710 && !(A.isconj() && BlasIsCM(A)) 711 #endif 712 ) { 713 Vector<std::complex<float> > xx = alpha*x; 714 return BlasMultMV(std::complex<float>(1),A,xx,beta,y); 715 } 716 717 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 718 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 719 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 720 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 721 int ds = A.diagstep(); 722 int xs = x.step(); 723 int ys = y.step(); 724 const std::complex<float>* xp = x.cptr(); 725 if (xs < 0) xp += (x.size()-1)*xs; 726 std::complex<float>* yp = y.ptr(); 727 if (ys < 0) yp += (y.size()-1)*ys; 728 if (beta == 0) y.setZero(); 729 std::complex<float> xbeta(1); 730 731 if (A.isconj() && BlasIsCM(A)) { 732 #ifdef CBLAS 733 TMV_SWAP(m,n); 734 TMV_SWAP(lo,hi); 735 BLASNAME(cgbmv) ( 736 BLASRM BLASCH_CT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 737 BLASP(&alpha),BLASP(A.cptr()-lo),BLASV(ds), 738 BLASP(xp),BLASV(xs),BLASP(&xbeta), 739 BLASP(yp),BLASV(ys) BLAS1); 740 #else 741 std::complex<float> ca = TMV_CONJ(alpha); 742 if (x.isconj()) { 743 y.conjugateSelf(); 744 BLASNAME(cgbmv) ( 745 BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 746 BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds), 747 BLASP(xp),BLASV(xs),BLASP(&xbeta), 748 BLASP(yp),BLASV(ys) BLAS1); 749 y.conjugateSelf(); 750 } else { 751 Vector<std::complex<float> > xx=ca*x.conjugate(); 752 ca = std::complex<float>(1); 753 xs = 1; 754 xp = xx.cptr(); 755 y.conjugateSelf(); 756 BLASNAME(cgbmv) ( 757 BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 758 BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds), 759 BLASP(xp),BLASV(xs),BLASP(&xbeta), 760 BLASP(yp),BLASV(ys) BLAS1); 761 y.conjugateSelf(); 762 } 763 #endif 764 } else { 765 BLASNAME(cgbmv) ( 766 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 767 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 768 BLASP(&alpha),BLASP(A.cptr()-hi),BLASV(ds), 769 BLASP(xp),BLASV(xs),BLASP(&xbeta), 770 BLASP(yp),BLASV(ys) BLAS1); 771 } 772 } 773 template <> BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<std::complex<float>> & A,const GenVector<float> & x,int beta,VectorView<std::complex<float>> y)774 void BlasMultMV( 775 const std::complex<float> alpha, 776 const GenBandMatrix<std::complex<float> >& A, 777 const GenVector<float>& x, 778 int beta, VectorView<std::complex<float> > y) 779 { BlasMultMV(alpha,A,Vector<std::complex<float> >(x),beta,y); } 780 template <> BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<float> & A,const GenVector<std::complex<float>> & x,int beta,VectorView<std::complex<float>> y)781 void BlasMultMV( 782 const std::complex<float> alpha, 783 const GenBandMatrix<float>& A, 784 const GenVector<std::complex<float> >& x, 785 int beta, VectorView<std::complex<float> > y) 786 { 787 if (beta == 0) { 788 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 789 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 790 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 791 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 792 int ds = A.diagstep(); 793 int xs = 2*x.step(); 794 int ys = 2*y.step(); 795 const float* xp = (const float*) x.cptr(); 796 if (xs < 0) xp += (x.size()-1)*xs; 797 float* yp = (float*) y.ptr(); 798 if (ys < 0) yp += (y.size()-1)*ys; 799 float xalpha(1); 800 y.setZero(); 801 float xbeta(1); 802 BLASNAME(sgbmv) ( 803 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 804 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 805 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 806 BLASP(xp),BLASV(xs),BLASV(xbeta), 807 BLASP(yp),BLASV(ys) BLAS1); 808 BLASNAME(sgbmv) ( 809 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 810 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 811 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 812 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 813 BLASP(yp+1),BLASV(ys) BLAS1); 814 if (x.isconj()) y.conjugateSelf(); 815 y *= alpha; 816 } else if (TMV_IMAG(alpha) == 0.F && !x.isconj()) { 817 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 818 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 819 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 820 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 821 int ds = A.diagstep(); 822 int xs = 2*x.step(); 823 int ys = 2*y.step(); 824 const float* xp = (const float*) x.cptr(); 825 if (xs < 0) xp += (x.size()-1)*xs; 826 float* yp = (float*) y.ptr(); 827 if (ys < 0) yp += (y.size()-1)*ys; 828 float xalpha(TMV_REAL(alpha)); 829 float xbeta(1); 830 BLASNAME(sgbmv) ( 831 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 832 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 833 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 834 BLASP(xp),BLASV(xs),BLASV(xbeta), 835 BLASP(yp),BLASV(ys) BLAS1); 836 BLASNAME(sgbmv) ( 837 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 838 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 839 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds), 840 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 841 BLASP(yp+1),BLASV(ys) BLAS1); 842 } else { 843 Vector<std::complex<float> > xx = alpha*x; 844 BlasMultMV(std::complex<float>(1),A,xx,1,y); 845 } 846 } 847 template <> BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<float> & A,const GenVector<float> & x,int beta,VectorView<std::complex<float>> y)848 void BlasMultMV( 849 const std::complex<float> alpha, 850 const GenBandMatrix<float>& A, 851 const GenVector<float>& x, 852 int beta, VectorView<std::complex<float> > y) 853 { 854 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 855 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 856 int lo = BlasIsCM(A) ? A.nlo() : A.nhi(); 857 int hi = BlasIsCM(A) ? A.nhi() : A.nlo(); 858 int ds = A.diagstep(); 859 int xs = x.step(); 860 int ys = 2*y.step(); 861 const float* xp = x.cptr(); 862 if (xs < 0) xp += (x.size()-1)*xs; 863 float* yp = (float*) y.ptr(); 864 if (ys < 0) yp += (y.size()-1)*ys; 865 float ar(TMV_REAL(alpha)); 866 float ai(TMV_IMAG(alpha)); 867 if (beta == 0) y.setZero(); 868 float xbeta(1); 869 if (ar != 0.F) { 870 BLASNAME(sgbmv) ( 871 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 872 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 873 BLASV(ar),BLASP(A.cptr()-hi),BLASV(ds), 874 BLASP(xp),BLASV(xs),BLASV(xbeta), 875 BLASP(yp),BLASV(ys) BLAS1); 876 } 877 if (ai != 0.F) { 878 BLASNAME(sgbmv) ( 879 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 880 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi), 881 BLASV(ai),BLASP(A.cptr()-hi),BLASV(ds), 882 BLASP(xp),BLASV(xs),BLASV(xbeta), 883 BLASP(yp+1),BLASV(ys) BLAS1); 884 } 885 } 886 #endif 887 #endif // BLAS 888 889 template <bool add, class T, class Ta, class Tx> DoMultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)890 static void DoMultMV( 891 const T alpha, const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, 892 VectorView<T> y) 893 { 894 //cout<<"Start DoMultMV\n"; 895 //cout<<"A = "<<TMV_Text(A)<<" "<<A.cptr()<<" "<<A<<endl; 896 //cout<<"x = "<<TMV_Text(x)<<" "<<x.cptr()<<" step "<<x.step()<<" "<<x<<endl; 897 //cout<<"y = "<<TMV_Text(y)<<" "<<y.cptr()<<" step "<<y.step()<<" "<<y<<endl; 898 //cout<<"alpha = "<<alpha<<", add = "<<add<<endl; 899 TMVAssert(A.rowsize() == x.size()); 900 TMVAssert(A.colsize() == y.size()); 901 TMVAssert(alpha != T(0)); 902 TMVAssert(x.size() > 0); 903 TMVAssert(y.size() > 0); 904 905 if (y.isconj()) { 906 DoMultMV<add>( 907 TMV_CONJ(alpha),A.conjugate(),x.conjugate(),y.conjugate()); 908 } else { 909 #ifdef BLAS 910 if (x.step() == 0) { 911 if (x.size() <= 1) 912 DoMultMV<add>( 913 alpha,A, 914 ConstVectorView<Tx>(x.cptr(),x.size(),1,x.ct()),y); 915 else 916 DoMultMV<add>(alpha,A,Vector<Tx>(x),y); 917 } else if (y.step() == 0) { 918 TMVAssert(y.size() <= 1); 919 DoMultMV<add>( 920 alpha,A,x,VectorView<T>(y.ptr(),y.size(),1,y.ct())); 921 } else if (BlasIsRM(A) || BlasIsCM(A)) { 922 if (!SameStorage(A,y)) { 923 if (!SameStorage(x,y) && !SameStorage(A,x)) { 924 BlasMultMV(alpha,A,x,add?1:0,y); 925 } else { 926 Vector<T> xx = alpha*x; 927 BlasMultMV(T(1),A,xx,add?1:0,y); 928 } 929 } else { 930 Vector<T> yy(y.size(),T(0)); 931 if (!SameStorage(A,x)) { 932 BlasMultMV(T(1),A,x,0,yy.view()); 933 if (add) y += alpha*yy; 934 else y = alpha*yy; 935 } else { 936 Vector<T> xx = alpha*x; 937 BlasMultMV(T(1),A,xx,0,yy.view()); 938 if (add) y += yy; 939 else y = yy; 940 } 941 } 942 } else if ((A.isrm() && A.stepi() < A.nlo()+A.nhi()) || 943 (A.iscm() && A.stepj() < A.nlo()+A.nhi())) { 944 if (SameStorage(A,y)) { 945 Vector<T> yy(y.size(),T(0)); 946 DoMultMV<false>(T(1),A,x,yy.view()); 947 if (add) y += alpha*yy; 948 else y = alpha*yy; 949 } else if (SameStorage(x,y)) { 950 DoMultMV<add>(T(1),A,Vector<T>(alpha*x),y); 951 } else if (A.nlo()+1 == A.colsize()) { 952 if (A.nhi()+1 == A.rowsize()) { 953 ConstMatrixView<Ta> A1 = 954 A.subMatrix(0,A.colsize(),0,A.rowsize()); 955 MultMV<add>(alpha,A1,x,y); 956 } else { 957 ConstMatrixView<Ta> A1 = 958 A.subMatrix(0,A.colsize(),0,A.nhi()); 959 MultMV<add>(alpha,A1,x.subVector(0,A.nhi()),y); 960 ConstBandMatrixView<Ta> A2 = A.colRange(A.nhi(),A.rowsize()); 961 BlasMultMV(alpha,A2,x.subVector(A.nhi(),A.rowsize()),1,y); 962 } 963 } else { 964 TMVAssert(A.nlo()>0); 965 if (A.nhi()+1 == A.rowsize()) { 966 ConstMatrixView<Ta> A1 = 967 A.subMatrix(0,A.nlo(),0,A.rowsize()); 968 MultMV<add>(alpha,A1,x,y.subVector(0,A.nlo())); 969 } else { 970 ConstBandMatrixView<Ta> A1 = A.rowRange(0,A.nlo()); 971 BlasMultMV(alpha,A1,x.subVector(0,A1.rowsize()), 972 add?1:0,y.subVector(0,A.nlo())); 973 } 974 ConstBandMatrixView<Ta> A2 = A.rowRange(A.nlo(),A.colsize()); 975 BlasMultMV(alpha,A2,x,add?1:0,y.subVector(A.nlo(),A.colsize())); 976 } 977 } else { 978 if (TMV_IMAG(alpha) == T(0)) { 979 BandMatrix<Ta,RowMajor> A2 = TMV_REAL(alpha)*A; 980 DoMultMV<add>(T(1),A2,x,y); 981 } else { 982 BandMatrix<T,RowMajor> A2 = alpha*A; 983 DoMultMV<add>(T(1),A2,x,y); 984 } 985 } 986 #else 987 NonBlasMultMV<add>(alpha,A,x,y); 988 #endif 989 } 990 //std::cout<<"Done DoMultMV\n"; 991 } 992 993 // 994 // MultEqMV 995 // 996 997 template <bool rm, bool ca, class T, class Ta> DoRowUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)998 static void DoRowUpperMultEqMV( 999 const GenBandMatrix<Ta>& A, VectorView<T> x) 1000 { 1001 TMVAssert(A.isSquare()); 1002 TMVAssert(A.colsize() == x.size()); 1003 TMVAssert(x.size() > 0); 1004 TMVAssert(x.step()==1); 1005 TMVAssert(x.ct() == NonConj); 1006 TMVAssert(rm == A.isrm()); 1007 TMVAssert(ca == A.isconj()); 1008 1009 const ptrdiff_t N = x.size(); 1010 const ptrdiff_t sj = (rm ? 1 : A.stepj()); 1011 const ptrdiff_t ds = A.diagstep(); 1012 1013 T* xi = x.ptr(); 1014 const Ta* Aii = A.cptr(); 1015 ptrdiff_t j2=A.nhi()+1; 1016 ptrdiff_t len = j2-1; 1017 1018 for(; len>0; ++xi,Aii+=ds) { 1019 // i = 0..N-2 1020 // x(i) = A.row(i,i,j2) * x.subVector(i,j2); 1021 #ifdef TMVFLDEBUG 1022 TMVAssert(xi >= x._first); 1023 TMVAssert(xi < x._last); 1024 #endif 1025 *xi *= (ca ? TMV_CONJ(*Aii) : *Aii); 1026 const T* xj = xi+1; 1027 const Ta* Aij = Aii+sj; 1028 for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) { 1029 #ifdef TMVFLDEBUG 1030 TMVAssert(xi >= x._first); 1031 TMVAssert(xi < x._last); 1032 #endif 1033 *xi += (*xj) * (ca ? TMV_CONJ(*Aij) : *Aij); 1034 } 1035 1036 if (j2<N) ++j2; 1037 else --len; 1038 } 1039 #ifdef TMVFLDEBUG 1040 TMVAssert(xi >= x._first); 1041 TMVAssert(xi < x._last); 1042 #endif 1043 *xi *= (ca ? TMV_CONJ(*Aii) : *Aii); 1044 } 1045 1046 template <bool rm, class T, class Ta> RowUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1047 static inline void RowUpperMultEqMV( 1048 const GenBandMatrix<Ta>& A, VectorView<T> x) 1049 { 1050 if (A.isconj()) 1051 DoRowUpperMultEqMV<rm,true>(A,x); 1052 else 1053 DoRowUpperMultEqMV<rm,false>(A,x); 1054 } 1055 1056 template <bool cm, bool ca, class T, class Ta> DoColUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1057 static void DoColUpperMultEqMV( 1058 const GenBandMatrix<Ta>& A, VectorView<T> x) 1059 { 1060 TMVAssert(A.isSquare()); 1061 TMVAssert(A.colsize() == x.size()); 1062 TMVAssert(x.size() > 0); 1063 TMVAssert(x.step()==1); 1064 TMVAssert(x.ct() == NonConj); 1065 TMVAssert(cm == A.iscm()); 1066 TMVAssert(ca == A.isconj()); 1067 1068 const ptrdiff_t N = x.size(); 1069 const ptrdiff_t si = cm ? 1 : A.stepi(); 1070 const ptrdiff_t sj = A.stepj(); 1071 const ptrdiff_t ds = A.diagstep(); 1072 1073 const Ta* Ai1j = A.cptr(); 1074 T* xi1 = x.ptr(); 1075 1076 *xi1 *= (ca ? TMV_CONJ(*Ai1j) : *Ai1j); 1077 Ai1j += sj; 1078 const T* xj = x.cptr()+1; 1079 1080 ptrdiff_t k=A.nhi()-1; 1081 ptrdiff_t len = 1; 1082 for(ptrdiff_t j=1; j<N; ++j,++xj) { 1083 if (*xj != T(0)) { 1084 // j = 1..N-1 1085 // x.subVector(i1,j) += x(j) * A.col(j,i1,j); 1086 const Ta* Aij = Ai1j; 1087 T* xi = xi1; 1088 for(ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) { 1089 #ifdef TMVFLDEBUG 1090 TMVAssert(xi >= x._first); 1091 TMVAssert(xi < x._last); 1092 #endif 1093 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij); 1094 } 1095 // Now Aij == Ajj, xi == xj 1096 // so this next statement is really *xj *= *Ajj 1097 #ifdef TMVFLDEBUG 1098 TMVAssert(xi >= x._first); 1099 TMVAssert(xi < x._last); 1100 #endif 1101 *xi *= (ca ? TMV_CONJ(*Aij) : *Aij); 1102 } 1103 1104 if (k>0) { --k; Ai1j+=sj; ++len; } 1105 else { ++xi1; Ai1j+=ds; } 1106 } 1107 } 1108 1109 template <bool cm, class T, class Ta> ColUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1110 static inline void ColUpperMultEqMV( 1111 const GenBandMatrix<Ta>& A, VectorView<T> x) 1112 { 1113 if (A.isconj()) 1114 DoColUpperMultEqMV<cm,true>(A,x); 1115 else 1116 DoColUpperMultEqMV<cm,false>(A,x); 1117 } 1118 1119 template <bool rm, bool ca, class T, class Ta> DoRowLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1120 static void DoRowLowerMultEqMV( 1121 const GenBandMatrix<Ta>& A, VectorView<T> x) 1122 { 1123 TMVAssert(A.isSquare()); 1124 TMVAssert(A.colsize() == x.size()); 1125 TMVAssert(x.size() > 0); 1126 TMVAssert(x.step()==1); 1127 TMVAssert(x.ct() == NonConj); 1128 TMVAssert(rm == A.isrm()); 1129 TMVAssert(ca == A.isconj()); 1130 1131 const ptrdiff_t N = x.size(); 1132 const ptrdiff_t si = A.stepi(); 1133 const ptrdiff_t sj = (rm ? 1 : A.stepj()); 1134 const ptrdiff_t ds = A.diagstep(); 1135 1136 ptrdiff_t j1 = N-1-A.nlo(); 1137 const T* xj1 = x.cptr() + j1; 1138 T* xi = x.ptr() + N-1; 1139 const Ta* Aii = A.cptr() + (N-1)*ds; 1140 const Ta* Aij1 = Aii - A.nlo()*sj; 1141 ptrdiff_t len = A.nlo(); 1142 1143 for(; len>0; --xi,Aii-=ds) { 1144 // i = N-1..1 1145 // x(i) = A.row(i,j1,i+1) * x.subVector(j1,i+1); 1146 #ifdef TMVFLDEBUG 1147 TMVAssert(xi >= x._first); 1148 TMVAssert(xi < x._last); 1149 #endif 1150 *xi *= (ca ? TMV_CONJ(*Aii) : *Aii); 1151 const Ta* Aij = Aij1; 1152 const T* xj = xj1; 1153 for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) { 1154 #ifdef TMVFLDEBUG 1155 TMVAssert(xi >= x._first); 1156 TMVAssert(xi < x._last); 1157 #endif 1158 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij); 1159 } 1160 1161 if (j1>0) { --j1; Aij1-=ds; --xj1; } 1162 else { --len; Aij1-=si; } 1163 } 1164 #ifdef TMVFLDEBUG 1165 TMVAssert(xi >= x._first); 1166 TMVAssert(xi < x._last); 1167 #endif 1168 *xi *= (ca ? TMV_CONJ(*Aii) : *Aii); 1169 } 1170 1171 template <bool rm, class T, class Ta> RowLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1172 static inline void RowLowerMultEqMV( 1173 const GenBandMatrix<Ta>& A, VectorView<T> x) 1174 { 1175 if (A.isconj()) 1176 DoRowLowerMultEqMV<rm,true>(A,x); 1177 else 1178 DoRowLowerMultEqMV<rm,false>(A,x); 1179 } 1180 1181 template <bool cm, bool ca, class T, class Ta> DoColLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1182 static void DoColLowerMultEqMV( 1183 const GenBandMatrix<Ta>& A, VectorView<T> x) 1184 { 1185 TMVAssert(A.isSquare()); 1186 TMVAssert(A.colsize() == x.size()); 1187 TMVAssert(x.size() > 0); 1188 TMVAssert(x.step() == 1); 1189 TMVAssert(x.ct() == NonConj); 1190 TMVAssert(cm == A.iscm()); 1191 TMVAssert(ca == A.isconj()); 1192 1193 const ptrdiff_t N = x.size(); 1194 const ptrdiff_t si = cm ? 1 : A.stepi(); 1195 const ptrdiff_t ds = A.diagstep(); 1196 1197 T* xj = x.ptr() + N-1; 1198 const Ta* Ajj = A.cptr()+(N-1)*ds; 1199 1200 #ifdef TMVFLDEBUG 1201 TMVAssert(xj >= x._first); 1202 TMVAssert(xj < x._last); 1203 #endif 1204 *xj *= (ca ? TMV_CONJ(*Ajj) : *Ajj); 1205 --xj; 1206 Ajj -= ds; 1207 1208 ptrdiff_t k=A.nlo()-1; 1209 for(ptrdiff_t j=N-1,len=1;j>0;--j,--xj,Ajj-=ds) { 1210 if (*xj!=T(0)) { 1211 // Actual j = N-2..0 1212 // x.subVector(j+1,N) += *xj * A.col(j,j+1,N); 1213 T* xi = xj+1; 1214 const Ta* Aij = Ajj+si; 1215 for (ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) { 1216 #ifdef TMVFLDEBUG 1217 TMVAssert(xi >= x._first); 1218 TMVAssert(xi < x._last); 1219 #endif 1220 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij); 1221 } 1222 #ifdef TMVFLDEBUG 1223 TMVAssert(xj >= x._first); 1224 TMVAssert(xj < x._last); 1225 #endif 1226 *xj *= (ca ? TMV_CONJ(*Ajj) : *Ajj); 1227 1228 } 1229 if (k>0) { --k; ++len; } 1230 } 1231 } 1232 1233 template <bool cm, class T, class Ta> ColLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1234 static inline void ColLowerMultEqMV( 1235 const GenBandMatrix<Ta>& A, VectorView<T> x) 1236 { 1237 if (A.isconj()) 1238 DoColLowerMultEqMV<cm,true>(A,x); 1239 else 1240 DoColLowerMultEqMV<cm,false>(A,x); 1241 } 1242 1243 template <class T, class Ta> DoUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1244 static inline void DoUpperMultEqMV( 1245 const GenBandMatrix<Ta>& A, VectorView<T> x) 1246 // x = A * x 1247 { 1248 if (A.isrm()) RowUpperMultEqMV<true>(A,x); 1249 else if (A.iscm()) ColUpperMultEqMV<true>(A,x); 1250 else RowUpperMultEqMV<false>(A,x); 1251 } 1252 1253 template <class T, class Ta> DoLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1254 static inline void DoLowerMultEqMV( 1255 const GenBandMatrix<Ta>& A, VectorView<T> x) 1256 { 1257 if (A.isrm()) RowLowerMultEqMV<true>(A,x); 1258 else if (A.iscm() && !SameStorage(A,x)) 1259 ColLowerMultEqMV<true>(A,x); 1260 else RowLowerMultEqMV<false>(A,x); 1261 } 1262 1263 template <class T, class Ta> NonBlasUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1264 static void NonBlasUpperMultEqMV( 1265 const GenBandMatrix<Ta>& A, VectorView<T> x) 1266 { 1267 TMVAssert(A.isSquare()); 1268 TMVAssert(A.colsize() == x.size()); 1269 TMVAssert(x.size() > 0); 1270 TMVAssert(x.step() == 1); 1271 TMVAssert(x.ct() == NonConj); 1272 1273 // [ A11 A12 0 ] [ 0 ] [ A12 x2 ] 1274 // x = [ 0 A22 A23 ] [ x2 ] = [ A22 x2 ] 1275 // [ 0 0 A33 ] [ 0 ] [ 0 ] 1276 1277 const ptrdiff_t N = x.size(); // = A.size() 1278 ptrdiff_t j2 = N; 1279 for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2); 1280 if (j2 == 0) return; 1281 ptrdiff_t j1 = 0; 1282 for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1); 1283 if (j1 == 0 && j2 == N) DoUpperMultEqMV(A,x); 1284 else { 1285 TMVAssert(j1 < j2); 1286 const Ta* p22 = A.cptr() + j1*A.diagstep(); 1287 const ptrdiff_t N22 = j2-j1; 1288 VectorView<T> x2 = x.subVector(j1,j2); 1289 if (N22 > A.nhi()) { 1290 ConstBandMatrixView<Ta> A22( 1291 p22,N22,N22,0,A.nhi(), 1292 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1293 if (j1 > 0) { 1294 const ptrdiff_t jx = j1+A.nhi(); 1295 if (j1 < A.nhi()) { 1296 const Ta* p12 = A.cptr() + j1*A.stepj(); 1297 ConstBandMatrixView<Ta> A12( 1298 p12,j1,A.nhi(),j1-1,A.nhi()-j1, 1299 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1300 UnitAMultMV1<false,false>( 1301 A12,x.subVector(j1,jx),x.subVector(0,j1)); 1302 } else { 1303 const Ta* p12 = p22 - A.nhi()*A.stepi(); 1304 ConstBandMatrixView<Ta> A12( 1305 p12,A.nhi(),A.nhi(),A.nhi()-1,0, 1306 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1307 VectorView<T> x1x = x.subVector(j1-A.nhi(),j1); 1308 x1x = x.subVector(j1,jx); 1309 DoLowerMultEqMV(A12,x1x); 1310 } 1311 } 1312 DoUpperMultEqMV(A22,x2); 1313 } else { 1314 ConstBandMatrixView<Ta> A22( 1315 p22,N22,N22,0,N22-1, 1316 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1317 if (j1 > 0) { 1318 const ptrdiff_t M12 = (j1 < A.nhi()) ? j1 : A.nhi(); 1319 const Ta* p12 = p22 - M12*A.stepi(); 1320 ptrdiff_t newhi = A.nhi()-M12; 1321 if (newhi >= N22) newhi = N22-1; 1322 ConstBandMatrixView<Ta> A12( 1323 p12,M12,N22,M12-1,newhi, 1324 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1325 UnitAMultMV1<false,false>(A12,x2,x.subVector(j1-M12,j1)); 1326 } 1327 DoUpperMultEqMV(A22,x2); 1328 } 1329 } 1330 } 1331 1332 template <class T, class Ta> NonBlasLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1333 static void NonBlasLowerMultEqMV( 1334 const GenBandMatrix<Ta>& A, VectorView<T> x) 1335 // x = A * x 1336 { 1337 TMVAssert(A.isSquare()); 1338 TMVAssert(A.colsize() == x.size()); 1339 TMVAssert(x.size() > 0); 1340 TMVAssert(x.step() == 1); 1341 TMVAssert(x.ct() == NonConj); 1342 1343 // [ A11 0 0 ] [ 0 ] [ 0 ] 1344 // x = [ A21 A22 0 ] [ x2 ] = [ A22 x2 ] 1345 // [ 0 A32 A33 ] [ 0 ] [ A32 x2 ] 1346 1347 const ptrdiff_t N = x.size(); // = A.size() 1348 ptrdiff_t j2 = N; 1349 for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2); 1350 if (j2 == 0) return; 1351 ptrdiff_t j1 = 0; 1352 for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1); 1353 if (j1 == 0 && j2 == N) DoLowerMultEqMV(A,x); 1354 else { 1355 TMVAssert(j1 < j2); 1356 const Ta* p22 = A.cptr() + j1*A.diagstep(); 1357 const ptrdiff_t N22 = j2-j1; 1358 VectorView<T> x2 = x.subVector(j1,j2); 1359 if (N22 > A.nlo()) { 1360 ConstBandMatrixView<Ta> A22( 1361 p22,N22,N22,A.nlo(),0, 1362 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1363 if (j2 < N) { 1364 const ptrdiff_t jx = j2-A.nlo(); 1365 const Ta* p32 = A.cptr() + 1366 j2*A.diagstep() - A.nlo()*A.stepj(); 1367 if (j2+A.nlo() > N) { 1368 ConstBandMatrixView<Ta> A32( 1369 p32,N-j2,A.nlo(),0,A.nlo()-1, 1370 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1371 UnitAMultMV1<false,false>( 1372 A32,x.subVector(jx,j2),x.subVector(j2,N)); 1373 } else { 1374 ConstBandMatrixView<Ta> A32( 1375 p32,A.nlo(),A.nlo(),0,A.nlo()-1, 1376 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1377 VectorView<T> x3x = x.subVector(j2,j2+A.nlo()); 1378 x3x = x.subVector(jx,j2); 1379 DoUpperMultEqMV(A32,x3x); 1380 } 1381 } 1382 DoLowerMultEqMV(A22,x2); 1383 } else { 1384 ConstBandMatrixView<Ta> A22( 1385 p22,N22,N22,N22-1,0, 1386 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1387 if (j2 < N) { 1388 const Ta* p32 = p22 + N22*A.stepi(); 1389 const ptrdiff_t M32 = (j2+A.nlo() > N) ? N-j2 : A.nlo(); 1390 ptrdiff_t newlo = A.nlo()-N22; 1391 if (newlo >= M32) newlo = M32-1; 1392 ConstBandMatrixView<Ta> A32( 1393 p32,M32,N22,newlo,N22-1, 1394 A.stepi(),A.stepj(),A.diagstep(),A.ct()); 1395 UnitAMultMV1<false,false>(A32,x2,x.subVector(j2,j2+M32)); 1396 } 1397 DoLowerMultEqMV(A22,x2); 1398 } 1399 } 1400 } 1401 1402 #ifdef BLAS 1403 template <class T, class Ta> BlasMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1404 static inline void BlasMultEqMV( 1405 const GenBandMatrix<Ta>& A, VectorView<T> x) 1406 { 1407 if (A.nlo() == 0) NonBlasUpperMultEqMV(A,x); 1408 else NonBlasLowerMultEqMV(A,x); 1409 } 1410 #ifdef INST_DOUBLE 1411 template <> BlasMultEqMV(const GenBandMatrix<double> & A,VectorView<double> x)1412 void BlasMultEqMV( 1413 const GenBandMatrix<double>& A, VectorView<double> x) 1414 { 1415 bool up = A.nlo()==0; 1416 int n=A.colsize(); 1417 int lohi = up ? A.nhi() : A.nlo(); 1418 int aoffset = 1419 up && BlasIsCM(A) ? A.nhi() : 1420 !up && !BlasIsCM(A) ? A.nlo() : 0; 1421 int ds = A.diagstep(); 1422 int xs = x.step(); 1423 BLASNAME(dtbmv) ( 1424 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1425 BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU, 1426 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1427 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1428 } 1429 template <> BlasMultEqMV(const GenBandMatrix<std::complex<double>> & A,VectorView<std::complex<double>> x)1430 void BlasMultEqMV( 1431 const GenBandMatrix<std::complex<double> >& A, 1432 VectorView<std::complex<double> > x) 1433 { 1434 bool up = A.nlo()==0; 1435 int n=A.colsize(); 1436 int lohi = up ? A.nhi() : A.nlo(); 1437 int aoffset = 1438 up && BlasIsCM(A) ? A.nhi() : 1439 !up && !BlasIsCM(A) ? A.nlo() : 0; 1440 int ds = A.diagstep(); 1441 int xs = x.step(); 1442 if (BlasIsCM(A) && A.isconj()) { 1443 #ifdef CBLAS 1444 BLASNAME(ztbmv) ( 1445 BLASRM up ? BLASCH_LO : BLASCH_UP, BLASCH_CT, BLASCH_NU, 1446 BLASV(n),BLASV(lohi),BLASP(A.cptr()-A.nhi()), 1447 BLASV(ds),BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1448 #else 1449 x.conjugateSelf(); 1450 BLASNAME(ztbmv) ( 1451 BLASCM up?BLASCH_UP:BLASCH_LO, BLASCH_NT, BLASCH_NU, 1452 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1453 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1454 x.conjugateSelf(); 1455 #endif 1456 } else { 1457 BLASNAME(ztbmv) ( 1458 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1459 BlasIsCM(A) ? BLASCH_NT : A.isconj() ? BLASCH_CT : BLASCH_T, 1460 BLASCH_NU, 1461 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1462 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1463 } 1464 } 1465 template <> BlasMultEqMV(const GenBandMatrix<double> & A,VectorView<std::complex<double>> x)1466 void BlasMultEqMV( 1467 const GenBandMatrix<double>& A, 1468 VectorView<std::complex<double> > x) 1469 { 1470 bool up = A.nlo()==0; 1471 int n=A.colsize(); 1472 int lohi = up ? A.nhi() : A.nlo(); 1473 int aoffset = 1474 up && BlasIsCM(A) ? A.nhi() : 1475 !up && !BlasIsCM(A) ? A.nlo() : 0; 1476 int ds = A.diagstep(); 1477 int xs = 2*x.step(); 1478 BLASNAME(dtbmv) ( 1479 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1480 BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU, 1481 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1482 BLASP((double*)x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1483 BLASNAME(dtbmv) ( 1484 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1485 BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU, 1486 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1487 BLASP((double*)x.ptr()+1),BLASV(xs) BLAS1 BLAS1 BLAS1); 1488 } 1489 #endif 1490 #ifdef INST_FLOAT 1491 template <> BlasMultEqMV(const GenBandMatrix<float> & A,VectorView<float> x)1492 void BlasMultEqMV( 1493 const GenBandMatrix<float>& A, VectorView<float> x) 1494 { 1495 bool up = A.nlo()==0; 1496 int n=A.colsize(); 1497 int lohi = up ? A.nhi() : A.nlo(); 1498 int aoffset = 1499 up && BlasIsCM(A) ? A.nhi() : 1500 !up && !BlasIsCM(A) ? A.nlo() : 0; 1501 int ds = A.diagstep(); 1502 int xs = x.step(); 1503 BLASNAME(stbmv) ( 1504 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1505 BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU, 1506 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1507 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1508 } 1509 template <> BlasMultEqMV(const GenBandMatrix<std::complex<float>> & A,VectorView<std::complex<float>> x)1510 void BlasMultEqMV( 1511 const GenBandMatrix<std::complex<float> >& A, 1512 VectorView<std::complex<float> > x) 1513 { 1514 bool up = A.nlo()==0; 1515 int n=A.colsize(); 1516 int lohi = up ? A.nhi() : A.nlo(); 1517 int aoffset = 1518 up && BlasIsCM(A) ? A.nhi() : 1519 !up && !BlasIsCM(A) ? A.nlo() : 0; 1520 int ds = A.diagstep(); 1521 int xs = x.step(); 1522 //cout<<"Before ctbmv\n"; 1523 //cout<<"up = "<<up<<std::endl; 1524 //cout<<"n = "<<n<<std::endl; 1525 //cout<<"lohi = "<<lohi<<std::endl; 1526 //cout<<"aoffset = "<<aoffset<<std::endl; 1527 //cout<<"ds = "<<ds<<std::endl; 1528 //cout<<"xs = "<<xs<<std::endl; 1529 if (BlasIsCM(A) && A.isconj()) { 1530 //cout<<"cm && conj\n"; 1531 #ifdef CBLAS 1532 BLASNAME(ctbmv) ( 1533 BLASRM up ? BLASCH_LO : BLASCH_UP, BLASCH_CT, BLASCH_NU, 1534 BLASV(n),BLASV(lohi),BLASP(A.cptr()-A.nhi()), 1535 BLASV(ds),BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1536 #else 1537 x.conjugateSelf(); 1538 BLASNAME(ctbmv) ( 1539 BLASCM up?BLASCH_UP:BLASCH_LO, BLASCH_NT, BLASCH_NU, 1540 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1541 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1542 x.conjugateSelf(); 1543 #endif 1544 } else { 1545 //cout<<"!cm || !conj\n"; 1546 BLASNAME(ctbmv) ( 1547 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1548 BlasIsCM(A) ? BLASCH_NT : A.isconj() ? BLASCH_CT : BLASCH_T, 1549 BLASCH_NU, 1550 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1551 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1552 } 1553 } 1554 template <> BlasMultEqMV(const GenBandMatrix<float> & A,VectorView<std::complex<float>> x)1555 void BlasMultEqMV( 1556 const GenBandMatrix<float>& A, 1557 VectorView<std::complex<float> > x) 1558 { 1559 bool up = A.nlo()==0; 1560 int n=A.colsize(); 1561 int lohi = up ? A.nhi() : A.nlo(); 1562 int aoffset = 1563 up && BlasIsCM(A) ? A.nhi() : 1564 !up && !BlasIsCM(A) ? A.nlo() : 0; 1565 int ds = A.diagstep(); 1566 int xs = 2*x.step(); 1567 BLASNAME(stbmv) ( 1568 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1569 BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU, 1570 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1571 BLASP((float*)x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1); 1572 BLASNAME(stbmv) ( 1573 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO, 1574 BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU, 1575 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds), 1576 BLASP((float*)x.ptr()+1),BLASV(xs) BLAS1 BLAS1 BLAS1); 1577 } 1578 #endif 1579 #endif // BLAS 1580 1581 template <class T, class Ta> MultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1582 static void MultEqMV( 1583 const GenBandMatrix<Ta>& A, VectorView<T> x) 1584 { 1585 #ifdef XDEBUG 1586 cout<<"Start MultEqMV\n"; 1587 Vector<T> x0 = x; 1588 Matrix<Ta> A0 = A; 1589 Vector<T> x2 = A0 * x0; 1590 #endif 1591 TMVAssert(A.isSquare()); 1592 TMVAssert(A.colsize() == x.size()); 1593 TMVAssert(x.size() > 0); 1594 TMVAssert(x.step() == 1); 1595 TMVAssert(A.nlo() == 0 || A.nhi() == 0); 1596 1597 if (x.isconj()) MultEqMV(A.conjugate(),x.conjugate()); 1598 else { 1599 #ifdef BLAS 1600 if ( !(BlasIsRM(A) || BlasIsCM(A)) ) { 1601 //cout<<"!(rm or cm)\n"; 1602 BandMatrix<Ta,ColMajor> AA = A; 1603 BlasMultEqMV(AA,x); 1604 } else if (SameStorage(A,x) || x.step() != 1) { 1605 //cout<<"copy x\n"; 1606 Vector<T> xx = x; 1607 BlasMultEqMV(A,xx.view()); 1608 x = xx; 1609 } else { 1610 //cout<<"normal\n"; 1611 BlasMultEqMV(A,x); 1612 } 1613 #else 1614 if (A.nlo() == 0) NonBlasUpperMultEqMV(A,x); 1615 else NonBlasLowerMultEqMV(A,x); 1616 #endif 1617 } 1618 1619 #ifdef XDEBUG 1620 cout<<"Done MultEqMV\n"; 1621 if (!(Norm(x-x2) <= 0.001*(Norm(A0)*Norm(x0)))) { 1622 cerr<<"MultEqMV: \n"; 1623 cerr<<"A = "<<TMV_Text(A)<<" "<<A0<<endl; 1624 cerr<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x0<<endl; 1625 cerr<<"-> x = "<<x<<endl; 1626 cerr<<"x2 = "<<x2<<endl; 1627 abort(); 1628 } 1629 #endif 1630 } 1631 1632 template <bool add, class T, class Ta, class Tx> MultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)1633 void MultMV( 1634 const T alpha, const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, 1635 VectorView<T> y) 1636 // y (+)= alpha * A * x 1637 { 1638 TMVAssert(A.rowsize() == x.size()); 1639 TMVAssert(A.colsize() == y.size()); 1640 #ifdef XDEBUG 1641 cout<<"Start Band: MultMV\n"; 1642 cout<<"A = "<<TMV_Text(A)<<" "<<A.cptr()<<" "<<A<<endl; 1643 cout<<"x = "<<TMV_Text(x)<<" "<<x.cptr()<<" step "<<x.step()<<" "<<x<<endl; 1644 cout<<"y = "<<TMV_Text(y)<<" "<<y.cptr()<<" step "<<y.step()<<" "<<y<<endl; 1645 cout<<"alpha = "<<alpha<<", add = "<<add<<endl; 1646 Vector<T> y0 = y; 1647 cout<<"y0 = "<<y0<<std::endl; 1648 Vector<Tx> x0 = x; 1649 cout<<"x0 = "<<x0<<std::endl; 1650 Matrix<Ta> A0 = A; 1651 cout<<"A0 = "<<A0<<std::endl; 1652 Vector<T> y2 = alpha*A0*x0; 1653 cout<<"y2 = "<<y2<<std::endl; 1654 if (add) y2 += y0; 1655 cout<<"y2 => "<<y2<<std::endl; 1656 #endif 1657 1658 if (y.size() > 0) { 1659 if (x.size()==0 || alpha==T(0)) { 1660 if (!add) y.setZero(); 1661 } else if (A.rowsize() > A.colsize()+A.nhi()) { 1662 MultMV<add>( 1663 alpha,A.colRange(0,A.colsize()+A.nhi()), 1664 x.subVector(0,A.colsize()+A.nhi()),y); 1665 } else if (A.colsize() > A.rowsize()+A.nlo()) { 1666 MultMV<add>( 1667 alpha,A.rowRange(0,A.rowsize()+A.nlo()), 1668 x,y.subVector(0,A.rowsize()+A.nlo())); 1669 if (!add) 1670 y.subVector(A.rowsize()+A.nlo(),A.colsize()).setZero(); 1671 } else if (A.isSquare() && (A.nlo() == 0 || A.nhi() == 0)) { 1672 if (A.nlo() == 0 && A.nhi() == 0) 1673 MultMV<add>(alpha,DiagMatrixViewOf(A.diag()),x,y); 1674 else if (!add && y.step() == 1) { 1675 y = alpha * x; 1676 MultEqMV(A,y); 1677 } else { 1678 Vector<T> xx = alpha*x; 1679 MultEqMV(A,xx.view()); 1680 if (add) y += xx; 1681 else y = xx; 1682 } 1683 } else { 1684 if (SameStorage(y,A)) { 1685 Vector<T> yy(y.size()); 1686 DoMultMV<false>(T(1),A,x,yy.view()); 1687 if (add) y += alpha*yy; 1688 else y = alpha*yy; 1689 } else { 1690 DoMultMV<add>(alpha,A,x,y); 1691 } 1692 } 1693 } 1694 #ifdef XDEBUG 1695 cout<<"->y = "<<y<<endl; 1696 if (!(Norm(y2-y) <= 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 1697 (add?Norm(y0):TMV_RealType(T)(0))))) { 1698 cerr<<"MultMV: alpha = "<<alpha<<endl; 1699 cerr<<"add = "<<add<<endl; 1700 cerr<<"A = "<<TMV_Text(A)<<" "<<A.cptr()<<" "<<A0<<endl; 1701 cerr<<"x = "<<TMV_Text(x)<<" "<<x.cptr()<<" step "<<x.step()<<" "<<x0<<endl; 1702 cerr<<"y = "<<TMV_Text(y)<<" "<<y.cptr()<<" step "<<y.step()<<" "<<y0<<endl; 1703 cerr<<"--> y = "<<y<<endl; 1704 cerr<<"y2 = "<<y2<<endl; 1705 cerr<<"Norm(diff) = "<<Norm(y2-y)<<endl; 1706 cerr<<"abs(alpha)*Norm(A0)*Norm(x0) = "<<TMV_ABS(alpha)*Norm(A0)*Norm(x0)<<endl; 1707 cerr<<"Norm(y0) = "<<Norm(y0)<<endl; 1708 abort(); 1709 } 1710 #endif 1711 } 1712 1713 #define InstFile "TMV_MultBV.inst" 1714 #include "TMV_Inst.h" 1715 #undef InstFile 1716 1717 } // namespace tmv 1718 1719 1720