1 /////////////////////////////////////////////////////////////////////////////// 2 // // 3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis // 4 // Copyright (C) 1998 - 2016 // 5 // All rights reserved // 6 // // 7 // The project is hosted at https://code.google.com/p/tmv-cpp/ // 8 // where you can find the current version and current documention. // 9 // // 10 // For concerns or problems with the software, Mike may be contacted at // 11 // mike_jarvis17 [at] gmail. // 12 // // 13 // This software is licensed under a FreeBSD license. The file // 14 // TMV_LICENSE should have bee included with this distribution. // 15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/. // 16 // // 17 // Essentially, you can use this software however you want provided that // 18 // you include the TMV_LICENSE file in any distribution that uses it. // 19 // // 20 /////////////////////////////////////////////////////////////////////////////// 21 22 23 //#define XDEBUG 24 25 26 #include "TMV_Blas.h" 27 #include "tmv/TMV_TriDiv.h" 28 #include "tmv/TMV_TriMatrix.h" 29 #include "tmv/TMV_Vector.h" 30 31 // Using CblasRowMajor in all the other files isn't working correcly with 32 // MKL 10.2.2. The usage in here hasn't given me any problems, but 33 // just to be safe, I'm disabling it here as well. My guess is that I 34 // haven't tested the V/U code as thoroughly as the others, so it might 35 // turn up as a problem in more complicated code. 36 #ifdef CBLAS 37 #undef CBLAS 38 #endif 39 40 #ifdef NOTHROW 41 #include <iostream> 42 #endif 43 44 #ifdef XDEBUG 45 #include "tmv/TMV_MatrixArith.h" 46 #include "tmv/TMV_VectorArith.h" 47 #include <iostream> 48 using std::cout; 49 using std::cerr; 50 using std::endl; 51 #endif 52 53 namespace tmv { 54 55 // 56 // TriLDivEq V 57 // 58 59 template <bool rm, bool ca, bool ua, class T, class Ta> DoRowTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)60 static void DoRowTriLDivEq( 61 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 62 { 63 // Solve A x = y where A is an upper triangle matrix 64 //cout<<"Row Upper\n"; 65 TMVAssert(b.step()==1); 66 TMVAssert(A.size() == b.size()); 67 TMVAssert(b.size() > 0); 68 TMVAssert(b.ct() == NonConj); 69 TMVAssert(rm == A.isrm()); 70 TMVAssert(ca == A.isconj()); 71 TMVAssert(ua == A.isunit()); 72 73 const ptrdiff_t N = A.size(); 74 75 const ptrdiff_t sj = (rm?1:A.stepj()); 76 const ptrdiff_t ds = A.stepi()+sj; 77 const Ta* Aii = A.cptr() + (ua ? N-2 : N-1)*ds; 78 T* bi = b.ptr() + (ua ? N-2 : N-1); 79 80 if (!ua) { 81 if (*Aii==Ta(0)) { 82 #ifdef NOTHROW 83 std::cerr<<"Singular UpperTriMatrix found\n"; 84 exit(1); 85 #else 86 throw SingularUpperTriMatrix<Ta>(A); 87 #endif 88 } 89 #ifdef TMVFLDEBUG 90 TMVAssert(bi >= b._first); 91 TMVAssert(bi < b._last); 92 #endif 93 *bi = TMV_Divide(*bi,(ca ? TMV_CONJ(*Aii) : *Aii)); 94 Aii -= ds; 95 --bi; 96 } 97 if (N==1) return; 98 99 for(ptrdiff_t i=N-1,len=1; i>0; --i,++len,Aii-=ds,--bi) { 100 // Actual row being done is i-1, not i 101 102 // *bi -= A.row(i,i+1,N) * b.subVector(i+1,N); 103 const T* bj = bi+1; 104 const Ta* Aij = Aii + sj; 105 for(ptrdiff_t j=len;j>0;--j,++bj,(rm?++Aij:Aij+=sj)) { 106 #ifdef TMVFLDEBUG 107 TMVAssert(bi >= b._first); 108 TMVAssert(bi < b._last); 109 #endif 110 *bi -= (*bj) * (ca ? TMV_CONJ(*Aij) : *Aij); 111 } 112 113 if (!ua) { 114 if (*Aii==Ta(0)) { 115 #ifdef NOTHROW 116 std::cerr<<"Singular UpperTriMatrix found\n"; 117 exit(1); 118 #else 119 throw SingularUpperTriMatrix<Ta>(A); 120 #endif 121 } 122 #ifdef TMVFLDEBUG 123 TMVAssert(bi >= b._first); 124 TMVAssert(bi < b._last); 125 #endif 126 *bi = TMV_Divide(*bi,(ca ? TMV_CONJ(*Aii) : *Aii)); 127 } 128 } 129 } 130 131 template <bool rm, class T, class Ta> RowTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)132 static inline void RowTriLDivEq( 133 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 134 { 135 if (A.isconj()) 136 if (A.isunit()) 137 DoRowTriLDivEq<rm,true,true>(A,b); 138 else 139 DoRowTriLDivEq<rm,true,false>(A,b); 140 else 141 if (A.isunit()) 142 DoRowTriLDivEq<rm,false,true>(A,b); 143 else 144 DoRowTriLDivEq<rm,false,false>(A,b); 145 } 146 147 template <bool cm, bool ca, bool ua, class T, class Ta> DoColTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)148 static void DoColTriLDivEq( 149 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 150 { 151 //cout<<"colmajor upper\n"; 152 // Solve A x = y where A is an upper triangle matrix 153 TMVAssert(b.step()==1); 154 TMVAssert(A.size() == b.size()); 155 TMVAssert(b.size() > 0); 156 TMVAssert(b.ct() == NonConj); 157 TMVAssert(cm == A.iscm()); 158 TMVAssert(ca == A.isconj()); 159 TMVAssert(ua == A.isunit()); 160 161 const ptrdiff_t N = A.size(); 162 163 const ptrdiff_t si = (cm ? 1 : A.stepi()); 164 const ptrdiff_t sj = A.stepj(); 165 const ptrdiff_t ds = si+sj; 166 const Ta* A0j = A.cptr()+(N-1)*sj; 167 const Ta* Ajj = (ua ? 0 : A0j+(N-1)*si); // if unit, this isn't used. 168 T*const b0 = b.ptr(); 169 T* bj = b0 + N-1; 170 171 for(ptrdiff_t j=N-1; j>0; --j,--bj,A0j-=sj) { 172 if (*bj != T(0)) { 173 if (!ua) { 174 if (*Ajj==Ta(0)) { 175 #ifdef NOTHROW 176 std::cerr<<"Singular UpperTriMatrix found\n"; 177 exit(1); 178 #else 179 throw SingularUpperTriMatrix<Ta>(A); 180 #endif 181 } 182 #ifdef TMVFLDEBUG 183 TMVAssert(bj >= b._first); 184 TMVAssert(bj < b._last); 185 #endif 186 *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj)); 187 Ajj-=ds; 188 } 189 190 // b.subVector(0,j) -= *bj * A.col(j,0,j); 191 T* bi = b0; 192 const Ta* Aij = A0j; 193 for(ptrdiff_t i=j;i>0;--i,++bi,(cm?++Aij:Aij+=si)) { 194 #ifdef TMVFLDEBUG 195 TMVAssert(bi >= b._first); 196 TMVAssert(bi < b._last); 197 #endif 198 *bi -= *bj * (ca ? TMV_CONJ(*Aij) : *Aij); 199 } 200 } 201 else if (!ua) Ajj -= ds; 202 } 203 if (!ua && *bj != T(0)) { 204 if (*Ajj==Ta(0)) { 205 #ifdef NOTHROW 206 std::cerr<<"Singular UpperTriMatrix found\n"; 207 exit(1); 208 #else 209 throw SingularUpperTriMatrix<Ta>(A); 210 #endif 211 } 212 #ifdef TMVFLDEBUG 213 TMVAssert(bj >= b._first); 214 TMVAssert(bj < b._last); 215 #endif 216 *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj)); 217 } 218 } 219 220 template <bool cm, class T, class Ta> ColTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)221 static inline void ColTriLDivEq( 222 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 223 { 224 if (A.isconj()) 225 if (A.isunit()) 226 DoColTriLDivEq<cm,true,true>(A,b); 227 else 228 DoColTriLDivEq<cm,true,false>(A,b); 229 else 230 if (A.isunit()) 231 DoColTriLDivEq<cm,false,true>(A,b); 232 else 233 DoColTriLDivEq<cm,false,false>(A,b); 234 } 235 236 template <bool rm, bool ca, bool ua, class T, class Ta> DoRowTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)237 static void DoRowTriLDivEq( 238 const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 239 { 240 // Solve A x = y where A is a lower triangle matrix 241 TMVAssert(b.step()==1); 242 TMVAssert(A.size() == b.size()); 243 TMVAssert(b.size() > 0); 244 TMVAssert(b.ct() == NonConj); 245 TMVAssert(rm == A.isrm()); 246 TMVAssert(ca == A.isconj()); 247 TMVAssert(ua == A.isunit()); 248 249 const ptrdiff_t N = A.size(); 250 251 const ptrdiff_t sj = (rm ? 1 : A.stepj()); 252 const ptrdiff_t si = A.stepi(); 253 254 const Ta* Ai0 = A.cptr(); 255 T* b0 = b.ptr(); 256 257 if (!ua) { 258 if (*Ai0==Ta(0)) { 259 #ifdef NOTHROW 260 std::cerr<<"Singular LowerTriMatrix found\n"; 261 exit(1); 262 #else 263 throw SingularLowerTriMatrix<Ta>(A); 264 #endif 265 } 266 #ifdef TMVFLDEBUG 267 TMVAssert(b0 >= b._first); 268 TMVAssert(b0 < b._last); 269 #endif 270 *b0 = TMV_Divide(*b0,(ca ? TMV_CONJ(*Ai0) : *Ai0)); 271 } 272 273 T* bi = b0+1; 274 Ai0 += si; 275 for(ptrdiff_t i=1,len=1;i<N;++i,++len,++bi,Ai0+=si) { 276 // *bi -= A.row(i,0,i) * b.subVector(0,i); 277 const Ta* Aij = Ai0; 278 const T* bj = b0; 279 for(ptrdiff_t j=len;j>0;--j,++bj,(rm?++Aij:Aij+=sj)){ 280 #ifdef TMVFLDEBUG 281 TMVAssert(bi >= b._first); 282 TMVAssert(bi < b._last); 283 #endif 284 *bi -= (*bj) * (ca ? TMV_CONJ(*Aij) : *Aij); 285 } 286 if (!ua) { 287 // Aij is Aii after the above for loop 288 if (*Aij==Ta(0)) { 289 #ifdef NOTHROW 290 std::cerr<<"Singular LowerTriMatrix found\n"; 291 exit(1); 292 #else 293 throw SingularLowerTriMatrix<Ta>(A); 294 #endif 295 } 296 #ifdef TMVFLDEBUG 297 TMVAssert(bi >= b._first); 298 TMVAssert(bi < b._last); 299 #endif 300 *bi = TMV_Divide(*bi,(ca ? TMV_CONJ(*Aij) : *Aij)); 301 } 302 } 303 } 304 305 template <bool rm, class T, class Ta> RowTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)306 static inline void RowTriLDivEq( 307 const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 308 { 309 if (A.isconj()) 310 if (A.isunit()) 311 DoRowTriLDivEq<rm,true,true>(A,b); 312 else 313 DoRowTriLDivEq<rm,true,false>(A,b); 314 else 315 if (A.isunit()) 316 DoRowTriLDivEq<rm,false,true>(A,b); 317 else 318 DoRowTriLDivEq<rm,false,false>(A,b); 319 } 320 321 template <bool cm, bool ca, bool ua, class T, class Ta> DoColTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)322 static void DoColTriLDivEq( 323 const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 324 { 325 // Solve A x = y where A is a lower triangle matrix 326 TMVAssert(b.step()==1); 327 TMVAssert(A.size() == b.size()); 328 TMVAssert(b.size() > 0); 329 TMVAssert(b.ct() == NonConj); 330 TMVAssert(cm == A.iscm()); 331 TMVAssert(ca == A.isconj()); 332 TMVAssert(ua == A.isunit()); 333 334 const ptrdiff_t N = A.size(); 335 336 const ptrdiff_t si = (cm ? 1 : A.stepi()); 337 const ptrdiff_t ds = A.stepj()+si; 338 const Ta* Ajj = A.cptr(); 339 T* bj = b.ptr(); 340 341 for(ptrdiff_t j=0,len=N-1;len>0;++j,--len,++bj,Ajj+=ds) if (*bj != T(0)) { 342 if (!ua) { 343 if (*Ajj==Ta(0)) { 344 #ifdef NOTHROW 345 std::cerr<<"Singular LowerTriMatrix found\n"; 346 exit(1); 347 #else 348 throw SingularLowerTriMatrix<Ta>(A); 349 #endif 350 } 351 #ifdef TMVFLDEBUG 352 TMVAssert(bj >= b._first); 353 TMVAssert(bj < b._last); 354 #endif 355 *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj)); 356 } 357 // b.subVector(j+1,N) -= *bj * A.col(j,j+1,N); 358 T* bi = bj+1; 359 const Ta* Aij = Ajj+si; 360 for(ptrdiff_t i=len;i>0;--i,++bi,(cm?++Aij:Aij+=si)){ 361 #ifdef TMVFLDEBUG 362 TMVAssert(bi >= b._first); 363 TMVAssert(bi < b._last); 364 #endif 365 *bi -= *bj * (ca ? TMV_CONJ(*Aij) : *Aij); 366 } 367 } 368 if (!ua && *bj != T(0)) { 369 if (*Ajj==Ta(0)) { 370 #ifdef NOTHROW 371 std::cerr<<"Singular LowerTriMatrix found\n"; 372 exit(1); 373 #else 374 throw SingularLowerTriMatrix<Ta>(A); 375 #endif 376 } 377 #ifdef TMVFLDEBUG 378 TMVAssert(bj >= b._first); 379 TMVAssert(bj < b._last); 380 #endif 381 *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj)); 382 } 383 } 384 385 template <bool cm, class T, class Ta> ColTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)386 static inline void ColTriLDivEq( 387 const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 388 { 389 if (A.isconj()) 390 if (A.isunit()) 391 DoColTriLDivEq<cm,true,true>(A,b); 392 else 393 DoColTriLDivEq<cm,true,false>(A,b); 394 else 395 if (A.isunit()) 396 DoColTriLDivEq<cm,false,true>(A,b); 397 else 398 DoColTriLDivEq<cm,false,false>(A,b); 399 } 400 401 template <class T, class Ta> DoTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)402 static inline void DoTriLDivEq( 403 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 404 { 405 if (A.isrm()) RowTriLDivEq<true>(A,b); 406 else if (A.iscm()) ColTriLDivEq<true>(A,b); 407 else RowTriLDivEq<false>(A,b); 408 } 409 410 template <class T, class Ta> DoTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)411 static inline void DoTriLDivEq( 412 const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 413 { 414 if (A.isrm()) RowTriLDivEq<true>(A,b); 415 else if (A.iscm()) ColTriLDivEq<true>(A,b); 416 else RowTriLDivEq<false>(A,b); 417 } 418 419 template <class T, class Ta> NonBlasTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)420 static void NonBlasTriLDivEq( 421 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 422 { 423 //cout<<"Upper LDivEq vect\n"; 424 TMVAssert(A.size() == b.size()); 425 TMVAssert(b.size()>0); 426 TMVAssert(b.ct() == NonConj); 427 428 if (b.step() == 1) { 429 ptrdiff_t i2 = b.size(); 430 for(const T* b2 = b.cptr()+i2-1; i2>0 && *b2==T(0); --i2,--b2); 431 if (i2==0) return; 432 else DoTriLDivEq(A.subTriMatrix(0,i2),b.subVector(0,i2)); 433 } else { 434 Vector<T> bb = b; 435 NonBlasTriLDivEq(A,bb.view()); 436 b = bb; 437 } 438 } 439 440 template <class T, class Ta> NonBlasTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)441 static void NonBlasTriLDivEq( 442 const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 443 { 444 //cout<<"Lower LDivEq vect\n"; 445 TMVAssert(A.size() == b.size()); 446 TMVAssert(b.size()>0); 447 TMVAssert(b.ct() == NonConj); 448 449 if (b.step() == 1) { 450 const ptrdiff_t N = b.size(); 451 ptrdiff_t i1 = 0; 452 for(const T* b1 = b.cptr(); i1<N && *b1==T(0); ++i1,++b1); 453 if (i1==N) return; 454 else 455 DoTriLDivEq(A.subTriMatrix(i1,N),b.subVector(i1,N)); 456 } else { 457 Vector<T> bb = b; 458 NonBlasTriLDivEq(A,bb.view()); 459 b = bb; 460 } 461 } 462 463 #ifdef BLAS 464 template <class T, class Ta> BlasTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)465 static inline void BlasTriLDivEq( 466 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 467 { NonBlasTriLDivEq(A,b); } 468 template <class T, class Ta> BlasTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)469 static inline void BlasTriLDivEq( 470 const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 471 { NonBlasTriLDivEq(A,b); } 472 #ifdef INST_DOUBLE 473 template <> BlasTriLDivEq(const GenUpperTriMatrix<double> & A,VectorView<double> b)474 void BlasTriLDivEq( 475 const GenUpperTriMatrix<double>& A, VectorView<double> b) 476 { 477 int n=A.size(); 478 int lda = A.isrm()?A.stepi():A.stepj(); 479 int bs = b.step(); 480 double* bp = b.ptr(); 481 if (bs < 0) bp += (n-1)*bs; 482 BLASNAME(dtrsv) ( 483 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 484 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 485 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 486 BLAS1 BLAS1 BLAS1); 487 } 488 template <> BlasTriLDivEq(const GenLowerTriMatrix<double> & A,VectorView<double> b)489 void BlasTriLDivEq( 490 const GenLowerTriMatrix<double>& A, VectorView<double> b) 491 { 492 int n=A.size(); 493 int lda = A.isrm()?A.stepi():A.stepj(); 494 int bs = b.step(); 495 double* bp = b.ptr(); 496 if (bs < 0) bp += (n-1)*bs; 497 BLASNAME(dtrsv) ( 498 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 499 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 500 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 501 BLAS1 BLAS1 BLAS1); 502 } 503 template <> BlasTriLDivEq(const GenUpperTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> b)504 void BlasTriLDivEq( 505 const GenUpperTriMatrix<std::complex<double> >& A, 506 VectorView<std::complex<double> > b) 507 { 508 int n=A.size(); 509 int lda = A.isrm()?A.stepi():A.stepj(); 510 int bs = b.step(); 511 std::complex<double>* bp = b.ptr(); 512 if (bs < 0) bp += (n-1)*bs; 513 if (A.iscm() && A.isconj()) { 514 #ifdef CBLAS 515 BLASNAME(ztrsv) ( 516 BLASRM BLASCH_LO,BLASCH_CT, 517 A.isunit()?BLASCH_U:BLASCH_NU, 518 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 519 BLAS1 BLAS1 BLAS1); 520 #else 521 b.conjugateSelf(); 522 BLASNAME(ztrsv) ( 523 BLASCM BLASCH_UP,BLASCH_NT, 524 A.isunit()?BLASCH_U:BLASCH_NU, 525 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 526 BLAS1 BLAS1 BLAS1); 527 b.conjugateSelf(); 528 #endif 529 } else { 530 BLASNAME(ztrsv) ( 531 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 532 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 533 A.isunit()?BLASCH_U:BLASCH_NU, 534 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 535 BLAS1 BLAS1 BLAS1); 536 } 537 } 538 template <> BlasTriLDivEq(const GenLowerTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> b)539 void BlasTriLDivEq( 540 const GenLowerTriMatrix<std::complex<double> >& A, 541 VectorView<std::complex<double> > b) 542 { 543 int n=A.size(); 544 int lda = A.isrm()?A.stepi():A.stepj(); 545 int bs = b.step(); 546 std::complex<double>* bp = b.ptr(); 547 if (bs < 0) bp += (n-1)*bs; 548 if (A.iscm() && A.isconj()) { 549 #ifdef CBLAS 550 BLASNAME(ztrsv) ( 551 BLASRM BLASCH_UP,BLASCH_CT, 552 A.isunit()?BLASCH_U:BLASCH_NU, 553 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 554 BLAS1 BLAS1 BLAS1); 555 #else 556 b.conjugateSelf(); 557 BLASNAME(ztrsv) ( 558 BLASCM BLASCH_LO,BLASCH_NT, 559 A.isunit()?BLASCH_U:BLASCH_NU, 560 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 561 BLAS1 BLAS1 BLAS1); 562 b.conjugateSelf(); 563 #endif 564 } else { 565 BLASNAME(ztrsv) ( 566 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 567 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 568 A.isunit()?BLASCH_U:BLASCH_NU, 569 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 570 BLAS1 BLAS1 BLAS1); 571 } 572 } 573 template <> BlasTriLDivEq(const GenUpperTriMatrix<double> & A,VectorView<std::complex<double>> b)574 void BlasTriLDivEq( 575 const GenUpperTriMatrix<double>& A, 576 VectorView<std::complex<double> > b) 577 { 578 int n=A.size(); 579 int lda = A.isrm()?A.stepi():A.stepj(); 580 int bs = 2*b.step(); 581 double* bp = (double*)(b.ptr()); 582 if (bs < 0) bp += (n-1)*bs; 583 BLASNAME(dtrsv) ( 584 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 585 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 586 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 587 BLAS1 BLAS1 BLAS1); 588 BLASNAME(dtrsv) ( 589 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 590 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 591 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs) 592 BLAS1 BLAS1 BLAS1); 593 } 594 template <> BlasTriLDivEq(const GenLowerTriMatrix<double> & A,VectorView<std::complex<double>> b)595 void BlasTriLDivEq( 596 const GenLowerTriMatrix<double>& A, 597 VectorView<std::complex<double> > b) 598 { 599 int n=A.size(); 600 int lda = A.isrm()?A.stepi():A.stepj(); 601 int bs = 2*b.step(); 602 double* bp = (double*)(b.ptr()); 603 if (bs < 0) bp += (n-1)*bs; 604 BLASNAME(dtrsv) ( 605 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 606 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 607 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 608 BLAS1 BLAS1 BLAS1); 609 BLASNAME(dtrsv) ( 610 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 611 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 612 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs) 613 BLAS1 BLAS1 BLAS1); 614 } 615 #endif 616 #ifdef INST_FLOAT 617 template <> BlasTriLDivEq(const GenUpperTriMatrix<float> & A,VectorView<float> b)618 void BlasTriLDivEq( 619 const GenUpperTriMatrix<float>& A, VectorView<float> b) 620 { 621 int n=A.size(); 622 int lda = A.isrm()?A.stepi():A.stepj(); 623 int bs = b.step(); 624 float* bp = (b.ptr()); 625 if (bs < 0) bp += (n-1)*bs; 626 BLASNAME(strsv) ( 627 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 628 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 629 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 630 BLAS1 BLAS1 BLAS1); 631 } 632 template <> BlasTriLDivEq(const GenLowerTriMatrix<float> & A,VectorView<float> b)633 void BlasTriLDivEq( 634 const GenLowerTriMatrix<float>& A, VectorView<float> b) 635 { 636 int n=A.size(); 637 int lda = A.isrm()?A.stepi():A.stepj(); 638 int bs = b.step(); 639 float* bp = (b.ptr()); 640 if (bs < 0) bp += (n-1)*bs; 641 BLASNAME(strsv) ( 642 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 643 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 644 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 645 BLAS1 BLAS1 BLAS1); 646 } 647 template <> BlasTriLDivEq(const GenUpperTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> b)648 void BlasTriLDivEq( 649 const GenUpperTriMatrix<std::complex<float> >& A, 650 VectorView<std::complex<float> > b) 651 { 652 int n=A.size(); 653 int lda = A.isrm()?A.stepi():A.stepj(); 654 int bs = b.step(); 655 std::complex<float>* bp = (b.ptr()); 656 if (bs < 0) bp += (n-1)*bs; 657 if (A.iscm() && A.isconj()) { 658 #ifdef CBLAS 659 BLASNAME(ctrsv) ( 660 BLASRM BLASCH_LO,BLASCH_CT, 661 A.isunit()?BLASCH_U:BLASCH_NU, 662 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 663 BLAS1 BLAS1 BLAS1); 664 #else 665 b.conjugateSelf(); 666 BLASNAME(ctrsv) ( 667 BLASCM BLASCH_UP,BLASCH_NT, 668 A.isunit()?BLASCH_U:BLASCH_NU, 669 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 670 BLAS1 BLAS1 BLAS1); 671 b.conjugateSelf(); 672 #endif 673 } else { 674 BLASNAME(ctrsv) ( 675 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 676 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 677 A.isunit()?BLASCH_U:BLASCH_NU, 678 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 679 BLAS1 BLAS1 BLAS1); 680 } 681 } 682 template <> BlasTriLDivEq(const GenLowerTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> b)683 void BlasTriLDivEq( 684 const GenLowerTriMatrix<std::complex<float> >& A, 685 VectorView<std::complex<float> > b) 686 { 687 int n=A.size(); 688 int lda = A.isrm()?A.stepi():A.stepj(); 689 int bs = b.step(); 690 std::complex<float>* bp = (b.ptr()); 691 if (bs < 0) bp += (n-1)*bs; 692 if (A.iscm() && A.isconj()) { 693 #ifdef CBLAS 694 BLASNAME(ctrsv) ( 695 BLASRM BLASCH_UP,BLASCH_CT, 696 A.isunit()?BLASCH_U:BLASCH_NU, 697 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 698 BLAS1 BLAS1 BLAS1); 699 #else 700 b.conjugateSelf(); 701 BLASNAME(ctrsv) ( 702 BLASCM BLASCH_LO,BLASCH_NT, 703 A.isunit()?BLASCH_U:BLASCH_NU, 704 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 705 BLAS1 BLAS1 BLAS1); 706 b.conjugateSelf(); 707 #endif 708 } else { 709 BLASNAME(ctrsv) ( 710 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 711 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T, 712 A.isunit()?BLASCH_U:BLASCH_NU, 713 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 714 BLAS1 BLAS1 BLAS1); 715 } 716 } 717 template <> BlasTriLDivEq(const GenUpperTriMatrix<float> & A,VectorView<std::complex<float>> b)718 void BlasTriLDivEq( 719 const GenUpperTriMatrix<float>& A, 720 VectorView<std::complex<float> > b) 721 { 722 int n=A.size(); 723 int lda = A.isrm()?A.stepi():A.stepj(); 724 int bs = 2*b.step(); 725 float* bp = (float*)(b.ptr()); 726 if (bs < 0) bp += (n-1)*bs; 727 BLASNAME(strsv) ( 728 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 729 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 730 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 731 BLAS1 BLAS1 BLAS1); 732 BLASNAME(strsv) ( 733 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO, 734 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 735 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs) 736 BLAS1 BLAS1 BLAS1); 737 } 738 template <> BlasTriLDivEq(const GenLowerTriMatrix<float> & A,VectorView<std::complex<float>> b)739 void BlasTriLDivEq( 740 const GenLowerTriMatrix<float>& A, 741 VectorView<std::complex<float> > b) 742 { 743 int n=A.size(); 744 int lda = A.isrm()?A.stepi():A.stepj(); 745 int bs = 2*b.step(); 746 float* bp = (float*)(b.ptr()); 747 if (bs < 0) bp += (n-1)*bs; 748 BLASNAME(strsv) ( 749 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 750 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 751 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs) 752 BLAS1 BLAS1 BLAS1); 753 BLASNAME(strsv) ( 754 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP, 755 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU, 756 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs) 757 BLAS1 BLAS1 BLAS1); 758 } 759 #endif // FLOAT 760 #endif // BLAS 761 762 template <class T, class Ta> TriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)763 void TriLDivEq( 764 const GenUpperTriMatrix<Ta>& A, VectorView<T> b) 765 { 766 TMVAssert(b.size() == A.size()); 767 #ifdef XDEBUG 768 Matrix<Ta> A0(A); 769 Vector<T> b0(b); 770 #endif 771 if (b.size() > 0) { 772 if (b.isconj()) TriLDivEq(A.conjugate(),b.conjugate()); 773 else 774 #ifdef BLAS 775 if (isComplex(T()) && isReal(Ta())) 776 BlasTriLDivEq(A,b); 777 else if ( !((A.isrm() && A.stepi()>0) || 778 (A.iscm() && A.stepj()>0)) ) { 779 if (A.isunit()) { 780 UpperTriMatrix<Ta,UnitDiag|ColMajor> AA = A; 781 BlasTriLDivEq(AA,b); 782 } else { 783 UpperTriMatrix<Ta,NonUnitDiag|ColMajor> AA = A; 784 BlasTriLDivEq(AA,b); 785 } 786 } 787 else BlasTriLDivEq(A,b); 788 #else 789 NonBlasTriLDivEq(A,b); 790 #endif 791 } 792 #ifdef XDEBUG 793 Vector<T> b2 = A0*b; 794 if (Norm(b2-b0) > 0.001*Norm(A0)*Norm(b0)) { 795 cerr<<"TriLDivEq: v/Upper\n"; 796 cerr<<"A = "<<TMV_Text(A)<<" "<<A0<<endl; 797 cerr<<"b = "<<TMV_Text(b)<<" "<<b0<<endl; 798 cerr<<"Done: b = "<<b<<endl; 799 cerr<<"A*b = "<<b2<<endl; 800 abort(); 801 } 802 #endif 803 } 804 805 template <class T, class Ta> TriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)806 void TriLDivEq(const GenLowerTriMatrix<Ta>& A, VectorView<T> b) 807 { 808 TMVAssert(b.size() == A.size()); 809 #ifdef XDEBUG 810 Matrix<Ta> A0(A); 811 Vector<T> b0(b); 812 #endif 813 if (b.size() > 0) { 814 if (b.isconj()) TriLDivEq(A.conjugate(),b.conjugate()); 815 else { 816 #ifdef BLAS 817 if (isComplex(T()) && isReal(Ta())) 818 BlasTriLDivEq(A,b); 819 else if ( !((A.isrm() && A.stepi()>0) || 820 (A.iscm() && A.stepj()>0)) ) { 821 if (A.isunit()) { 822 LowerTriMatrix<Ta,UnitDiag|ColMajor> AA = A; 823 BlasTriLDivEq(AA,b); 824 } else { 825 LowerTriMatrix<Ta,NonUnitDiag|ColMajor> AA = A; 826 BlasTriLDivEq(AA,b); 827 } 828 } else { 829 BlasTriLDivEq(A,b); 830 } 831 #else 832 NonBlasTriLDivEq(A,b); 833 #endif 834 } 835 } 836 #ifdef XDEBUG 837 Vector<T> b2 = A0*b; 838 if (Norm(b2-b0) > 0.001*Norm(A0)*Norm(b0)) { 839 cerr<<"TriLDivEq: v/Lower\n"; 840 cerr<<"A = "<<TMV_Text(A)<<" "<<A<<endl; 841 cerr<<"b = "<<TMV_Text(b)<<" "<<b0<<endl; 842 cerr<<"Done: b = "<<b<<endl; 843 cerr<<"A*b = "<<b2<<endl; 844 abort(); 845 } 846 #endif 847 } 848 849 #ifdef INST_INT 850 #undef INST_INT 851 #endif 852 853 #define InstFile "TMV_TriDiv_V.inst" 854 #include "TMV_Inst.h" 855 #undef InstFile 856 857 } // namespace tmv 858 859 860