1 /////////////////////////////////////////////////////////////////////////////// 2 // // 3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis // 4 // Copyright (C) 1998 - 2016 // 5 // All rights reserved // 6 // // 7 // The project is hosted at https://code.google.com/p/tmv-cpp/ // 8 // where you can find the current version and current documention. // 9 // // 10 // For concerns or problems with the software, Mike may be contacted at // 11 // mike_jarvis17 [at] gmail. // 12 // // 13 // This software is licensed under a FreeBSD license. The file // 14 // TMV_LICENSE should have bee included with this distribution. // 15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/. // 16 // // 17 // Essentially, you can use this software however you want provided that // 18 // you include the TMV_LICENSE file in any distribution that uses it. // 19 // // 20 /////////////////////////////////////////////////////////////////////////////// 21 22 23 //#define XDEBUG 24 25 #include "TMV_Blas.h" 26 #include "tmv/TMV_MatrixArithFunc.h" 27 #include "TMV_MultMV.h" 28 #include "tmv/TMV_Matrix.h" 29 #include "tmv/TMV_VectorArith.h" 30 #include "tmv/TMV_MatrixArith.h" 31 32 #ifdef XDEBUG 33 #include "tmv/TMV_VIt.h" 34 #include <iostream> 35 using std::cout; 36 using std::cerr; 37 using std::endl; 38 #endif 39 40 // CBLAS trick of using RowMajor with ConjTrans when we have a 41 // case of A.conjugate() * x doesn't seem to be working with MKL 10.2.2. 42 // I haven't been able to figure out why. (e.g. Is it a bug in the MKL 43 // code, or am I doing something wrong?) So for now, just disable it. 44 #ifdef CBLAS 45 #undef CBLAS 46 #endif 47 48 namespace tmv { 49 cptr() const50 template <class T> const T* MatrixComposite<T>::cptr() const 51 { 52 if (!itsm.get()) { 53 ptrdiff_t len = this->colsize()*this->rowsize(); 54 itsm.resize(len); 55 MatrixView<T>(itsm.get(),this->colsize(),this->rowsize(), 56 stepi(),stepj(),NonConj,len 57 TMV_FIRSTLAST1(itsm.get(),itsm.get()+len) ) = *this; 58 } 59 return itsm.get(); 60 } 61 stepi() const62 template <class T> ptrdiff_t MatrixComposite<T>::stepi() const 63 { return 1; } 64 stepj() const65 template <class T> ptrdiff_t MatrixComposite<T>::stepj() const 66 { return this->colsize(); } 67 ls() const68 template <class T> ptrdiff_t MatrixComposite<T>::ls() const 69 { return this->rowsize() * this->colsize(); } 70 71 // 72 // 73 // MultMV 74 // 75 76 // These routines are designed to work even if y has the same storage 77 // as either x or the first row/column of A. 78 79 template <bool add, bool cx, bool ca, bool rm, class T, class Ta, class Tx> RowMultMV(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)80 static void RowMultMV( 81 const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y) 82 { 83 TMVAssert(A.rowsize() == x.size()); 84 TMVAssert(A.colsize() == y.size()); 85 TMVAssert(x.size() > 0); 86 TMVAssert(y.size() > 0); 87 TMVAssert(y.ct()==NonConj); 88 TMVAssert(x.step() == 1); 89 TMVAssert(y.step() == 1); 90 TMVAssert(!SameStorage(x,y)); 91 TMVAssert(cx == x.isconj()); 92 TMVAssert(ca == A.isconj()); 93 94 const ptrdiff_t M = A.colsize(); 95 const ptrdiff_t N = A.rowsize(); 96 const ptrdiff_t si = A.stepi(); 97 const ptrdiff_t sj = (rm ? 1 : A.stepj()); 98 99 const Ta* Ai0 = A.cptr(); 100 const Tx*const x0 = x.cptr(); 101 T* yi = y.ptr(); 102 103 for(ptrdiff_t i=M; i>0; --i,++yi,Ai0+=si) { 104 // *yi += A.row(i) * x 105 106 const Ta* Aij = Ai0; 107 const Tx* xj = x0; 108 register T temp(0); 109 for(ptrdiff_t j=N; j>0; --j,++xj,(rm?++Aij:Aij+=sj)) 110 temp += 111 (cx ? TMV_CONJ(*xj) : *xj) * 112 (ca ? TMV_CONJ(*Aij) : *Aij); 113 114 #ifdef TMVFLDEBUG 115 TMVAssert(yi >= y._first); 116 TMVAssert(yi < y._last); 117 #endif 118 if (add) *yi += temp; 119 else *yi = temp; 120 } 121 } 122 123 template <bool add, bool cx, bool ca, bool cm, class T, class Ta, class Tx> ColMultMV(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)124 static void ColMultMV( 125 const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y) 126 { 127 TMVAssert(A.rowsize() == x.size()); 128 TMVAssert(A.colsize() == y.size()); 129 TMVAssert(x.size() > 0); 130 TMVAssert(y.size() > 0); 131 TMVAssert(y.ct()==NonConj); 132 TMVAssert(x.step() == 1); 133 TMVAssert(y.step() == 1); 134 TMVAssert(!SameStorage(x,y)); 135 TMVAssert(cx == x.isconj()); 136 TMVAssert(ca == A.isconj()); 137 TMVAssert(cm == A.iscm()); 138 139 const ptrdiff_t M = A.colsize(); 140 ptrdiff_t N = A.rowsize(); 141 const ptrdiff_t si = (cm ? 1 : A.stepi()); 142 const ptrdiff_t sj = A.stepj(); 143 144 const Ta* A0j = A.cptr(); 145 const Tx* xj = x.cptr(); 146 T*const y0 = y.ptr(); 147 148 if (!add) { 149 if (*xj == Tx(0)) { 150 y.setZero(); 151 } else { 152 const Ta* Aij = A0j; 153 T* yi = y0; 154 const Tx xjval = (cx ? TMV_CONJ(*xj) : *xj); 155 for(ptrdiff_t i=M; i>0; --i,++yi,(cm?++Aij:Aij+=si)) { 156 #ifdef TMVFLDEBUG 157 TMVAssert(yi >= y._first); 158 TMVAssert(yi < y._last); 159 #endif 160 *yi = xjval * (ca ? TMV_CONJ(*Aij) : *Aij); 161 } 162 } 163 ++xj; A0j+=sj; --N; 164 } 165 166 for(; N>0; --N,++xj,A0j+=sj) { 167 // y += *xj * A.col(j) 168 if (*xj != Tx(0)) { 169 const Ta* Aij = A0j; 170 T* yi = y0; 171 const Tx xjval = (cx ? TMV_CONJ(*xj) : *xj); 172 for(ptrdiff_t i=M; i>0; --i,++yi,(cm?++Aij:Aij+=si)) { 173 #ifdef TMVFLDEBUG 174 TMVAssert(yi >= y._first); 175 TMVAssert(yi < y._last); 176 #endif 177 *yi += xjval * (ca ? TMV_CONJ(*Aij) : *Aij); 178 } 179 } 180 } 181 } 182 183 template <bool add, bool cx, class T, class Ta, class Tx> UnitAMultMV1(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)184 void UnitAMultMV1( 185 const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y) 186 { 187 TMVAssert(A.rowsize() == x.size()); 188 TMVAssert(A.colsize() == y.size()); 189 TMVAssert(x.size() > 0); 190 TMVAssert(y.size() > 0); 191 TMVAssert(y.ct() == NonConj); 192 TMVAssert(x.step() == 1); 193 TMVAssert(y.step() == 1); 194 TMVAssert(!SameStorage(x,y)); 195 TMVAssert(cx == x.isconj()); 196 197 if (A.isrm()) 198 if (A.isconj()) 199 RowMultMV<add,cx,true,true>(A,x,y); 200 else 201 RowMultMV<add,cx,false,true>(A,x,y); 202 else if (A.iscm()) 203 if (A.isconj()) 204 ColMultMV<add,cx,true,true>(A,x,y); 205 else 206 ColMultMV<add,cx,false,true>(A,x,y); 207 else if ( A.rowsize() >= A.colsize() ) 208 if (A.isconj()) 209 RowMultMV<add,cx,true,false>(A,x,y); 210 else 211 RowMultMV<add,cx,false,false>(A,x,y); 212 else 213 if (A.isconj()) 214 ColMultMV<add,cx,true,false>(A,x,y); 215 else 216 ColMultMV<add,cx,false,false>(A,x,y); 217 } 218 219 template <bool add, bool cx, class T, class Ta, class Tx> UnitAMultMV(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)220 static void UnitAMultMV( 221 const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y) 222 { 223 #ifdef XDEBUG 224 cout<<"Start UnitAMultMV: \n"; 225 cout<<"add = "<<add<<endl; 226 cout<<"A = "<<TMV_Text(A)<<" "<<A<<endl; 227 cout<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x<<endl; 228 cout<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<" "<<y<<endl; 229 Vector<Tx> x0 = x; 230 Vector<T> y0 = y; 231 Matrix<Ta> A0 = A; 232 Vector<T> y2 = y; 233 for(ptrdiff_t i=0;i<y.size();i++) { 234 if (add) 235 y2(i) += (A.row(i) * x0); 236 else 237 y2(i) = (A.row(i) * x0); 238 } 239 cout<<"y2 = "<<y2<<endl; 240 #endif 241 // Check for 0's in beginning or end of x: 242 // y += [ A1 A2 A3 ] [ 0 ] --> y += A2 x 243 // [ x ] 244 // [ 0 ] 245 246 const ptrdiff_t N = x.size(); // = A.rowsize() 247 ptrdiff_t j2 = N; 248 for(const Tx* x2=x.cptr()+N-1; j2>0 && *x2==Tx(0); --j2,--x2); 249 if (j2 == 0) { 250 if (!add) y.setZero(); 251 return; 252 } 253 ptrdiff_t j1 = 0; 254 for(const Tx* x1=x.cptr(); *x1==Tx(0); ++j1,++x1) {} 255 TMVAssert(j1 !=j2); 256 if (j1 == 0 && j2 == N) UnitAMultMV1<add,cx>(A,x,y); 257 else UnitAMultMV1<add,cx>(A.colRange(j1,j2),x.subVector(j1,j2),y); 258 259 #ifdef XDEBUG 260 cout<<"y => "<<y<<endl; 261 if (!(Norm(y-y2) <= 262 0.001*(Norm(A0)*Norm(x0)+ 263 (add?Norm(y0):TMV_RealType(T)(0))))) { 264 cerr<<"MultMV: \n"; 265 cerr<<"add = "<<add<<endl; 266 cerr<<"A = "<<TMV_Text(A); 267 if (A.rowsize() < 30 && A.colsize() < 30) cerr<<" "<<A0; 268 else cerr<<" "<<A.colsize()<<" x "<<A.rowsize(); 269 cerr<<endl<<"x = "<<TMV_Text(x)<<" step "<<x.step(); 270 if (x.size() < 30) cerr<<" "<<x0; 271 cerr<<endl<<"y = "<<TMV_Text(y)<<" step "<<y.step(); 272 if (add && y.size() < 30) cerr<<" "<<y0; 273 cerr<<endl<<"Aptr = "<<A.cptr(); 274 cerr<<", xptr = "<<x.cptr()<<", yptr = "<<y.cptr()<<endl; 275 if (y.size() < 200) { 276 cerr<<"--> y = "<<y<<endl; 277 cerr<<"y2 = "<<y2<<endl; 278 } else { 279 ptrdiff_t imax; 280 (y-y2).maxAbsElement(&imax); 281 cerr<<"y("<<imax<<") = "<<y(imax)<<endl; 282 cerr<<"y2("<<imax<<") = "<<y2(imax)<<endl; 283 } 284 cerr<<"Norm(A0) = "<<Norm(A0)<<endl; 285 cerr<<"Norm(x0) = "<<Norm(x0)<<endl; 286 if (add) cerr<<"Norm(y0) = "<<Norm(y0)<<endl; 287 cerr<<"|A0|*|x0|+?|y0| = "<< 288 Norm(A0)*Norm(x0)+ 289 (add?Norm(y0):TMV_RealType(T)(0))<<endl; 290 cerr<<"Norm(y-y2) = "<<Norm(y-y2)<<endl; 291 cerr<<"NormInf(y-y2) = "<<NormInf(y-y2)<<endl; 292 cerr<<"Norm1(y-y2) = "<<Norm1(y-y2)<<endl; 293 abort(); 294 } 295 #endif 296 } 297 NonBlasMultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)298 template <bool add, class T, class Ta, class Tx> static void NonBlasMultMV( 299 const T alpha, const GenMatrix<Ta>& A, const GenVector<Tx>& x, 300 VectorView<T> y) 301 // y (+)= alpha * A * x 302 { 303 #ifdef XDEBUG 304 cout<<"Start MultMV: alpha = "<<alpha<<endl; 305 cout<<"add = "<<add<<endl; 306 cout<<"A = "<<TMV_Text(A)<<" "<<A<<endl; 307 cout<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x<<endl; 308 cout<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<" "<<y<<endl; 309 Vector<Tx> x0 = x; 310 Vector<T> y0 = y; 311 Matrix<Ta> A0 = A; 312 Vector<T> y2 = y; 313 for(ptrdiff_t i=0;i<y.size();i++) { 314 if (add) 315 y2(i) += alpha * (A.row(i) * x0); 316 else 317 y2(i) = alpha * (A.row(i) * x0); 318 } 319 cout<<"y2 = "<<y2<<endl; 320 #endif 321 TMVAssert(A.rowsize() == x.size()); 322 TMVAssert(A.colsize() == y.size()); 323 TMVAssert(alpha != T(0)); 324 TMVAssert(x.size() > 0); 325 TMVAssert(y.size() > 0); 326 TMVAssert(y.ct() == NonConj); 327 328 const ptrdiff_t M = A.colsize(); 329 const ptrdiff_t N = A.rowsize(); 330 331 if (x.step() != 1 || SameStorage(x,y) || 332 (alpha != TMV_RealType(T)(1) && y.step() == 1 && M/4 >= N)) { 333 // This last check is taken from the ATLAS version of this code. 334 // Apparently M = 4N is the dividing line between applying alpha 335 // here versus at the end when adding Ax to y 336 if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 337 Vector<Tx> xx = TMV_REAL(alpha)*x; 338 if (y.step()!=1) { 339 Vector<T> yy(y.size()); 340 UnitAMultMV<false,false>(A,xx,yy.view()); 341 if (add) y += yy; 342 else y = yy; 343 } 344 else 345 UnitAMultMV<add,false>(A,xx,y); 346 } else { 347 Vector<T> xx = alpha*x; 348 if (y.step() != 1) { 349 Vector<T> yy(y.size()); 350 UnitAMultMV<false,false>(A,xx,yy.view()); 351 if (add) y += yy; 352 else y = yy; 353 } 354 else 355 UnitAMultMV<add,false>(A,xx,y); 356 } 357 } else if (y.step() != 1 || alpha != TMV_RealType(T)(1)) { 358 Vector<T> yy(y.size()); 359 if (x.isconj()) 360 UnitAMultMV<false,true>(A,x,yy.view()); 361 else 362 UnitAMultMV<false,false>(A,x,yy.view()); 363 if (add) y += alpha*yy; 364 else y = alpha*yy; 365 } else { 366 TMVAssert(alpha == T(1)); 367 TMVAssert(y.step() == 1); 368 TMVAssert(x.step() == 1); 369 TMVAssert(!SameStorage(x,y)); 370 if (x.isconj()) 371 UnitAMultMV<add,true>(A,x,y); 372 else 373 UnitAMultMV<add,false>(A,x,y); 374 } 375 #ifdef XDEBUG 376 cout<<"y => "<<y<<endl; 377 if (!(Norm(y-y2) <= 378 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 379 (add?Norm(y0):TMV_RealType(T)(0))))) { 380 cerr<<"MultMV: alpha = "<<alpha<<endl; 381 cerr<<"add = "<<add<<endl; 382 cerr<<"A = "<<TMV_Text(A); 383 if (A.rowsize() < 30 && A.colsize() < 30) cerr<<" "<<A0; 384 else cerr<<" "<<A.colsize()<<" x "<<A.rowsize(); 385 cerr<<endl<<"x = "<<TMV_Text(x)<<" step "<<x.step(); 386 if (x.size() < 30) cerr<<" "<<x0; 387 cerr<<endl<<"y = "<<TMV_Text(y)<<" step "<<y.step(); 388 if (add && y.size() < 30) cerr<<" "<<y0; 389 cerr<<endl<<"Aptr = "<<A.cptr(); 390 cerr<<", xptr = "<<x.cptr()<<", yptr = "<<y.cptr()<<endl; 391 if (y.size() < 200) { 392 cerr<<"--> y = "<<y<<endl; 393 cerr<<"y2 = "<<y2<<endl; 394 } else { 395 ptrdiff_t imax; 396 (y-y2).maxAbsElement(&imax); 397 cerr<<"y("<<imax<<") = "<<y(imax)<<endl; 398 cerr<<"y2("<<imax<<") = "<<y2(imax)<<endl; 399 } 400 cerr<<"Norm(A0) = "<<Norm(A0)<<endl; 401 cerr<<"Norm(x0) = "<<Norm(x0)<<endl; 402 if (add) cerr<<"Norm(y0) = "<<Norm(y0)<<endl; 403 cerr<<"|alpha|*|A0|*|x0|+?|y0| = "<< 404 TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 405 (add?Norm(y0):TMV_RealType(T)(0))<<endl; 406 cerr<<"Norm(y-y2) = "<<Norm(y-y2)<<endl; 407 cerr<<"NormInf(y-y2) = "<<NormInf(y-y2)<<endl; 408 cerr<<"Norm1(y-y2) = "<<Norm1(y-y2)<<endl; 409 abort(); 410 } 411 #endif 412 } 413 414 #ifdef BLAS BlasMultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,const int beta,VectorView<T> y)415 template <class T, class Ta, class Tx> static inline void BlasMultMV( 416 const T alpha, const GenMatrix<Ta>& A, 417 const GenVector<Tx>& x, const int beta, VectorView<T> y) 418 { 419 if (beta == 0) NonBlasMultMV<false>(alpha,A,x,y); 420 else NonBlasMultMV<true>(alpha,A,x,y); 421 } 422 #ifdef INST_DOUBLE BlasMultMV(const double alpha,const GenMatrix<double> & A,const GenVector<double> & x,const int beta,VectorView<double> y)423 template <> void BlasMultMV( 424 const double alpha, const GenMatrix<double>& A, 425 const GenVector<double>& x, const int beta, VectorView<double> y) 426 { 427 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 428 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 429 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 430 if (lda < m) { TMVAssert(n==1); lda = m; } 431 int xs = x.step(); 432 int ys = y.step(); 433 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 434 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 435 const double* xp = x.cptr(); 436 if (xs < 0) xp += (x.size()-1)*xs; 437 double* yp = y.ptr(); 438 if (ys < 0) yp += (y.size()-1)*ys; 439 // Some BLAS implementations seem to have trouble if the 440 // input y has a nan in it. 441 // They propagate the nan into the output. 442 // I guess they strictly interpret y = beta*y + alpha*m*x, 443 // so if beta = 0, then beta*nan = nan. 444 // Anyway, to fix this problem, we always use beta=1, and just 445 // zero out y before calling the blas function if beta is 0. 446 if (beta == 0) y.setZero(); 447 double xbeta(1); 448 449 #if 0 450 std::cout<<"Before dgemv"<<std::endl; 451 std::cout<<"A = "<<A<<std::endl; 452 std::cout<<"x = "<<x<<std::endl; 453 std::cout<<"y = "<<y<<std::endl; 454 std::cout<<"m = "<<m<<std::endl; 455 std::cout<<"n = "<<n<<std::endl; 456 std::cout<<"lda = "<<lda<<std::endl; 457 std::cout<<"xs = "<<xs<<std::endl; 458 std::cout<<"ys = "<<ys<<std::endl; 459 std::cout<<"alpha = "<<alpha<<std::endl; 460 std::cout<<"beta = "<<xbeta<<std::endl; 461 std::cout<<"aptr = "<<A.cptr()<<std::endl; 462 std::cout<<"xp = "<<xp<<std::endl; 463 std::cout<<"yp = "<<yp<<std::endl; 464 std::cout<<"NT = "<<(A.isrm()?'T':'N')<<std::endl; 465 if (A.isrm()) { 466 std::cout<<"x.size = "<<x.size()<<std::endl; 467 std::cout<<"x = "; 468 for(int i=0;i<m;i++) std::cout<<*(xp+i*xs)<<" "; 469 std::cout<<std::endl; 470 std::cout<<"y.size = "<<y.size()<<std::endl; 471 std::cout<<"y = "; 472 for(int i=0;i<n;i++) std::cout<<*(yp+i*ys)<<" "; 473 std::cout<<std::endl; 474 std::cout<<"A.size = "<<A.colsize()<<','<<A.rowsize()<<std::endl; 475 std::cout<<"A = "; 476 for(int i=0;i<n*lda;i++) std::cout<<*(A.cptr()+i)<<" "; 477 std::cout<<std::endl; 478 } 479 #endif 480 BLASNAME(dgemv) ( 481 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 482 BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda), 483 BLASP(xp),BLASV(xs),BLASV(xbeta),BLASP(yp),BLASV(ys) BLAS1); 484 #if 0 485 std::cout<<"After dgemv"<<std::endl; 486 #endif 487 } BlasMultMV(const std::complex<double> alpha,const GenMatrix<std::complex<double>> & A,const GenVector<std::complex<double>> & x,const int beta,VectorView<std::complex<double>> y)488 template <> void BlasMultMV( 489 const std::complex<double> alpha, 490 const GenMatrix<std::complex<double> >& A, 491 const GenVector<std::complex<double> >& x, 492 const int beta, VectorView<std::complex<double> > y) 493 { 494 if (x.isconj() 495 #ifndef CBLAS 496 && !(A.isconj() && BlasIsCM(A)) 497 #endif 498 ) { 499 Vector<std::complex<double> > xx = alpha*x; 500 return BlasMultMV(std::complex<double>(1),A,xx,beta,y); 501 } 502 503 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 504 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 505 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 506 if (lda < m) { TMVAssert(n==1); lda = m; } 507 int xs = x.step(); 508 int ys = y.step(); 509 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 510 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 511 const std::complex<double>* xp = x.cptr(); 512 if (xs < 0) xp += (x.size()-1)*xs; 513 std::complex<double>* yp = y.ptr(); 514 if (ys < 0) yp += (y.size()-1)*ys; 515 if (beta == 0) y.setZero(); 516 std::complex<double> xbeta(1); 517 #if 0 518 std::cout<<"Before zgemv"<<std::endl; 519 std::cout<<"A = "<<A<<std::endl; 520 std::cout<<"x = "<<x<<std::endl; 521 std::cout<<"y = "<<y<<std::endl; 522 std::cout<<"m = "<<m<<std::endl; 523 std::cout<<"n = "<<n<<std::endl; 524 std::cout<<"lda = "<<lda<<std::endl; 525 std::cout<<"xs = "<<xs<<std::endl; 526 std::cout<<"ys = "<<ys<<std::endl; 527 std::cout<<"alpha = "<<alpha<<std::endl; 528 std::cout<<"beta = "<<xbeta<<std::endl; 529 std::cout<<"aptr = "<<A.cptr()<<std::endl; 530 std::cout<<"xp = "<<xp<<std::endl; 531 std::cout<<"yp = "<<yp<<std::endl; 532 #endif 533 if (A.isconj() && BlasIsCM(A)) { 534 #ifdef CBLAS 535 TMV_SWAP(m,n); 536 BLASNAME(zgemv) ( 537 BLASRM BLASCH_CT, 538 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 539 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) BLAS1); 540 #else 541 std::complex<double> ca = TMV_CONJ(alpha); 542 if (x.isconj()) { 543 y.conjugateSelf(); 544 BLASNAME(zgemv) ( 545 BLASCM BLASCH_NT, 546 BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda), 547 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) 548 BLAS1); 549 y.conjugateSelf(); 550 } else { 551 Vector<std::complex<double> > xx = ca*x.conjugate(); 552 ca = std::complex<double>(1); 553 xs = 1; 554 xp = xx.cptr(); 555 y.conjugateSelf(); 556 BLASNAME(zgemv) ( 557 BLASCM BLASCH_NT, 558 BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda), 559 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) 560 BLAS1); 561 y.conjugateSelf(); 562 } 563 #endif 564 } else { 565 BLASNAME(zgemv) ( 566 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 567 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 568 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) 569 BLAS1); 570 } 571 } BlasMultMV(const std::complex<double> alpha,const GenMatrix<std::complex<double>> & A,const GenVector<double> & x,const int beta,VectorView<std::complex<double>> y)572 template <> void BlasMultMV( 573 const std::complex<double> alpha, 574 const GenMatrix<std::complex<double> >& A, 575 const GenVector<double>& x, 576 const int beta, VectorView<std::complex<double> > y) 577 { 578 if (BlasIsCM(A)) { 579 if (y.step() != 1) { 580 Vector<std::complex<double> > yy(y.size()); 581 BlasMultMV(std::complex<double>(1),A,x,0,yy.view()); 582 if (beta == 0) y = alpha*yy; 583 else y += alpha*yy; 584 } else { 585 if (beta == 0) { 586 int m = 2*A.colsize(); 587 int n = A.rowsize(); 588 int lda = 2*A.stepj(); 589 if (lda < m) { TMVAssert(n==1); lda = m; } 590 int xs = x.step(); 591 int ys = 1; 592 const double* xp = x.cptr(); 593 if (xs < 0) xp += (x.size()-1)*xs; 594 double* yp = (double*) y.ptr(); 595 double xalpha(1); 596 y.setZero(); 597 double xbeta(1); 598 BLASNAME(dgemv) ( 599 BLASCM BLASCH_NT, 600 BLASV(m),BLASV(n),BLASV(xalpha), 601 BLASP((double*)A.cptr()),BLASV(lda), 602 BLASP(xp),BLASV(xs),BLASV(xbeta), 603 BLASP(yp),BLASV(ys) BLAS1); 604 if (A.isconj()) y.conjugateSelf(); 605 y *= alpha; 606 } else if (A.isconj()) { 607 Vector<std::complex<double> > yy(y.size()); 608 BlasMultMV( 609 std::complex<double>(1),A.conjugate(),x,0,yy.view()); 610 y += alpha*yy.conjugate(); 611 } else if (TMV_IMAG(alpha) == 0.) { 612 int m = 2*A.colsize(); 613 int n = A.rowsize(); 614 int lda = 2*A.stepj(); 615 if (lda < m) { TMVAssert(n==1); lda = m; } 616 int xs = x.step(); 617 int ys = 1; 618 const double* xp = x.cptr(); 619 if (xs < 0) xp += (x.size()-1)*xs; 620 double* yp = (double*) y.ptr(); 621 if (ys < 0) yp += (y.size()-1)*ys; 622 double xalpha(TMV_REAL(alpha)); 623 double xbeta(1); 624 BLASNAME(dgemv) ( 625 BLASCM BLASCH_NT, 626 BLASV(m),BLASV(n),BLASV(xalpha), 627 BLASP((double*)A.cptr()),BLASV(lda), 628 BLASP(xp),BLASV(xs),BLASV(xbeta), 629 BLASP(yp),BLASV(ys) BLAS1); 630 } else { 631 Vector<std::complex<double> > yy(y.size()); 632 BlasMultMV(std::complex<double>(1),A,x,0,yy.view()); 633 y += alpha*yy; 634 } 635 } 636 } else { // A.isrm 637 BlasMultMV(alpha,A,Vector<std::complex<double> >(x),beta,y); 638 } 639 } BlasMultMV(const std::complex<double> alpha,const GenMatrix<double> & A,const GenVector<std::complex<double>> & x,const int beta,VectorView<std::complex<double>> y)640 template <> void BlasMultMV( 641 const std::complex<double> alpha, 642 const GenMatrix<double>& A, 643 const GenVector<std::complex<double> >& x, 644 const int beta, VectorView<std::complex<double> > y) 645 { 646 if (beta == 0) { 647 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 648 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 649 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 650 if (lda < m) { TMVAssert(n==1); lda = m; } 651 int xs = 2*x.step(); 652 int ys = 2*y.step(); 653 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 654 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 655 const double* xp = (const double*) x.cptr(); 656 if (xs < 0) xp += (x.size()-1)*xs; 657 double* yp = (double*) y.ptr(); 658 if (ys < 0) yp += (y.size()-1)*ys; 659 double xalpha(1); 660 y.setZero(); 661 double xbeta(1); 662 BLASNAME(dgemv) ( 663 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 664 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 665 BLASP(xp),BLASV(xs),BLASV(xbeta), 666 BLASP(yp),BLASV(ys) BLAS1); 667 BLASNAME(dgemv) ( 668 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 669 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 670 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 671 BLASP(yp+1),BLASV(ys) BLAS1); 672 if (x.isconj()) y.conjugateSelf(); 673 y *= alpha; 674 } else if (TMV_IMAG(alpha) == 0. && !x.isconj()) { 675 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 676 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 677 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 678 if (lda < m) { TMVAssert(n==1); lda = m; } 679 int xs = 2*x.step(); 680 int ys = 2*y.step(); 681 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 682 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 683 const double* xp = (const double*) x.cptr(); 684 if (xs < 0) xp += (x.size()-1)*xs; 685 double* yp = (double*) y.ptr(); 686 if (ys < 0) yp += (y.size()-1)*ys; 687 double xalpha(TMV_REAL(alpha)); 688 double xbeta(1); 689 BLASNAME(dgemv) ( 690 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 691 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 692 BLASP(xp),BLASV(xs),BLASV(xbeta), 693 BLASP(yp),BLASV(ys) BLAS1); 694 BLASNAME(dgemv) ( 695 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 696 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 697 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 698 BLASP(yp+1),BLASV(ys) BLAS1); 699 } else { 700 Vector<std::complex<double> > xx = alpha*x; 701 BlasMultMV(std::complex<double>(1),A,xx,1,y); 702 } 703 } BlasMultMV(const std::complex<double> alpha,const GenMatrix<double> & A,const GenVector<double> & x,const int beta,VectorView<std::complex<double>> y)704 template <> void BlasMultMV( 705 const std::complex<double> alpha, 706 const GenMatrix<double>& A, 707 const GenVector<double>& x, 708 const int beta, VectorView<std::complex<double> > y) 709 { 710 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 711 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 712 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 713 if (lda < m) { TMVAssert(n==1); lda = m; } 714 int xs = x.step(); 715 int ys = 2*y.step(); 716 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 717 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 718 const double* xp = x.cptr(); 719 if (xs < 0) xp += (x.size()-1)*xs; 720 double* yp = (double*) y.ptr(); 721 if (ys < 0) yp += (y.size()-1)*ys; 722 double ar(TMV_REAL(alpha)); 723 double ai(TMV_IMAG(alpha)); 724 double xbeta(beta); 725 if (beta == 0) y.setZero(); 726 if (ar != 0.) { 727 BLASNAME(dgemv) ( 728 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 729 BLASV(m),BLASV(n),BLASV(ar),BLASP(A.cptr()),BLASV(lda), 730 BLASP(xp),BLASV(xs),BLASV(xbeta), 731 BLASP(yp),BLASV(ys) BLAS1); 732 } 733 if (ai != 0.) { 734 BLASNAME(dgemv) ( 735 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 736 BLASV(m),BLASV(n),BLASV(ai),BLASP(A.cptr()),BLASV(lda), 737 BLASP(xp),BLASV(xs),BLASV(xbeta), 738 BLASP(yp+1),BLASV(ys) BLAS1); 739 } 740 } 741 #endif 742 #ifdef INST_FLOAT BlasMultMV(const float alpha,const GenMatrix<float> & A,const GenVector<float> & x,const int beta,VectorView<float> y)743 template <> void BlasMultMV( 744 const float alpha, const GenMatrix<float>& A, 745 const GenVector<float>& x, const int beta, VectorView<float> y) 746 { 747 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 748 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 749 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 750 if (lda < m) { TMVAssert(n==1); lda = m; } 751 int xs = x.step(); 752 int ys = y.step(); 753 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 754 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 755 const float* xp = x.cptr(); 756 if (xs < 0) xp += (x.size()-1)*xs; 757 float* yp = y.ptr(); 758 if (ys < 0) yp += (y.size()-1)*ys; 759 if (beta == 0) y.setZero(); 760 float xbeta(1); 761 #if 0 762 std::cout<<"Before sgemv:\n"; 763 std::cout<<"A = "<<A<<std::endl; 764 std::cout<<"x = "<<x<<std::endl; 765 std::cout<<"y = "<<y<<std::endl; 766 std::cout<<"m,n = "<<m<<','<<n<<std::endl; 767 std::cout<<"alpha,beta = "<<alpha<<','<<xbeta<<std::endl; 768 std::cout<<"A.cptr = "<<A.cptr()<<std::endl; 769 std::cout<<"xp = "<<xp<<std::endl; 770 std::cout<<"yp = "<<yp<<std::endl; 771 std::cout<<"lda,xs,ys = "<<lda<<','<<xs<<','<<ys<<std::endl; 772 std::cout<<"cm = "<<BlasIsCM(A)<<std::endl; 773 #endif 774 775 BLASNAME(sgemv) ( 776 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 777 BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda), 778 BLASP(xp),BLASV(xs),BLASV(xbeta),BLASP(yp),BLASV(ys) 779 BLAS1); 780 #if 0 781 std::cout<<"After sgemv:"<<std::endl; 782 std::cout<<"y -> "<<y<<std::endl; 783 #endif 784 } BlasMultMV(const std::complex<float> alpha,const GenMatrix<std::complex<float>> & A,const GenVector<std::complex<float>> & x,const int beta,VectorView<std::complex<float>> y)785 template <> void BlasMultMV( 786 const std::complex<float> alpha, 787 const GenMatrix<std::complex<float> >& A, 788 const GenVector<std::complex<float> >& x, 789 const int beta, VectorView<std::complex<float> > y) 790 { 791 if (x.isconj() 792 #ifndef CBLAS 793 && !(A.isconj() && BlasIsCM(A)) 794 #endif 795 ) { 796 Vector<std::complex<float> > xx = alpha*x; 797 return BlasMultMV(std::complex<float>(1),A,xx,beta,y); 798 } 799 800 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 801 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 802 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 803 if (lda < m) { TMVAssert(n==1); lda = m; } 804 int xs = x.step(); 805 int ys = y.step(); 806 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 807 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 808 const std::complex<float>* xp = x.cptr(); 809 if (xs < 0) xp += (x.size()-1)*xs; 810 std::complex<float>* yp = y.ptr(); 811 if (ys < 0) yp += (y.size()-1)*ys; 812 if (beta == 0) y.setZero(); 813 std::complex<float> xbeta(1); 814 #if 0 815 std::cout<<"Before cgemv:\n"; 816 std::cout<<"A = "<<A<<std::endl; 817 std::cout<<"x = "<<x<<std::endl; 818 std::cout<<"y = "<<y<<std::endl; 819 std::cout<<"m,n = "<<m<<','<<n<<std::endl; 820 std::cout<<"alpha,beta = "<<alpha<<','<<xbeta<<std::endl; 821 std::cout<<"A.cptr = "<<A.cptr()<<std::endl; 822 std::cout<<"xp = "<<xp<<std::endl; 823 std::cout<<"yp = "<<yp<<std::endl; 824 std::cout<<"lda,xs,ys = "<<lda<<','<<xs<<','<<ys<<std::endl; 825 std::cout<<"conj = "<<A.isconj()<<std::endl; 826 std::cout<<"cm = "<<A.iscm()<<std::endl; 827 #endif 828 if (A.isconj() && BlasIsCM(A)) { 829 #ifdef CBLAS 830 TMV_SWAP(m,n); 831 BLASNAME(cgemv) ( 832 BLASRM BLASCH_CT, 833 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 834 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) BLAS1); 835 #else 836 std::complex<float> ca = TMV_CONJ(alpha); 837 if (x.isconj()) { 838 y.conjugateSelf(); 839 BLASNAME(cgemv) ( 840 BLASCM BLASCH_NT, 841 BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda), 842 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) 843 BLAS1); 844 y.conjugateSelf(); 845 } else { 846 Vector<std::complex<float> > xx = ca*x.conjugate(); 847 ca = std::complex<float>(1); 848 xs = 1; 849 xp = xx.cptr(); 850 y.conjugateSelf(); 851 BLASNAME(cgemv) ( 852 BLASCM BLASCH_NT, 853 BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda), 854 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) 855 BLAS1); 856 y.conjugateSelf(); 857 } 858 #endif 859 } else { 860 BLASNAME(cgemv) ( 861 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 862 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 863 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) 864 BLAS1); 865 } 866 #if 0 867 std::cout<<"After cgemv:"<<std::endl; 868 std::cout<<"y -> "<<y<<std::endl; 869 #endif 870 } BlasMultMV(const std::complex<float> alpha,const GenMatrix<std::complex<float>> & A,const GenVector<float> & x,const int beta,VectorView<std::complex<float>> y)871 template <> void BlasMultMV( 872 const std::complex<float> alpha, 873 const GenMatrix<std::complex<float> >& A, 874 const GenVector<float>& x, 875 const int beta, VectorView<std::complex<float> > y) 876 { 877 if (BlasIsCM(A)) { 878 if (y.step() != 1) { 879 Vector<std::complex<float> > yy(y.size()); 880 BlasMultMV(std::complex<float>(1),A,x,0,yy.view()); 881 if (beta == 0) y = alpha*yy; 882 else y += alpha*yy; 883 } else { 884 if (beta == 0) { 885 int m = 2*A.colsize(); 886 int n = A.rowsize(); 887 int lda = 2*A.stepj(); 888 if (lda < m) { TMVAssert(n==1); lda = m; } 889 int xs = x.step(); 890 int ys = 1; 891 const float* xp = x.cptr(); 892 if (xs < 0) xp += (x.size()-1)*xs; 893 float* yp = (float*) y.ptr(); 894 float xalpha(1); 895 y.setZero(); 896 float xbeta(1); 897 BLASNAME(sgemv) ( 898 BLASCM BLASCH_NT, 899 BLASV(m),BLASV(n),BLASV(xalpha), 900 BLASP((float*)A.cptr()),BLASV(lda), 901 BLASP(xp),BLASV(xs),BLASV(xbeta), 902 BLASP(yp),BLASV(ys) BLAS1); 903 if (A.isconj()) y.conjugateSelf(); 904 y *= alpha; 905 } else if (A.isconj()) { 906 Vector<std::complex<float> > yy(y.size()); 907 BlasMultMV(std::complex<float>(1),A.conjugate(),x,0,yy.view()); 908 y += alpha*yy.conjugate(); 909 } else if (TMV_IMAG(alpha) == 0.F) { 910 int m = 2*A.colsize(); 911 int n = A.rowsize(); 912 int lda = 2*A.stepj(); 913 if (lda < m) { TMVAssert(n==1); lda = m; } 914 int xs = x.step(); 915 int ys = 1; 916 const float* xp = x.cptr(); 917 if (xs < 0) xp += (x.size()-1)*xs; 918 float* yp = (float*) y.ptr(); 919 float xalpha(TMV_REAL(alpha)); 920 float xbeta(1); 921 BLASNAME(sgemv) ( 922 BLASCM BLASCH_NT, 923 BLASV(m),BLASV(n),BLASV(xalpha), 924 BLASP((float*)A.cptr()),BLASV(lda), 925 BLASP(xp),BLASV(xs),BLASV(xbeta), 926 BLASP(yp),BLASV(ys) BLAS1); 927 } else { 928 Vector<std::complex<float> > yy(y.size()); 929 BlasMultMV(std::complex<float>(1),A,x,0,yy.view()); 930 y += alpha*yy; 931 } 932 } 933 } else { // A.isrm 934 BlasMultMV(alpha,A,Vector<std::complex<float> >(x),beta,y); 935 } 936 } BlasMultMV(const std::complex<float> alpha,const GenMatrix<float> & A,const GenVector<std::complex<float>> & x,const int beta,VectorView<std::complex<float>> y)937 template <> void BlasMultMV( 938 const std::complex<float> alpha, 939 const GenMatrix<float>& A, 940 const GenVector<std::complex<float> >& x, 941 const int beta, VectorView<std::complex<float> > y) 942 { 943 if (beta == 0) { 944 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 945 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 946 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 947 if (lda < m) { TMVAssert(n==1); lda = m; } 948 int xs = 2*x.step(); 949 int ys = 2*y.step(); 950 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 951 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 952 const float* xp = (const float*) x.cptr(); 953 if (xs < 0) xp += (x.size()-1)*xs; 954 float* yp = (float*) y.ptr(); 955 if (ys < 0) yp += (y.size()-1)*ys; 956 float xalpha(1); 957 y.setZero(); 958 float xbeta(1); 959 BLASNAME(sgemv) ( 960 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 961 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 962 BLASP(xp),BLASV(xs),BLASV(xbeta), 963 BLASP(yp),BLASV(ys) BLAS1); 964 BLASNAME(sgemv) ( 965 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 966 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 967 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 968 BLASP(yp+1),BLASV(ys) BLAS1); 969 if (x.isconj()) y.conjugateSelf(); 970 y *= alpha; 971 } else if (TMV_IMAG(alpha) == 0.F && !x.isconj()) { 972 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 973 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 974 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 975 if (lda < m) { TMVAssert(n==1); lda = m; } 976 int xs = 2*x.step(); 977 int ys = 2*y.step(); 978 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 979 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 980 const float* xp = (const float*) x.cptr(); 981 if (xs < 0) xp += (x.size()-1)*xs; 982 float* yp = (float*) y.ptr(); 983 if (ys < 0) yp += (y.size()-1)*ys; 984 float xalpha(TMV_REAL(alpha)); 985 float xbeta(1); 986 BLASNAME(sgemv) ( 987 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 988 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 989 BLASP(xp),BLASV(xs),BLASV(xbeta), 990 BLASP(yp),BLASV(ys) BLAS1); 991 BLASNAME(sgemv) ( 992 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 993 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda), 994 BLASP(xp+1),BLASV(xs),BLASV(xbeta), 995 BLASP(yp+1),BLASV(ys) BLAS1); 996 } else { 997 Vector<std::complex<float> > xx = alpha*x; 998 BlasMultMV(std::complex<float>(1),A,xx,1,y); 999 } 1000 } BlasMultMV(const std::complex<float> alpha,const GenMatrix<float> & A,const GenVector<float> & x,const int beta,VectorView<std::complex<float>> y)1001 template <> void BlasMultMV( 1002 const std::complex<float> alpha, 1003 const GenMatrix<float>& A, 1004 const GenVector<float>& x, 1005 const int beta, VectorView<std::complex<float> > y) 1006 { 1007 int m = BlasIsCM(A) ? A.colsize() : A.rowsize(); 1008 int n = BlasIsCM(A) ? A.rowsize() : A.colsize(); 1009 int lda = BlasIsCM(A) ? A.stepj() : A.stepi(); 1010 if (lda < m) { TMVAssert(n==1); lda = m; } 1011 int xs = x.step(); 1012 int ys = 2*y.step(); 1013 if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; } 1014 if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; } 1015 const float* xp = x.cptr(); 1016 if (xs < 0) xp += (x.size()-1)*xs; 1017 float* yp = (float*) y.ptr(); 1018 if (ys < 0) yp += (y.size()-1)*ys; 1019 float ar(TMV_REAL(alpha)); 1020 float ai(TMV_IMAG(alpha)); 1021 if (beta == 0) y.setZero(); 1022 float xbeta(1); 1023 if (ar != 0.F) { 1024 BLASNAME(sgemv) ( 1025 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 1026 BLASV(m),BLASV(n),BLASV(ar),BLASP(A.cptr()),BLASV(lda), 1027 BLASP(xp),BLASV(xs),BLASV(xbeta), 1028 BLASP(yp),BLASV(ys) BLAS1); 1029 } 1030 if (ai != 0.F) { 1031 BLASNAME(sgemv) ( 1032 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T, 1033 BLASV(m),BLASV(n),BLASV(ai),BLASP(A.cptr()),BLASV(lda), 1034 BLASP(xp),BLASV(xs),BLASV(xbeta), 1035 BLASP(yp+1),BLASV(ys) BLAS1); 1036 } 1037 } 1038 #endif 1039 #endif // BLAS 1040 1041 template <bool add, class T, class Ta, class Tx> DoMultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)1042 static void DoMultMV( 1043 const T alpha, const GenMatrix<Ta>& A, 1044 const GenVector<Tx>& x, VectorView<T> y) 1045 { 1046 #ifdef XDEBUG 1047 std::cout<<"Start DoMultMV\n"; 1048 std::cout<<"alpha = "<<alpha<<std::endl; 1049 std::cout<<"add = "<<add<<std::endl; 1050 std::cout<<"A = "<<A<<std::endl; 1051 std::cout<<"x = "<<x<<std::endl; 1052 if (add) std::cout<<"y = "<<y<<std::endl; 1053 #endif 1054 1055 TMVAssert(A.rowsize() == x.size()); 1056 TMVAssert(A.colsize() == y.size()); 1057 TMVAssert(alpha != T(0)); 1058 TMVAssert(x.size() > 0); 1059 TMVAssert(y.size() > 0); 1060 TMVAssert(y.ct() == NonConj); 1061 1062 #ifdef BLAS 1063 if (x.step() == 0) { 1064 if (x.size() <= 1) 1065 DoMultMV<add>( 1066 alpha,A,ConstVectorView<Tx>(x.cptr(),x.size(),1,x.ct()),y); 1067 else 1068 DoMultMV<add>(alpha,A,Vector<Tx>(x),y); 1069 } else if (y.step() == 0) { 1070 TMVAssert(y.size() <= 1); 1071 DoMultMV<add>(alpha,A,x,VectorView<T>(y.ptr(),y.size(),1,y.ct())); 1072 #if 1 1073 } else if (y.step() != 1) { 1074 // Most BLAS implementations do fine with the y.step() < 0. 1075 // And in fact, they _should_ do ok even if the step is negative. 1076 // However, some implementations seem to propagate nan's from 1077 // the temporary memory they create to do the unit-1 calculation. 1078 // So to make sure they don't have to make a temporary, I just 1079 // do it here for them. 1080 #else 1081 } else if (y.step() < 0) { 1082 #endif 1083 Vector<T> yy(y.size()); 1084 DoMultMV<false>(T(1),A,x,yy.view()); 1085 if (add) y += alpha*yy; 1086 else y = alpha*yy; 1087 #if 1 1088 } else if (x.step() != 1) { 1089 // I don't think the non-unit step is a problem, but just to be 1090 // sure... 1091 #else 1092 } else if (x.step() < 0) { 1093 #endif 1094 Vector<T> xx = alpha*x; 1095 DoMultMV<add>(T(1),A,xx,y); 1096 } else if (BlasIsCM(A) || BlasIsRM(A)) { 1097 if (!SameStorage(A,y)) { 1098 if (!SameStorage(x,y) && !SameStorage(A,x)) { 1099 BlasMultMV(alpha,A,x,add?1:0,y); 1100 } else { 1101 Vector<T> xx = alpha*x; 1102 BlasMultMV(T(1),A,xx,add?1:0,y); 1103 } 1104 } else { 1105 Vector<T> yy(y.size()); 1106 if (!SameStorage(A,x)) { 1107 BlasMultMV(T(1),A,x,0,yy.view()); 1108 if (add) y += alpha*yy; 1109 else y = alpha*yy; 1110 } else { 1111 Vector<T> xx = alpha*x; 1112 BlasMultMV(T(1),A,xx,0,yy.view()); 1113 if (add) y += yy; 1114 else y = yy; 1115 } 1116 } 1117 } else { 1118 if (TMV_IMAG(alpha) == T(0)) { 1119 Matrix<Ta,RowMajor> A2 = TMV_REAL(alpha)*A; 1120 DoMultMV<add>(T(1),A2,x,y); 1121 } else { 1122 Matrix<T,RowMajor> A2 = alpha*A; 1123 DoMultMV<add>(T(1),A2,x,y); 1124 } 1125 } 1126 #else 1127 NonBlasMultMV<add>(alpha,A,x,y); 1128 #endif 1129 #ifdef XDEBUG 1130 std::cout<<"y => "<<y<<std::endl; 1131 #endif 1132 } 1133 MultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)1134 template <bool add, class T, class Ta, class Tx> void MultMV( 1135 const T alpha, const GenMatrix<Ta>& A, const GenVector<Tx>& x, 1136 VectorView<T> y) 1137 // y (+)= alpha * A * x 1138 { 1139 #ifdef XDEBUG 1140 cout<<"Start MultMV: alpha = "<<alpha<<endl; 1141 cout<<"add = "<<add<<endl; 1142 cout<<"A = "<<TMV_Text(A)<<" "<<A<<endl; 1143 cout<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<" "<<x<<endl; 1144 cout<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<endl; 1145 if (add) cout<<"y = "<<y<<endl; 1146 cout<<"ptrs = "<<A.cptr()<<" "<<x.cptr()<<" "<<y.cptr()<<std::endl; 1147 Vector<Tx> x0 = x; 1148 Vector<T> y0 = y; 1149 Matrix<Ta> A0 = A; 1150 Vector<T> y2 = y; 1151 for(ptrdiff_t i=0;i<y.size();i++) { 1152 if (add) y2(i) += alpha * (A.row(i) * x0); 1153 else y2(i) = alpha * (A.row(i) * x0); 1154 } 1155 cout<<"y2 = "<<y2<<endl; 1156 #endif 1157 TMVAssert(A.rowsize() == x.size()); 1158 TMVAssert(A.colsize() == y.size()); 1159 1160 if (y.size() > 0) { 1161 if (x.size()==0 || alpha==T(0)) { 1162 if (!add) y.setZero(); 1163 } else if (y.isconj()) { 1164 DoMultMV<add>( 1165 TMV_CONJ(alpha),A.conjugate(),x.conjugate(),y.conjugate()); 1166 } else { 1167 DoMultMV<add>(alpha,A,x,y); 1168 } 1169 } 1170 1171 #ifdef XDEBUG 1172 cout<<"y => "<<y<<endl; 1173 if (!(Norm(y-y2) <= 1174 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 1175 (add?Norm(y0):TMV_RealType(T)(0))))) { 1176 cerr<<"MultMV: alpha = "<<alpha<<endl; 1177 cerr<<"add = "<<add<<endl; 1178 cerr<<"A = "<<TMV_Text(A); 1179 if (A.rowsize() < 30 && A.colsize() < 30) cerr<<" "<<A0; 1180 else cerr<<" "<<A.colsize()<<" x "<<A.rowsize(); 1181 cerr<<endl<<"x = "<<TMV_Text(x)<<" step "<<x.step(); 1182 if (x.size() < 30) cerr<<" "<<x0; 1183 cerr<<endl<<"y = "<<TMV_Text(y)<<" step "<<y.step(); 1184 if (add && y.size() < 30) cerr<<" "<<y0; 1185 cerr<<endl<<"Aptr = "<<A.cptr(); 1186 cerr<<", xptr = "<<x.cptr()<<", yptr = "<<y.cptr()<<endl; 1187 if (y.size() < 200) { 1188 cerr<<"--> y = "<<y<<endl; 1189 cerr<<"y2 = "<<y2<<endl; 1190 } else { 1191 ptrdiff_t imax; 1192 (y-y2).maxAbsElement(&imax); 1193 cerr<<"y("<<imax<<") = "<<y(imax)<<endl; 1194 cerr<<"y2("<<imax<<") = "<<y2(imax)<<endl; 1195 } 1196 cerr<<"Norm(A0) = "<<Norm(A0)<<endl; 1197 cerr<<"Norm(x0) = "<<Norm(x0)<<endl; 1198 if (add) cerr<<"Norm(y0) = "<<Norm(y0)<<endl; 1199 cerr<<"|alpha|*|A0|*|x0|+?|y0| = "<< 1200 TMV_ABS(alpha)*Norm(A0)*Norm(x0)+ 1201 (add?Norm(y0):TMV_RealType(T)(0))<<endl; 1202 cerr<<"Norm(y-y2) = "<<Norm(y-y2)<<endl; 1203 cerr<<"NormInf(y-y2) = "<<NormInf(y-y2)<<endl; 1204 cerr<<"Norm1(y-y2) = "<<Norm1(y-y2)<<endl; 1205 abort(); 1206 } 1207 #endif 1208 } 1209 1210 #define InstFile "TMV_MultMV.inst" 1211 #include "TMV_Inst.h" 1212 #undef InstFile 1213 1214 } // namespace tmv 1215 1216 1217