1 /////////////////////////////////////////////////////////////////////////////// 2 // // 3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis // 4 // Copyright (C) 1998 - 2016 // 5 // All rights reserved // 6 // // 7 // The project is hosted at https://code.google.com/p/tmv-cpp/ // 8 // where you can find the current version and current documention. // 9 // // 10 // For concerns or problems with the software, Mike may be contacted at // 11 // mike_jarvis17 [at] gmail. // 12 // // 13 // This software is licensed under a FreeBSD license. The file // 14 // TMV_LICENSE should have bee included with this distribution. // 15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/. // 16 // // 17 // Essentially, you can use this software however you want provided that // 18 // you include the TMV_LICENSE file in any distribution that uses it. // 19 // // 20 /////////////////////////////////////////////////////////////////////////////// 21 22 23 //#define XDEBUG 24 25 26 #include "TMV_Blas.h" 27 #include "tmv/TMV_SymMatrixArithFunc.h" 28 #include "tmv/TMV_SymMatrix.h" 29 #include "tmv/TMV_Vector.h" 30 #include "tmv/TMV_VectorArith.h" 31 #include "tmv/TMV_TriMatrixArith.h" 32 #include "tmv/TMV_MatrixArithFunc.h" 33 #ifdef BLAS 34 #include "tmv/TMV_SymMatrixArith.h" 35 #endif 36 37 #ifdef XDEBUG 38 #include <iostream> 39 #include "tmv/TMV_MatrixArith.h" 40 using std::cout; 41 using std::cerr; 42 using std::endl; 43 #endif 44 45 namespace tmv { 46 47 template <class T> cptr() const48 const T* SymMatrixComposite<T>::cptr() const 49 { 50 if (!itsm.get()) { 51 ptrdiff_t s = this->size(); 52 ptrdiff_t len = s*s; 53 itsm.resize(len); 54 this->assignToS(SymMatrixView<T>( 55 itsm.get(),s,stepi(),stepj(),Sym,uplo(),NonConj 56 TMV_FIRSTLAST1(itsm.get(),itsm.get()+len) )); 57 } 58 return itsm.get(); 59 } 60 61 template <class T> stepi() const62 ptrdiff_t SymMatrixComposite<T>::stepi() const 63 { return 1; } 64 65 template <class T> stepj() const66 ptrdiff_t SymMatrixComposite<T>::stepj() const 67 { return this->size(); } 68 69 // 70 // MultMV 71 // 72 73 template <bool add, class T, class Ta, class Tx> DoUnitAMultMV(const GenSymMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)74 static void DoUnitAMultMV( 75 const GenSymMatrix<Ta>& A, const GenVector<Tx>& x, 76 VectorView<T> y) 77 { 78 const ptrdiff_t N = A.size(); 79 if (add) y += A.lowerTri() * x; 80 else y = A.lowerTri() * x; 81 82 if (N > 1) 83 y.subVector(0,N-1) += A.upperTri().offDiag() * x.subVector(1,N); 84 } 85 86 template <bool add, class T, class Ta, class Tx> UnitAMultMV(const GenSymMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)87 static void UnitAMultMV( 88 const GenSymMatrix<Ta>& A, const GenVector<Tx>& x, 89 VectorView<T> y) 90 { 91 // Check for 0's in the beginning or end of x: 92 // [ A11 A12 A13 ] [ 0 ] [ A12 ] 93 // y += [ A21 A22 A23 ] [ x ] --> y += [ A22 ] x 94 // [ A31 A32 A33 ] [ 0 ] [ A32 ] 95 96 const ptrdiff_t N = x.size(); // == A.size() 97 ptrdiff_t j2 = N; 98 for(const Tx* x2=x.cptr()+N-1; j2>0 && *x2==Tx(0); --j2,--x2); 99 if (j2 == 0) { 100 if (!add) y.setZero(); 101 return; 102 } 103 ptrdiff_t j1 = 0; 104 for(const Tx* x1=x.cptr(); *x1==Tx(0); ++j1,++x1); 105 if (j1 == 0 && j2 == N) DoUnitAMultMV<add>(A,x,y); 106 else { 107 if (j1 > 0) 108 MultMV<add>(T(1),A.subMatrix(0,j1,j1,j2),x.subVector(j1,j2), 109 y.subVector(0,j1)); 110 TMVAssert(j1 != j2); 111 DoUnitAMultMV<add>(A.subSymMatrix(j1,j2),x.subVector(j1,j2), 112 y.subVector(j1,j2)); 113 if (j2 < N) 114 MultMV<add>(T(1),A.subMatrix(j2,N,j1,j2),x.subVector(j1,j2), 115 y.subVector(j2,N)); 116 } 117 } 118 119 template <bool add, class T, class Ta, class Tx> NonBlasMultMV(const T alpha,const GenSymMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)120 static void NonBlasMultMV( 121 const T alpha, const GenSymMatrix<Ta>& A, const GenVector<Tx>& x, 122 VectorView<T> y) 123 // y (+)= alpha * A * x 124 { 125 TMVAssert(A.size() == x.size()); 126 TMVAssert(A.size() == y.size()); 127 TMVAssert(alpha != T(0)); 128 TMVAssert(y.size() > 0); 129 TMVAssert(!SameStorage(x,y)); 130 131 if (A.uplo() == Upper) 132 if (A.isherm()) NonBlasMultMV<add>(alpha,A.adjoint(),x,y); 133 else NonBlasMultMV<add>(alpha,A.transpose(),x,y); 134 else if (y.isconj()) 135 NonBlasMultMV<add>(TMV_CONJ(alpha),A.conjugate(),x.conjugate(), 136 y.conjugate()); 137 else { 138 if (x.step() != 1) { 139 if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 140 Vector<Tx> xx = TMV_REAL(alpha)*x; 141 if (y.step() != 1) { 142 Vector<T> yy(y.size()); 143 UnitAMultMV<false>(A,xx,yy.view()); 144 if (!add) y = yy; 145 else y += yy; 146 } 147 else 148 UnitAMultMV<add>(A,xx,y); 149 } else { 150 Vector<T> xx = alpha*x; 151 if (y.step()!=1) { 152 Vector<T> yy(y.size()); 153 UnitAMultMV<false>(A,xx,yy.view()); 154 if (add) y += yy; 155 else y = yy; 156 } 157 else 158 UnitAMultMV<add>(A,xx,y); 159 } 160 } else if (y.step()!=1 || alpha!=T(1)) { 161 Vector<T> yy(y.size()); 162 UnitAMultMV<false>(A,x,yy.view()); 163 if (add) y += alpha * yy; 164 else y = alpha * yy; 165 } else { 166 TMVAssert(alpha == T(1)); 167 TMVAssert(y.step() == 1); 168 TMVAssert(x.step() == 1); 169 UnitAMultMV<add>(A,x,y); 170 } 171 } 172 } 173 174 #ifdef BLAS 175 template <class T, class Ta, class Tx> BlasMultMV(const T alpha,const GenSymMatrix<Ta> & A,const GenVector<Tx> & x,int beta,VectorView<T> y)176 static inline void BlasMultMV( 177 const T alpha, const GenSymMatrix<Ta>& A, 178 const GenVector<Tx>& x, int beta, VectorView<T> y) 179 { 180 if (beta==1) NonBlasMultMV<true>(alpha,A,x,y); 181 else NonBlasMultMV<false>(alpha,A,x,y); 182 } 183 #ifdef INST_DOUBLE 184 template <> BlasMultMV(const double alpha,const GenSymMatrix<double> & A,const GenVector<double> & x,int beta,VectorView<double> y)185 void BlasMultMV( 186 const double alpha, 187 const GenSymMatrix<double>& A, const GenVector<double>& x, 188 int beta, VectorView<double> y) 189 { 190 int n = A.size(); 191 int lda = A.stepj(); 192 int xs = x.step(); 193 int ys = y.step(); 194 const double* xp = x.cptr(); 195 if (xs < 0) xp += (n-1)*xs; 196 double* yp = y.ptr(); 197 if (ys < 0) yp += (n-1)*ys; 198 if (beta == 0) y.setZero(); 199 double xbeta(1); 200 BLASNAME(dsymv) ( 201 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 202 BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda), 203 BLASP(xp),BLASV(xs),BLASV(xbeta), 204 BLASP(yp),BLASV(ys) BLAS1); 205 } 206 template <> BlasMultMV(const std::complex<double> alpha,const GenSymMatrix<std::complex<double>> & A,const GenVector<std::complex<double>> & x,int beta,VectorView<std::complex<double>> y)207 void BlasMultMV( 208 const std::complex<double> alpha, 209 const GenSymMatrix<std::complex<double> >& A, 210 const GenVector<std::complex<double> >& x, 211 int beta, VectorView<std::complex<double> > y) 212 { 213 if (A.isherm()) { 214 int n = A.size(); 215 int lda = A.stepj(); 216 int xs = x.step(); 217 int ys = y.step(); 218 const std::complex<double>* xp = x.cptr(); 219 if (xs < 0) xp += (n-1)*xs; 220 std::complex<double>* yp = y.ptr(); 221 if (ys < 0) yp += (n-1)*ys; 222 if (beta == 0) y.setZero(); 223 std::complex<double> xbeta(1); 224 BLASNAME(zhemv) ( 225 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 226 BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 227 BLASP(xp),BLASV(xs),BLASP(&xbeta), 228 BLASP(yp),BLASV(ys) BLAS1); 229 } else { 230 #ifdef ELAP 231 int n = A.size(); 232 int lda = A.stepj(); 233 int xs = x.step(); 234 int ys = y.step(); 235 const std::complex<double>* xp = x.cptr(); 236 if (xs < 0) xp += (n-1)*xs; 237 std::complex<double>* yp = y.ptr(); 238 if (ys < 0) yp += (n-1)*ys; 239 if (beta == 0) y.setZero(); 240 std::complex<double> xbeta(1); 241 LAPNAME(zsymv) ( 242 LAPCM A.uplo()==Upper ? LAPCH_UP : LAPCH_LO, 243 LAPV(n),LAPP(&alpha),LAPP(A.cptr()),LAPV(lda), 244 LAPP(xp),LAPV(xs),LAPP(&xbeta), 245 LAPP(yp),LAPV(ys) LAP1); 246 #else 247 if (beta==1) 248 NonBlasMultMV<true>(alpha,A,x,y); 249 else 250 NonBlasMultMV<false>(alpha,A,x,y); 251 #endif 252 } 253 } 254 template <> BlasMultMV(const std::complex<double> alpha,const GenSymMatrix<std::complex<double>> & A,const GenVector<double> & x,int beta,VectorView<std::complex<double>> y)255 void BlasMultMV( 256 const std::complex<double> alpha, 257 const GenSymMatrix<std::complex<double> >& A, 258 const GenVector<double>& x, 259 int beta, VectorView<std::complex<double> > y) 260 { BlasMultMV(alpha,A,Vector<std::complex<double> >(x),beta,y); } 261 template <> BlasMultMV(const std::complex<double> alpha,const GenSymMatrix<double> & A,const GenVector<std::complex<double>> & x,int beta,VectorView<std::complex<double>> y)262 void BlasMultMV( 263 const std::complex<double> alpha, 264 const GenSymMatrix<double>& A, 265 const GenVector<std::complex<double> >& x, 266 int beta, VectorView<std::complex<double> > y) 267 { 268 if (beta == 0) { 269 int n = A.size(); 270 int lda = A.stepj(); 271 int xs = 2*x.step(); 272 int ys = 2*y.step(); 273 const double* xp = (const double*) x.cptr(); 274 if (xs < 0) xp += (n-1)*xs; 275 double* yp = (double*) y.ptr(); 276 if (ys < 0) yp += (n-1)*ys; 277 double xalpha(1); 278 y.setZero(); 279 double xbeta(1); 280 BLASNAME(dsymv) ( 281 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 282 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 283 BLASP(xp),BLASV(xs),BLASV(xbeta), 284 BLASP(yp),BLASV(ys) BLAS1); 285 BLASNAME(dsymv) ( 286 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 287 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 288 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 289 BLASP(yp+1),BLASV(ys) BLAS1); 290 if (x.isconj()) y.conjugateSelf(); 291 y *= alpha; 292 } else if (TMV_IMAG(alpha) == 0. && !x.isconj()) { 293 int n = A.size(); 294 int lda = A.stepj(); 295 int xs = 2*x.step(); 296 int ys = 2*y.step(); 297 const double* xp = (const double*) x.cptr(); 298 if (xs < 0) xp += (n-1)*xs; 299 double* yp = (double*) y.ptr(); 300 if (ys < 0) yp += (n-1)*ys; 301 double xalpha(TMV_REAL(alpha)); 302 double xbeta(1); 303 BLASNAME(dsymv) ( 304 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 305 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 306 BLASP(xp),BLASV(xs),BLASV(xbeta), 307 BLASP(yp),BLASV(ys) BLAS1); 308 BLASNAME(dsymv) ( 309 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 310 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 311 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 312 BLASP(yp+1),BLASV(ys) BLAS1); 313 } else { 314 Vector<std::complex<double> > xx = alpha*x; 315 BlasMultMV(std::complex<double>(1),A,xx,1,y); 316 } 317 } 318 template <> BlasMultMV(const std::complex<double> alpha,const GenSymMatrix<double> & A,const GenVector<double> & x,int beta,VectorView<std::complex<double>> y)319 void BlasMultMV( 320 const std::complex<double> alpha, 321 const GenSymMatrix<double>& A, 322 const GenVector<double>& x, 323 int beta, VectorView<std::complex<double> > y) 324 { 325 int n = A.size(); 326 int lda = A.stepj(); 327 int xs = x.step(); 328 int ys = 2*y.step(); 329 const double* xp = x.cptr(); 330 if (xs < 0) xp += (n-1)*xs; 331 double* yp = (double*) y.ptr(); 332 if (ys < 0) yp += (n-1)*ys; 333 double ar(TMV_REAL(alpha)); 334 double ai(TMV_IMAG(alpha)); 335 if (beta == 0) y.setZero(); 336 double xbeta(1); 337 if (ar != 0.) { 338 BLASNAME(dsymv) ( 339 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 340 BLASV(n),BLASV(ar),BLASP(A.cptr()),BLASV(lda), 341 BLASP(xp),BLASV(xs),BLASV(xbeta), 342 BLASP(yp),BLASV(ys) BLAS1); 343 } 344 if (ai != 0.) { 345 BLASNAME(dsymv) ( 346 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 347 BLASV(n),BLASV(ai),BLASP(A.cptr()),BLASV(lda), 348 BLASP(xp),BLASV(xs),BLASV(xbeta), 349 BLASP(yp+1),BLASV(ys) BLAS1); 350 } 351 } 352 #endif 353 #ifdef INST_FLOAT 354 template <> BlasMultMV(const float alpha,const GenSymMatrix<float> & A,const GenVector<float> & x,int beta,VectorView<float> y)355 void BlasMultMV( 356 const float alpha, 357 const GenSymMatrix<float>& A, const GenVector<float>& x, 358 int beta, VectorView<float> y) 359 { 360 int n = A.size(); 361 int lda = A.stepj(); 362 int xs = x.step(); 363 int ys = y.step(); 364 const float* xp = x.cptr(); 365 if (xs < 0) xp += (n-1)*xs; 366 float* yp = y.ptr(); 367 if (ys < 0) yp += (n-1)*ys; 368 if (beta == 0) y.setZero(); 369 float xbeta(1); 370 BLASNAME(ssymv) ( 371 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 372 BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda), 373 BLASP(xp),BLASV(xs),BLASV(xbeta), 374 BLASP(yp),BLASV(ys) BLAS1); 375 } 376 template <> BlasMultMV(const std::complex<float> alpha,const GenSymMatrix<std::complex<float>> & A,const GenVector<std::complex<float>> & x,int beta,VectorView<std::complex<float>> y)377 void BlasMultMV( 378 const std::complex<float> alpha, 379 const GenSymMatrix<std::complex<float> >& A, 380 const GenVector<std::complex<float> >& x, 381 int beta, VectorView<std::complex<float> > y) 382 { 383 if (A.isherm()) { 384 int n = A.size(); 385 int lda = A.stepj(); 386 int xs = x.step(); 387 int ys = y.step(); 388 const std::complex<float>* xp = x.cptr(); 389 if (xs < 0) xp += (n-1)*xs; 390 std::complex<float>* yp = y.ptr(); 391 if (ys < 0) yp += (n-1)*ys; 392 if (beta == 0) y.setZero(); 393 std::complex<float> xbeta(1); 394 BLASNAME(chemv) ( 395 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 396 BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 397 BLASP(xp),BLASV(xs),BLASP(&xbeta), 398 BLASP(yp),BLASV(ys) BLAS1); 399 } else { 400 #ifdef ELAP 401 int n = A.size(); 402 int lda = A.stepj(); 403 int xs = x.step(); 404 int ys = y.step(); 405 const std::complex<float>* xp = x.cptr(); 406 if (xs < 0) xp += (n-1)*xs; 407 std::complex<float>* yp = y.ptr(); 408 if (ys < 0) yp += (n-1)*ys; 409 if (beta == 0) y.setZero(); 410 std::complex<float> xbeta(1); 411 LAPNAME(csymv) ( 412 LAPCM A.uplo()==Upper ? LAPCH_UP : LAPCH_LO, 413 LAPV(n),LAPP(&alpha),LAPP(A.cptr()),LAPV(lda), 414 LAPP(xp),LAPV(xs),LAPP(&xbeta), 415 LAPP(yp),LAPV(ys) LAP1); 416 #else 417 if (beta==1) 418 NonBlasMultMV<true>(alpha,A,x,y); 419 else 420 NonBlasMultMV<false>(alpha,A,x,y); 421 #endif 422 } 423 } 424 template <> BlasMultMV(const std::complex<float> alpha,const GenSymMatrix<std::complex<float>> & A,const GenVector<float> & x,int beta,VectorView<std::complex<float>> y)425 void BlasMultMV( 426 const std::complex<float> alpha, 427 const GenSymMatrix<std::complex<float> >& A, 428 const GenVector<float>& x, 429 int beta, VectorView<std::complex<float> > y) 430 { BlasMultMV(alpha,A,Vector<std::complex<float> >(x),beta,y); } 431 template <> BlasMultMV(const std::complex<float> alpha,const GenSymMatrix<float> & A,const GenVector<std::complex<float>> & x,int beta,VectorView<std::complex<float>> y)432 void BlasMultMV( 433 const std::complex<float> alpha, 434 const GenSymMatrix<float>& A, 435 const GenVector<std::complex<float> >& x, 436 int beta, VectorView<std::complex<float> > y) 437 { 438 if (beta == 0) { 439 int n = A.size(); 440 int lda = A.stepj(); 441 int xs = 2*x.step(); 442 int ys = 2*y.step(); 443 const float* xp = (const float*) x.cptr(); 444 if (xs < 0) xp += (n-1)*xs; 445 float* yp = (float*) y.ptr(); 446 if (ys < 0) yp += (n-1)*ys; 447 float xalpha(1); 448 y.setZero(); 449 float xbeta(1); 450 BLASNAME(ssymv) ( 451 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 452 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 453 BLASP(xp),BLASV(xs),BLASV(xbeta), 454 BLASP(yp),BLASV(ys) BLAS1); 455 BLASNAME(ssymv) ( 456 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 457 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 458 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 459 BLASP(yp+1),BLASV(ys) BLAS1); 460 if (x.isconj()) y.conjugateSelf(); 461 y *= alpha; 462 } else if (TMV_IMAG(alpha) == 0.F && !x.isconj()) { 463 int n = A.size(); 464 int lda = A.stepj(); 465 int xs = 2*x.step(); 466 int ys = 2*y.step(); 467 const float* xp = (const float*) x.cptr(); 468 if (xs < 0) xp += (n-1)*xs; 469 float* yp = (float*) y.ptr(); 470 if (ys < 0) yp += (n-1)*ys; 471 float xalpha(TMV_REAL(alpha)); 472 float xbeta(1); 473 BLASNAME(ssymv) ( 474 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 475 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 476 BLASP(xp),BLASV(xs),BLASV(xbeta), 477 BLASP(yp),BLASV(ys) BLAS1); 478 BLASNAME(ssymv) ( 479 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 480 BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 481 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 482 BLASP(yp+1),BLASV(ys) BLAS1); 483 } else { 484 Vector<std::complex<float> > xx = alpha*x; 485 BlasMultMV(std::complex<float>(1),A,xx,1,y); 486 } 487 } 488 template <> BlasMultMV(const std::complex<float> alpha,const GenSymMatrix<float> & A,const GenVector<float> & x,int beta,VectorView<std::complex<float>> y)489 void BlasMultMV( 490 const std::complex<float> alpha, 491 const GenSymMatrix<float>& A, 492 const GenVector<float>& x, 493 int beta, VectorView<std::complex<float> > y) 494 { 495 int n = A.size(); 496 int lda = A.stepj(); 497 int xs = x.step(); 498 int ys = 2*y.step(); 499 const float* xp = x.cptr(); 500 if (xs < 0) xp += (n-1)*xs; 501 float* yp = (float*) y.ptr(); 502 if (ys < 0) yp += (n-1)*ys; 503 float ar(TMV_REAL(alpha)); 504 float ai(TMV_IMAG(alpha)); 505 if (beta == 0) y.setZero(); 506 float xbeta(1); 507 if (ar != 0.F) { 508 BLASNAME(ssymv) ( 509 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 510 BLASV(n),BLASV(ar),BLASP(A.cptr()),BLASV(lda), 511 BLASP(xp),BLASV(xs),BLASV(xbeta), 512 BLASP(yp),BLASV(ys) BLAS1); 513 } 514 if (ai != 0.F) { 515 BLASNAME(ssymv) ( 516 BLASCM A.uplo() == Upper?BLASCH_UP:BLASCH_LO, 517 BLASV(n),BLASV(ai),BLASP(A.cptr()),BLASV(lda), 518 BLASP(xp),BLASV(xs),BLASV(xbeta), 519 BLASP(yp+1),BLASV(ys) BLAS1); 520 } 521 } 522 #endif 523 #endif // BLAS 524 525 template <bool add, class T, class Ta, class Tx> DoMultMV(const T alpha,const GenSymMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)526 static void DoMultMV( 527 const T alpha, const GenSymMatrix<Ta>& A, 528 const GenVector<Tx>& x, VectorView<T> y) 529 { 530 TMVAssert(A.rowsize() == x.size()); 531 TMVAssert(A.colsize() == y.size()); 532 TMVAssert(alpha != T(0)); 533 TMVAssert(x.size() > 0); 534 TMVAssert(y.size() > 0); 535 TMVAssert(!SameStorage(x,y)); 536 537 #ifdef BLAS 538 if (A.isconj()) 539 DoMultMV<add>( 540 TMV_CONJ(alpha),A.conjugate(),x.conjugate(),y.conjugate()); 541 else if (!A.iscm() && A.isrm()) 542 if (A.isherm()) DoMultMV<add>(alpha,A.adjoint(),x,y); 543 else DoMultMV<add>(alpha,A.transpose(),x,y); 544 else if (x.step() == 0) { 545 if (x.size() <= 1) 546 DoMultMV<add>( 547 alpha,A,ConstVectorView<Tx>(x.cptr(),x.size(),1,x.ct()),y); 548 else 549 DoMultMV<add>(alpha,A,Vector<Tx>(x),y); 550 } else if (y.step() == 0) { 551 TMVAssert(y.size() <= 1); 552 DoMultMV<add>(alpha,A,x,VectorView<T>(y.ptr(),y.size(),1,y.ct())); 553 } else if (A.iscm()&&A.stepj()>0) { 554 if (!y.isconj() && y.step() != 1) { 555 if (!x.isconj() && x.step() != 1) { 556 BlasMultMV(alpha,A,x,add?1:0,y); 557 } else { 558 Vector<T> xx = alpha*x; 559 BlasMultMV(T(1),A,xx,add?1:0,y); 560 } 561 } else { 562 Vector<T> yy(y.size()); 563 if (!x.isconj() && x.step() != 1) { 564 BlasMultMV(T(1),A,x,0,yy.view()); 565 if (add) y += alpha*yy; 566 else y = alpha*yy; 567 } else { 568 Vector<T> xx = alpha*x; 569 BlasMultMV(T(1),A,xx,0,yy.view()); 570 if (add) y += yy; 571 else y = yy; 572 } 573 } 574 } else if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 575 if (A.isherm()) { 576 if (A.uplo() == Upper) { 577 HermMatrix<Ta,Upper|ColMajor> A2 = 578 TMV_REAL(alpha)*A; 579 DoMultMV<add>(T(1),A2,x,y); 580 } else { 581 HermMatrix<Ta,Lower|ColMajor> A2 = 582 TMV_REAL(alpha)*A; 583 DoMultMV<add>(T(1),A2,x,y); 584 } 585 } else { 586 if (A.uplo() == Upper) { 587 SymMatrix<Ta,Upper|ColMajor> A2 = 588 TMV_REAL(alpha)*A; 589 DoMultMV<add>(T(1),A2,x,y); 590 } else { 591 SymMatrix<Ta,Lower|ColMajor> A2 = 592 TMV_REAL(alpha)*A; 593 DoMultMV<add>(T(1),A2,x,y); 594 } 595 } 596 } else { 597 if (!A.issym()) { 598 if (A.uplo() == Upper) { 599 HermMatrix<Ta,Upper|ColMajor> A2 = A; 600 DoMultMV<add>(alpha,A2,x,y); 601 } else { 602 HermMatrix<Ta,Lower|ColMajor> A2 = A; 603 DoMultMV<add>(alpha,A2,x,y); 604 } 605 } else { 606 if (A.uplo() == Upper) { 607 SymMatrix<T,Upper|ColMajor> A2 = alpha*A; 608 DoMultMV<add>(T(1),A2,x,y); 609 } else { 610 SymMatrix<T,Lower|ColMajor> A2 = alpha*A; 611 DoMultMV<add>(T(1),A2,x,y); 612 } 613 } 614 } 615 #else 616 NonBlasMultMV<add>(alpha,A,x,y); 617 #endif 618 } 619 620 template <bool add, class T, class Ta, class Tx> MultMV(const T alpha,const GenSymMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)621 void MultMV( 622 const T alpha, const GenSymMatrix<Ta>& A, const GenVector<Tx>& x, 623 VectorView<T> y) 624 // y (+)= alpha * A * x 625 { 626 TMVAssert(A.rowsize() == x.size()); 627 TMVAssert(A.colsize() == y.size()); 628 #ifdef XDEBUG 629 cout<<"Start MultMV: alpha = "<<alpha<<endl; 630 cout<<"A = "<<TMV_Text(A)<<" "<<A<<endl; 631 cout<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x<<endl; 632 cout<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<" "<<y<<endl; 633 Matrix<Ta> A0 = A; 634 Vector<Tx> x0 = x; 635 Vector<T> y0 = y; 636 Vector<T> y2 = alpha*A0*x0; 637 if (add) y2 += y0; 638 #endif 639 640 if (y.size() > 0) { 641 if (x.size()==0 || alpha==T(0)) { 642 if (!add) y.setZero(); 643 } else if (SameStorage(x,y)) { 644 Vector<T> yy(y.size()); 645 DoMultMV<false>(T(1),A,x,yy.view()); 646 if (add) y += alpha*yy; 647 else y = alpha*yy; 648 } else { 649 DoMultMV<add>(alpha,A,x,y); 650 } 651 } 652 #ifdef XDEBUG 653 cout<<"--> y = "<<y<<endl; 654 cout<<"y2 = "<<y2<<endl; 655 if (!(Norm(y-y2) <= 656 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 657 (add?Norm(y0):TMV_RealType(T)(0))))) { 658 cerr<<"MultMV: alpha = "<<alpha<<endl; 659 cerr<<"A = "<<TMV_Text(A)<<" "<<A0<<endl; 660 cerr<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x0<<endl; 661 cerr<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<" "<<y0<<endl; 662 cerr<<"--> y = "<<y<<endl; 663 cerr<<"y2 = "<<y2<<endl; 664 abort(); 665 } 666 #endif 667 } 668 669 #define InstFile "TMV_MultSV.inst" 670 #include "TMV_Inst.h" 671 #undef InstFile 672 673 } // namespace tmv 674 675 676