1 /////////////////////////////////////////////////////////////////////////////// 2 // // 3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis // 4 // Copyright (C) 1998 - 2016 // 5 // All rights reserved // 6 // // 7 // The project is hosted at https://code.google.com/p/tmv-cpp/ // 8 // where you can find the current version and current documention. // 9 // // 10 // For concerns or problems with the software, Mike may be contacted at // 11 // mike_jarvis17 [at] gmail. // 12 // // 13 // This software is licensed under a FreeBSD license. The file // 14 // TMV_LICENSE should have bee included with this distribution. // 15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/. // 16 // // 17 // Essentially, you can use this software however you want provided that // 18 // you include the TMV_LICENSE file in any distribution that uses it. // 19 // // 20 /////////////////////////////////////////////////////////////////////////////// 21 22 23 //#define XDEBUG 24 25 26 #include "TMV_Blas.h" 27 #include "tmv/TMV_TriMatrixArithFunc.h" 28 #include "tmv/TMV_TriMatrix.h" 29 #include "tmv/TMV_Vector.h" 30 #include "tmv/TMV_VectorArith.h" 31 #include "TMV_MultMV.h" 32 33 // CBLAS trick of using RowMajor with ConjTrans when we have a 34 // case of A.conjugate() * x doesn't seem to be working with MKL 10.2.2. 35 // I haven't been able to figure out why. (e.g. Is it a bug in the MKL 36 // code, or am I doing something wrong?) So for now, just disable it. 37 #ifdef CBLAS 38 #undef CBLAS 39 #endif 40 41 #ifdef XDEBUG 42 #include "tmv/TMV_MatrixArith.h" 43 #include <iostream> 44 using std::cout; 45 using std::cerr; 46 using std::endl; 47 #endif 48 49 namespace tmv { 50 51 template <class T> cptr() const52 const T* UpperTriMatrixComposite<T>::cptr() const 53 { 54 if (!itsm.get()) { 55 itsm.resize(this->size()*this->size()); 56 UpperTriMatrixView<T>( 57 itsm.get(),this->size(),stepi(),stepj(),this->dt(),NonConj) = 58 *this; 59 } 60 return itsm.get(); 61 } 62 63 template <class T> stepi() const64 ptrdiff_t UpperTriMatrixComposite<T>::stepi() const 65 { return 1; } 66 67 template <class T> stepj() const68 ptrdiff_t UpperTriMatrixComposite<T>::stepj() const 69 { return this->size(); } 70 71 template <class T> cptr() const72 const T* LowerTriMatrixComposite<T>::cptr() const 73 { 74 if (!itsm.get()) { 75 itsm.resize(this->size()*this->size()); 76 LowerTriMatrixView<T>( 77 itsm.get(),this->size(),stepi(),stepj(),this->dt(),NonConj) = 78 *this; 79 } 80 return itsm.get(); 81 } 82 83 template <class T> stepi() const84 ptrdiff_t LowerTriMatrixComposite<T>::stepi() const 85 { return 1; } 86 87 template <class T> stepj() const88 ptrdiff_t LowerTriMatrixComposite<T>::stepj() const 89 { return this->size(); } 90 91 // 92 // MultEqMV 93 // 94 95 template <bool rm, bool ca, bool ua, class T, class Ta> DoRowMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)96 static void DoRowMultEqMV( 97 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 98 { 99 //cout<<"RowMultEqMV Upper\n"; 100 TMVAssert(x.step()==1); 101 TMVAssert(A.size() == x.size()); 102 TMVAssert(x.size() > 0); 103 TMVAssert(x.ct() == NonConj); 104 TMVAssert(rm == A.isrm()); 105 TMVAssert(ca == A.isconj()); 106 TMVAssert(ua == A.isunit()); 107 108 const ptrdiff_t N = x.size(); 109 110 const ptrdiff_t sj = rm ? 1 : A.stepj(); 111 const ptrdiff_t ds = A.stepi()+sj; 112 T* xi = x.ptr(); 113 const Ta* Aii = A.cptr(); 114 ptrdiff_t len = N-1; 115 116 for(; len>0; --len,++xi,Aii+=ds) { 117 // i = 0..N-2 118 // x(i) = A.row(i,i,N) * x.subVector(i,N); 119 if (!ua) *xi *= (ca ? TMV_CONJ(*Aii) : *Aii); 120 const T* xj = xi+1; 121 const Ta* Aij = Aii+sj; 122 for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) { 123 #ifdef TMVFLDEBUG 124 TMVAssert(xi >= x._first); 125 TMVAssert(xi < x._last); 126 #endif 127 *xi += (*xj) * (ca ? TMV_CONJ(*Aij) : *Aij); 128 } 129 } 130 #ifdef TMVFLDEBUG 131 TMVAssert(xi >= x._first); 132 TMVAssert(xi < x._last); 133 #endif 134 if (!ua) *xi *= (ca ? TMV_CONJ(*Aii) : *Aii); 135 } 136 137 template <bool rm, class T, class Ta> RowMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)138 static void RowMultEqMV( 139 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 140 { 141 if (A.isconj()) 142 if (A.isunit()) 143 DoRowMultEqMV<rm,true,true>(A,x); 144 else 145 DoRowMultEqMV<rm,true,false>(A,x); 146 else 147 if (A.isunit()) 148 DoRowMultEqMV<rm,false,true>(A,x); 149 else 150 DoRowMultEqMV<rm,false,false>(A,x); 151 } 152 153 template <bool cm, bool ca, bool ua, class T, class Ta> DoColMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)154 static void DoColMultEqMV( 155 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 156 { 157 //cout<<"ColMultEqMV Upper\n"; 158 TMVAssert(x.step()==1); 159 TMVAssert(A.size() == x.size()); 160 TMVAssert(x.size() > 0); 161 TMVAssert(x.ct() == NonConj); 162 TMVAssert(cm == A.iscm()); 163 TMVAssert(ca == A.isconj()); 164 TMVAssert(ua == A.isunit()); 165 166 const ptrdiff_t N = x.size(); 167 168 T* x0 = x.ptr(); 169 const T* xj = x0+1; 170 const ptrdiff_t si = cm ? 1 : A.stepi(); 171 const Ta* A0j = A.cptr(); 172 173 #ifdef TMVFLDEBUG 174 TMVAssert(x0 >= x._first); 175 TMVAssert(x0 < x._last); 176 #endif 177 if (!ua) *x0 *= (ca ? TMV_CONJ(*A0j) : *A0j); 178 A0j += A.stepj(); 179 180 for(ptrdiff_t len=1; len<N; ++len,++xj,A0j+=A.stepj()) if (*xj != T(0)) { 181 // j = 1..N-1 182 // x.subVector(0,j) += *xj * A.col(j,0,j); 183 const Ta* Aij = A0j; 184 T* xi = x0; 185 for(ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) { 186 #ifdef TMVFLDEBUG 187 TMVAssert(xi >= x._first); 188 TMVAssert(xi < x._last); 189 #endif 190 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij); 191 } 192 // Now Aij == Ajj, xi == xj 193 // so this next statement is really *xj *= *Ajj 194 #ifdef TMVFLDEBUG 195 TMVAssert(xi >= x._first); 196 TMVAssert(xi < x._last); 197 #endif 198 if (!ua) *xi *= (ca ? TMV_CONJ(*Aij) : *Aij); 199 } 200 } 201 202 template <bool cm, class T, class Ta> ColMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)203 static void ColMultEqMV( 204 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 205 { 206 if (A.isconj()) 207 if (A.isunit()) 208 DoColMultEqMV<cm,true,true>(A,x); 209 else 210 DoColMultEqMV<cm,true,false>(A,x); 211 else 212 if (A.isunit()) 213 DoColMultEqMV<cm,false,true>(A,x); 214 else 215 DoColMultEqMV<cm,false,false>(A,x); 216 } 217 218 template <bool rm, bool ca, bool ua, class T, class Ta> DoRowMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)219 static void DoRowMultEqMV( 220 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 221 { 222 //cout<<"RowMultEqMV Lower\n"; 223 TMVAssert(x.step()==1); 224 TMVAssert(A.size() == x.size()); 225 TMVAssert(x.size() > 0); 226 TMVAssert(x.ct() == NonConj); 227 TMVAssert(rm == A.isrm()); 228 TMVAssert(ca == A.isconj()); 229 TMVAssert(ua == A.isunit()); 230 231 const ptrdiff_t N = x.size(); 232 const ptrdiff_t si = A.stepi(); 233 const ptrdiff_t sj = rm ? 1 : A.stepj(); 234 const ptrdiff_t ds = si+sj; 235 236 const T* x0 = x.cptr(); 237 T* xi = x.ptr() + N-1; 238 const Ta* Ai0 = A.cptr()+(N-1)*si; 239 const Ta* Aii = Ai0 + (N-1)*sj; 240 241 for(ptrdiff_t len=N-1; len>0; --len,--xi,Ai0-=si,Aii-=ds) { 242 // i = N-1..1 243 // x(i) = A.row(i,0,i+1) * x.subVector(0,i+1); 244 T xx = *xi; 245 if (!ua) xx *= (ca ? TMV_CONJ(*Aii) : *Aii); 246 const Ta* Aij = Ai0; 247 const T* xj = x0; 248 for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) 249 xx += *xj * (ca ? TMV_CONJ(*Aij) : *Aij); 250 #ifdef TMVFLDEBUG 251 TMVAssert(xi >= x._first); 252 TMVAssert(xi < x._last); 253 #endif 254 *xi = xx; 255 } 256 #ifdef TMVFLDEBUG 257 TMVAssert(xi >= x._first); 258 TMVAssert(xi < x._last); 259 #endif 260 if (!ua) *xi *= (ca ? TMV_CONJ(*Aii) : *Aii); 261 } 262 263 template <bool rm, class T, class Ta> RowMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)264 static void RowMultEqMV( 265 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 266 { 267 if (A.isconj()) 268 if (A.isunit()) 269 DoRowMultEqMV<rm,true,true>(A,x); 270 else 271 DoRowMultEqMV<rm,true,false>(A,x); 272 else 273 if (A.isunit()) 274 DoRowMultEqMV<rm,false,true>(A,x); 275 else 276 DoRowMultEqMV<rm,false,false>(A,x); 277 } 278 279 template <bool cm, bool ca, bool ua, class T, class Ta> DoColMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)280 static void DoColMultEqMV( 281 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 282 { 283 //cout<<"ColMultEqMV Lower\n"; 284 TMVAssert(A.size() == x.size()); 285 TMVAssert(x.size() > 0); 286 TMVAssert(x.ct() == NonConj); 287 TMVAssert(x.step() == 1); 288 TMVAssert(cm == A.iscm()); 289 TMVAssert(ca == A.isconj()); 290 TMVAssert(ua == A.isunit()); 291 292 const ptrdiff_t N = x.size(); 293 294 T* xj = x.ptr() + N-2; 295 const ptrdiff_t si = cm ? 1 : A.stepi(); 296 const ptrdiff_t ds = A.stepj()+si; 297 const Ta* Ajj = A.cptr()+(N-2)*ds; 298 299 #ifdef TMVFLDEBUG 300 TMVAssert(xj+1 >= x._first); 301 TMVAssert(xj+1 < x._last); 302 #endif 303 if (!ua) *(xj+1) *= (ca ? TMV_CONJ(*(Ajj+ds)) : *(Ajj+ds)); 304 for(ptrdiff_t jj=N-1,len=1;jj>0;--jj,++len,--xj,Ajj-=ds) if (*xj!=T(0)) { 305 // j = N-2..0, jj = j+1 306 // x.subVector(j+1,N) += *xj * A.col(j,j+1,N); 307 T* xi = xj+1; 308 const Ta* Aij = Ajj+si; 309 for (ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) { 310 #ifdef TMVFLDEBUG 311 TMVAssert(xi >= x._first); 312 TMVAssert(xi < x._last); 313 #endif 314 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij); 315 } 316 #ifdef TMVFLDEBUG 317 TMVAssert(xj >= x._first); 318 TMVAssert(xj < x._last); 319 #endif 320 if (!ua) *xj *= (ca ? TMV_CONJ(*Ajj) : *Ajj); 321 } 322 } 323 324 template <bool cm, class T, class Ta> ColMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)325 static void ColMultEqMV( 326 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 327 { 328 if (A.isconj()) 329 if (A.isunit()) 330 DoColMultEqMV<cm,true,true>(A,x); 331 else 332 DoColMultEqMV<cm,true,false>(A,x); 333 else 334 if (A.isunit()) 335 DoColMultEqMV<cm,false,true>(A,x); 336 else 337 DoColMultEqMV<cm,false,false>(A,x); 338 } 339 340 template <class T, class Ta> DoMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)341 static inline void DoMultEqMV( 342 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 343 // x = A * x 344 { 345 if (A.isrm()) RowMultEqMV<true>(A,x); 346 else if (A.iscm()) ColMultEqMV<true>(A,x); 347 else RowMultEqMV<false>(A,x); 348 } 349 350 template <class T, class Ta> NonBlasMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)351 static void NonBlasMultEqMV( 352 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 353 { 354 TMVAssert(A.size() == x.size()); 355 TMVAssert(x.size() > 0); 356 TMVAssert(x.step() == 1); 357 TMVAssert(x.ct() == NonConj); 358 359 // [ A11 A12 A13 ] [ 0 ] [ A12 x2 ] 360 // x = [ 0 A22 A23 ] [ x2 ] = [ A22 x2 ] 361 // [ 0 0 A33 ] [ 0 ] [ 0 ] 362 363 const ptrdiff_t N = x.size(); // = A.size() 364 ptrdiff_t j2 = N; 365 for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2); 366 if (j2 == 0) return; 367 ptrdiff_t j1 = 0; 368 for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1); 369 if (j1 == 0 && j2 == N) { 370 DoMultEqMV(A,x); 371 } else { 372 TMVAssert(j1 < j2); 373 ConstUpperTriMatrixView<Ta> A22 = A.subTriMatrix(j1,j2); 374 VectorView<T> x2 = x.subVector(j1,j2); 375 376 if (j1 != 0) { 377 ConstMatrixView<Ta> A12 = A.subMatrix(0,j1,j1,j2); 378 VectorView<T> x1 = x.subVector(0,j1); 379 UnitAMultMV1<true,false>(A12,x2,x1); 380 } 381 DoMultEqMV(A22,x2); 382 } 383 } 384 385 template <class T, class Ta> DoMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)386 static inline void DoMultEqMV( 387 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 388 { 389 if (A.isrm()) RowMultEqMV<true>(A,x); 390 else if (A.iscm() && !SameStorage(A,x)) 391 ColMultEqMV<true>(A,x); 392 else RowMultEqMV<false>(A,x); 393 } 394 395 template <class T, class Ta> NonBlasMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)396 static void NonBlasMultEqMV( 397 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 398 // x = A * x 399 { 400 TMVAssert(A.size() == x.size()); 401 TMVAssert(x.size() > 0); 402 TMVAssert(x.step() == 1); 403 TMVAssert(x.ct() == NonConj); 404 405 // [ A11 0 0 ] [ 0 ] [ 0 ] 406 // x = [ A21 A22 0 ] [ x2 ] = [ A22 x2 ] 407 // [ A31 A32 A33 ] [ 0 ] [ A32 x2 ] 408 409 const ptrdiff_t N = x.size(); // = A.size() 410 ptrdiff_t j2 = N; 411 for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2); 412 if (j2 == 0) return; 413 ptrdiff_t j1 = 0; 414 for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1); 415 if (j1 == 0 && j2 == N) { 416 DoMultEqMV(A,x); 417 } else { 418 TMVAssert(j1 < j2); 419 ConstLowerTriMatrixView<Ta> A22 = A.subTriMatrix(j1,j2); 420 VectorView<T> x2 = x.subVector(j1,j2); 421 422 if (j2 != N) { 423 ConstMatrixView<Ta> A32 = A.subMatrix(j2,N,j1,j2); 424 VectorView<T> x3 = x.subVector(j2,N); 425 UnitAMultMV1<true,false>(A32,x2,x3); 426 } 427 DoMultEqMV(A22,x2); 428 } 429 } 430 431 #ifdef BLAS 432 template <class T, class Ta> BlasMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)433 static inline void BlasMultEqMV( 434 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 435 { NonBlasMultEqMV(A,x); } 436 template <class T, class Ta> BlasMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)437 static inline void BlasMultEqMV( 438 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 439 { NonBlasMultEqMV(A,x); } 440 #ifdef INST_DOUBLE 441 template <> BlasMultEqMV(const GenUpperTriMatrix<double> & A,VectorView<double> x)442 void BlasMultEqMV( 443 const GenUpperTriMatrix<double>& A, VectorView<double> x) 444 { 445 int n=A.size(); 446 int lda=A.isrm()?A.stepi():A.stepj(); 447 int xs=x.step(); 448 double* xp = x.ptr(); 449 BLASNAME(dtrmv) ( 450 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 451 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 452 BLASV(n),BLASP(A.cptr()),BLASV(lda), 453 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 454 } 455 template <> BlasMultEqMV(const GenLowerTriMatrix<double> & A,VectorView<double> x)456 void BlasMultEqMV( 457 const GenLowerTriMatrix<double>& A, VectorView<double> x) 458 { 459 int n=A.size(); 460 int lda=A.isrm()?A.stepi():A.stepj(); 461 int xs=x.step(); 462 double* xp = x.ptr(); 463 BLASNAME(dtrmv) ( 464 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 465 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 466 BLASV(n),BLASP(A.cptr()),BLASV(lda), 467 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 468 } 469 template <> BlasMultEqMV(const GenUpperTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> x)470 void BlasMultEqMV( 471 const GenUpperTriMatrix<std::complex<double> >& A, 472 VectorView<std::complex<double> > x) 473 { 474 int n=A.size(); 475 int lda=A.isrm()?A.stepi():A.stepj(); 476 int xs=x.step(); 477 std::complex<double>* xp = x.ptr(); 478 if (A.iscm() && A.isconj()) { 479 #ifdef CBLAS 480 BLASNAME(ztrmv) ( 481 BLASRM BLASCH_LO, BLASCH_CT, 482 A.isunit()?BLASCH_U:BLASCH_NU, 483 BLASV(n),BLASP(A.cptr()),BLASV(lda), 484 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 485 #else 486 x.conjugateSelf(); 487 BLASNAME(ztrmv) ( 488 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 489 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 490 BLASV(n),BLASP(A.cptr()),BLASV(lda), 491 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 492 x.conjugateSelf(); 493 #endif 494 } else { 495 BLASNAME(ztrmv) ( 496 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 497 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 498 A.isunit()?BLASCH_U:BLASCH_NU, 499 BLASV(n),BLASP(A.cptr()),BLASV(lda), 500 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 501 } 502 } 503 template <> BlasMultEqMV(const GenLowerTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> x)504 void BlasMultEqMV( 505 const GenLowerTriMatrix<std::complex<double> >& A, 506 VectorView<std::complex<double> > x) 507 { 508 int n=A.size(); 509 int lda=A.isrm()?A.stepi():A.stepj(); 510 int xs=x.step(); 511 std::complex<double>* xp = x.ptr(); 512 if (A.iscm() && A.isconj()) { 513 #ifdef CBLAS 514 BLASNAME(ztrmv) ( 515 BLASRM BLASCH_UP, BLASCH_CT, 516 A.isunit()?BLASCH_U:BLASCH_NU, 517 BLASV(n),BLASP(A.cptr()),BLASV(lda), 518 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 519 #else 520 x.conjugateSelf(); 521 BLASNAME(ztrmv) ( 522 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 523 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 524 BLASV(n),BLASP(A.cptr()),BLASV(lda), 525 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 526 x.conjugateSelf(); 527 #endif 528 } else { 529 BLASNAME(ztrmv) ( 530 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 531 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 532 A.isunit()?BLASCH_U:BLASCH_NU, 533 BLASV(n),BLASP(A.cptr()),BLASV(lda), 534 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 535 } 536 } 537 template <> BlasMultEqMV(const GenUpperTriMatrix<double> & A,VectorView<std::complex<double>> x)538 void BlasMultEqMV( 539 const GenUpperTriMatrix<double>& A, 540 VectorView<std::complex<double> > x) 541 { 542 int n=A.size(); 543 int lda=A.isrm()?A.stepi():A.stepj(); 544 int xs=2*x.step(); 545 double* xp = (double*) x.ptr(); 546 BLASNAME(dtrmv) ( 547 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 548 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 549 BLASV(n),BLASP(A.cptr()),BLASV(lda), 550 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 551 BLASNAME(dtrmv) ( 552 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 553 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 554 BLASV(n),BLASP(A.cptr()),BLASV(lda), 555 BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1); 556 } 557 template <> BlasMultEqMV(const GenLowerTriMatrix<double> & A,VectorView<std::complex<double>> x)558 void BlasMultEqMV( 559 const GenLowerTriMatrix<double>& A, 560 VectorView<std::complex<double> > x) 561 { 562 int n=A.size(); 563 int lda=A.isrm()?A.stepi():A.stepj(); 564 int xs=2*x.step(); 565 double* xp = (double*) x.ptr(); 566 BLASNAME(dtrmv) ( 567 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 568 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 569 BLASV(n),BLASP(A.cptr()),BLASV(lda), 570 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 571 BLASNAME(dtrmv) ( 572 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 573 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 574 BLASV(n),BLASP(A.cptr()),BLASV(lda), 575 BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1); 576 } 577 #endif 578 #ifdef INST_FLOAT 579 template <> BlasMultEqMV(const GenUpperTriMatrix<float> & A,VectorView<float> x)580 void BlasMultEqMV( 581 const GenUpperTriMatrix<float>& A, VectorView<float> x) 582 { 583 int n=A.size(); 584 int lda=A.isrm()?A.stepi():A.stepj(); 585 int xs=x.step(); 586 float* xp = x.ptr(); 587 BLASNAME(strmv) ( 588 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 589 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 590 BLASV(n),BLASP(A.cptr()),BLASV(lda), 591 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 592 } 593 template <> BlasMultEqMV(const GenLowerTriMatrix<float> & A,VectorView<float> x)594 void BlasMultEqMV( 595 const GenLowerTriMatrix<float>& A, VectorView<float> x) 596 { 597 int n=A.size(); 598 int lda=A.isrm()?A.stepi():A.stepj(); 599 int xs=x.step(); 600 float* xp = x.ptr(); 601 BLASNAME(strmv) ( 602 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 603 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 604 BLASV(n),BLASP(A.cptr()),BLASV(lda), 605 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 606 } 607 template <> BlasMultEqMV(const GenUpperTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> x)608 void BlasMultEqMV( 609 const GenUpperTriMatrix<std::complex<float> >& A, 610 VectorView<std::complex<float> > x) 611 { 612 int n=A.size(); 613 int lda=A.isrm()?A.stepi():A.stepj(); 614 int xs=x.step(); 615 std::complex<float>* xp = x.ptr(); 616 if (A.iscm() && A.isconj()) { 617 #ifdef CBLAS 618 BLASNAME(ctrmv) ( 619 BLASRM BLASCH_LO, BLASCH_CT, 620 A.isunit()?BLASCH_U:BLASCH_NU, 621 BLASV(n),BLASP(A.cptr()),BLASV(lda), 622 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 623 #else 624 x.conjugateSelf(); 625 BLASNAME(ctrmv) ( 626 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 627 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 628 BLASV(n),BLASP(A.cptr()),BLASV(lda), 629 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 630 x.conjugateSelf(); 631 #endif 632 } else { 633 BLASNAME(ctrmv) ( 634 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 635 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 636 A.isunit()?BLASCH_U:BLASCH_NU, 637 BLASV(n),BLASP(A.cptr()),BLASV(lda), 638 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 639 } 640 } 641 template <> BlasMultEqMV(const GenLowerTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> x)642 void BlasMultEqMV( 643 const GenLowerTriMatrix<std::complex<float> >& A, 644 VectorView<std::complex<float> > x) 645 { 646 int n=A.size(); 647 int lda=A.isrm()?A.stepi():A.stepj(); 648 int xs=x.step(); 649 std::complex<float>* xp = x.ptr(); 650 if (A.iscm() && A.isconj()) { 651 #ifdef CBLAS 652 BLASNAME(ctrmv) ( 653 BLASRM BLASCH_UP, BLASCH_CT, 654 A.isunit()?BLASCH_U:BLASCH_NU, 655 BLASV(n),BLASP(A.cptr()),BLASV(lda), 656 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 657 #else 658 x.conjugateSelf(); 659 BLASNAME(ctrmv) ( 660 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 661 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 662 BLASV(n),BLASP(A.cptr()),BLASV(lda), 663 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 664 x.conjugateSelf(); 665 #endif 666 } else { 667 BLASNAME(ctrmv) ( 668 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 669 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 670 A.isunit()?BLASCH_U:BLASCH_NU, 671 BLASV(n),BLASP(A.cptr()),BLASV(lda), 672 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 673 } 674 } 675 template <> BlasMultEqMV(const GenUpperTriMatrix<float> & A,VectorView<std::complex<float>> x)676 void BlasMultEqMV( 677 const GenUpperTriMatrix<float>& A, 678 VectorView<std::complex<float> > x) 679 { 680 int n=A.size(); 681 int lda=A.isrm()?A.stepi():A.stepj(); 682 int xs=2*x.step(); 683 float* xp = (float*) x.ptr(); 684 BLASNAME(strmv) ( 685 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 686 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 687 BLASV(n),BLASP(A.cptr()),BLASV(lda), 688 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 689 BLASNAME(strmv) ( 690 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 691 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 692 BLASV(n),BLASP(A.cptr()),BLASV(lda), 693 BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1); 694 } 695 template <> BlasMultEqMV(const GenLowerTriMatrix<float> & A,VectorView<std::complex<float>> x)696 void BlasMultEqMV( 697 const GenLowerTriMatrix<float>& A, 698 VectorView<std::complex<float> > x) 699 { 700 int n=A.size(); 701 int lda=A.isrm()?A.stepi():A.stepj(); 702 int xs=2*x.step(); 703 float* xp = (float*) x.ptr(); 704 BLASNAME(strmv) ( 705 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 706 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 707 BLASV(n),BLASP(A.cptr()),BLASV(lda), 708 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1); 709 BLASNAME(strmv) ( 710 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 711 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 712 BLASV(n),BLASP(A.cptr()),BLASV(lda), 713 BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1); 714 } 715 #endif // FLOAT 716 #endif // BLAS 717 718 template <class T, class Ta> MultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)719 static void MultEqMV( 720 const GenUpperTriMatrix<Ta>& A, VectorView<T> x) 721 { 722 #ifdef XDEBUG 723 Vector<T> x0 = x; 724 Matrix<Ta> A0 = A; 725 Vector<T> x2 = A0 * x0; 726 //cout<<"MultEqMV: A = "<<A<<"x = "<<x<<endl; 727 #endif 728 TMVAssert(A.size() == x.size()); 729 TMVAssert(x.size() > 0); 730 TMVAssert(x.step() == 1); 731 732 if (x.isconj()) MultEqMV(A.conjugate(),x.conjugate()); 733 else { 734 #ifdef BLAS 735 if ((A.isrm() && A.stepi()>0) || (A.iscm() && A.stepj()>0)) 736 BlasMultEqMV(A,x); 737 else { 738 if (A.isunit()) { 739 UpperTriMatrix<Ta,UnitDiag|RowMajor> A2(A); 740 BlasMultEqMV(A2,x); 741 } else { 742 UpperTriMatrix<Ta,NonUnitDiag|RowMajor> A2(A); 743 BlasMultEqMV(A2,x); 744 } 745 } 746 #else 747 NonBlasMultEqMV(A,x); 748 #endif 749 } 750 #ifdef XDEBUG 751 //cout<<"-> x = "<<x<<endl<<"x2 = "<<x2<<endl; 752 if (!(Norm(x-x2) <= 0.001*(Norm(A0)*Norm(x0)))) { 753 cerr<<"MultEqMV: \n"; 754 cerr<<"A = "<<A.cptr()<<" "<<TMV_Text(A)<<" "<<A0<<endl; 755 cerr<<"x = "<<x.cptr()<<" "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x0<<endl; 756 cerr<<"-> x = "<<x<<endl; 757 cerr<<"x2 = "<<x2<<endl; 758 abort(); 759 } 760 #endif 761 } 762 763 template <class T, class Ta> MultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)764 static void MultEqMV( 765 const GenLowerTriMatrix<Ta>& A, VectorView<T> x) 766 { 767 #ifdef XDEBUG 768 Vector<T> x0 = x; 769 Matrix<Ta> A0 = A; 770 Vector<T> x2 = A0 * x0; 771 #endif 772 TMVAssert(A.size() == x.size()); 773 TMVAssert(x.size() > 0); 774 TMVAssert(x.step() == 1); 775 776 if (x.isconj()) MultEqMV(A.conjugate(),x.conjugate()); 777 else { 778 #ifdef BLAS 779 if ( (A.isrm() && A.stepi()>0) || (A.iscm() && A.stepj()>0) ) 780 BlasMultEqMV(A,x); 781 else { 782 if (A.isunit()) { 783 LowerTriMatrix<Ta,UnitDiag|RowMajor> A2(A); 784 BlasMultEqMV(A2,x); 785 } else { 786 LowerTriMatrix<Ta,NonUnitDiag|RowMajor> A2(A); 787 BlasMultEqMV(A2,x); 788 } 789 } 790 #else 791 NonBlasMultEqMV(A,x); 792 #endif 793 } 794 795 #ifdef XDEBUG 796 if (!(Norm(x-x2) <= 0.001*(Norm(A0)*Norm(x0)))) { 797 cerr<<"MultEqMV: \n"; 798 cerr<<"A = "<<A.cptr()<<" "<<TMV_Text(A)<<" "<<A0<<endl; 799 cerr<<"x = "<<x.cptr()<<" "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x0<<endl; 800 cerr<<"-> x = "<<x<<endl; 801 cerr<<"x2 = "<<x2<<endl; 802 abort(); 803 } 804 #endif 805 } 806 807 template <bool add, class T, class Ta, class Tx> MultMV(const T alpha,const GenUpperTriMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)808 void MultMV( 809 const T alpha, const GenUpperTriMatrix<Ta>& A, 810 const GenVector<Tx>& x, VectorView<T> y) 811 // y (+)= alpha * A * x 812 { 813 #ifdef XDEBUG 814 Vector<Tx> x0 = x; 815 Vector<T> y0 = y; 816 Matrix<Ta> A0 = A; 817 Vector<T> y2 = alpha*A0*x0; 818 if (add) y2 += y0; 819 #endif 820 TMVAssert(A.size() == x.size()); 821 TMVAssert(A.size() == y.size()); 822 823 if (y.size() > 0) { 824 if (alpha==T(0)) { 825 if (!add) y.setZero(); 826 } else if (!add && y.step() == 1) { 827 y = x; 828 MultEqMV(A,y); 829 y *= alpha; 830 } else { 831 Vector<T> xx = alpha*x; 832 MultEqMV(A,xx.view()); 833 if (add) y += xx; 834 else y = xx; 835 } 836 } 837 #ifdef XDEBUG 838 if (!(Norm(y-y2) <= 839 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 840 (add?Norm(y0):TMV_RealType(T)(0))))) { 841 cerr<<"MultMV: alpha = "<<alpha<<endl; 842 cerr<<"add = "<<add<<endl; 843 cerr<<"A = "<<TMV_Text(A)<<" "<<A0<<endl; 844 cerr<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x0<<endl; 845 cerr<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<" "<<y0<<endl; 846 cerr<<"-> y = "<<y<<endl; 847 cerr<<"y2 = "<<y2<<endl; 848 abort(); 849 } 850 #endif 851 } 852 853 template <bool add, class T, class Ta, class Tx> MultMV(const T alpha,const GenLowerTriMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)854 void MultMV( 855 const T alpha, const GenLowerTriMatrix<Ta>& A, 856 const GenVector<Tx>& x, VectorView<T> y) 857 // y (+)= alpha * A * x 858 { 859 #ifdef XDEBUG 860 Vector<T> y0 = y; 861 Vector<Tx> x0 = x; 862 Matrix<Ta> A0 = A; 863 Vector<T> y2 = alpha*A0*x0; 864 if (add) y2 += y0; 865 #endif 866 867 TMVAssert(A.size() == x.size()); 868 TMVAssert(A.size() == y.size()); 869 870 if (y.size() > 0) { 871 if (alpha==T(0)) { 872 if (!add) y.setZero(); 873 } else if (!add && y.step() == 1) { 874 y = x; 875 MultEqMV(A,y); 876 if (alpha != T(1)) y *= alpha; 877 } else { 878 Vector<T> xx = alpha*x; 879 MultEqMV(A,xx.view()); 880 if (add) y += xx; 881 else y = xx; 882 } 883 } 884 #ifdef XDEBUG 885 if (!(Norm(y-y2) <= 886 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 887 (add?Norm(y0):TMV_RealType(T)(0))))) { 888 cerr<<"MultMV: alpha = "<<alpha<<endl; 889 cerr<<"add = "<<add<<endl; 890 cerr<<"A = "<<A.cptr()<<" "<<TMV_Text(A)<<" "<<A0<<endl; 891 cerr<<"x = "<<x.cptr()<<" "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x0<<endl; 892 cerr<<"y = "<<y.cptr()<<" "<<TMV_Text(y)<<" step "<<y.step()<<" "<<y0<<endl; 893 cerr<<"-> y = "<<y<<endl; 894 cerr<<"y2 = "<<y2<<endl; 895 abort(); 896 } 897 #endif 898 } 899 900 #define InstFile "TMV_MultUV.inst" 901 #include "TMV_Inst.h" 902 #undef InstFile 903 904 } // namespace tmv 905 906 907