1 /////////////////////////////////////////////////////////////////////////////// 2 // // 3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis // 4 // Copyright (C) 1998 - 2016 // 5 // All rights reserved // 6 // // 7 // The project is hosted at https://code.google.com/p/tmv-cpp/ // 8 // where you can find the current version and current documention. // 9 // // 10 // For concerns or problems with the software, Mike may be contacted at // 11 // mike_jarvis17 [at] gmail. // 12 // // 13 // This software is licensed under a FreeBSD license. The file // 14 // TMV_LICENSE should have bee included with this distribution. // 15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/. // 16 // // 17 // Essentially, you can use this software however you want provided that // 18 // you include the TMV_LICENSE file in any distribution that uses it. // 19 // // 20 /////////////////////////////////////////////////////////////////////////////// 21 22 23 //#define XDEBUG 24 25 26 #include "TMV_Blas.h" 27 #include "tmv/TMV_SymMatrixArithFunc.h" 28 #include "tmv/TMV_SymMatrix.h" 29 #include "tmv/TMV_Matrix.h" 30 #include "tmv/TMV_SymMatrixArith.h" 31 #include "tmv/TMV_MatrixArith.h" 32 #include "tmv/TMV_VectorArith.h" 33 #ifdef BLAS 34 #include "tmv/TMV_TriMatrixArith.h" 35 #endif 36 37 #ifdef XDEBUG 38 #include <iostream> 39 using std::cout; 40 using std::cerr; 41 using std::endl; 42 #endif 43 44 namespace tmv { 45 46 #ifdef TMV_BLOCKSIZE 47 #define SYM_MM_BLOCKSIZE TMV_BLOCKSIZE 48 #define SYM_MM_BLOCKSIZE2 (TMV_BLOCKSIZE/2) 49 #else 50 #define SYM_MM_BLOCKSIZE 64 51 #define SYM_MM_BLOCKSIZE2 32 52 #endif 53 54 // 55 // MultMM 56 // 57 58 template <bool add, class T, class Ta, class Tb> RRowMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)59 static void RRowMultMM( 60 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 61 MatrixView<T> C) 62 { 63 TMVAssert(A.size() == C.colsize()); 64 TMVAssert(A.size() == B.colsize()); 65 TMVAssert(B.rowsize() == C.rowsize()); 66 TMVAssert(C.colsize() > 0); 67 TMVAssert(C.rowsize() > 0); 68 TMVAssert(A.rowsize() > 0); 69 TMVAssert(alpha != T(0)); 70 TMVAssert(C.ct()==NonConj); 71 TMVAssert(A.uplo() == Lower); 72 73 const ptrdiff_t N = A.size(); 74 for(ptrdiff_t j=0;j<N;++j) { 75 if (add) C.row(j) += alpha * A.row(j,0,j+1) * B.rowRange(0,j+1); 76 else C.row(j) = alpha * A.row(j,0,j+1) * B.rowRange(0,j+1); 77 C.rowRange(0,j) += alpha * A.col(j,0,j) ^ B.row(j); 78 } 79 } 80 81 template <bool add, class T, class Ta, class Tb> CRowMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)82 static void CRowMultMM( 83 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 84 MatrixView<T> C) 85 { 86 TMVAssert(A.size() == C.colsize()); 87 TMVAssert(A.size() == B.colsize()); 88 TMVAssert(B.rowsize() == C.rowsize()); 89 TMVAssert(C.colsize() > 0); 90 TMVAssert(C.rowsize() > 0); 91 TMVAssert(A.rowsize() > 0); 92 TMVAssert(alpha != T(0)); 93 TMVAssert(C.ct()==NonConj); 94 TMVAssert(A.uplo() == Lower); 95 96 const ptrdiff_t N = A.size(); 97 for(ptrdiff_t j=N-1;j>=0;--j) { 98 if (add) C.row(j) += alpha * A.row(j,j,N) * B.rowRange(j,N); 99 else C.row(j) = alpha * A.row(j,j,N) * B.rowRange(j,N); 100 C.rowRange(j+1,N) += alpha * A.col(j,j+1,N) ^ B.row(j); 101 } 102 } 103 104 template <bool add, class T, class Ta, class Tb> RowMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)105 static inline void RowMultMM( 106 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 107 MatrixView<T> C) 108 { 109 if (A.iscm()) CRowMultMM<add>(alpha,A,B,C); 110 else RRowMultMM<add>(alpha,A,B,C); 111 } 112 113 template <bool add, class T, class Ta, class Tb> ColMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)114 static void ColMultMM( 115 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 116 MatrixView<T> C) 117 { 118 TMVAssert(A.size() == C.colsize()); 119 TMVAssert(A.size() == B.colsize()); 120 TMVAssert(B.rowsize() == C.rowsize()); 121 TMVAssert(C.colsize() > 0); 122 TMVAssert(C.rowsize() > 0); 123 TMVAssert(A.rowsize() > 0); 124 TMVAssert(alpha != T(0)); 125 TMVAssert(C.ct()==NonConj); 126 TMVAssert(A.uplo() == Lower); 127 128 const ptrdiff_t N = C.rowsize(); 129 for(ptrdiff_t j=0;j<N;++j) 130 if (add) C.col(j) += alpha * A * B.col(j); 131 else C.col(j) = alpha * A * B.col(j); 132 } 133 134 template <bool add, class T, class Ta, class Tb> RecursiveMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)135 static void RecursiveMultMM( 136 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 137 MatrixView<T> C) 138 { 139 TMVAssert(A.size() == C.colsize()); 140 TMVAssert(A.size() == B.colsize()); 141 TMVAssert(B.rowsize() == C.rowsize()); 142 TMVAssert(C.colsize() > 0); 143 TMVAssert(C.rowsize() > 0); 144 TMVAssert(A.rowsize() > 0); 145 TMVAssert(alpha != T(0)); 146 TMVAssert(C.ct()==NonConj); 147 TMVAssert(A.uplo() == Lower); 148 149 const ptrdiff_t N = A.size(); 150 if (N <= SYM_MM_BLOCKSIZE2) { 151 if (B.isrm() && C.isrm()) RowMultMM<add>(alpha,A,B,C); 152 else if (B.iscm() && C.iscm()) ColMultMM<add>(alpha,A,B,C); 153 else if (C.colsize() < C.rowsize()) RowMultMM<add>(alpha,A,B,C); 154 else ColMultMM<add>(alpha,A,B,C); 155 } else { 156 ptrdiff_t k = N/2; 157 const ptrdiff_t nb = SYM_MM_BLOCKSIZE; 158 if (k > nb) k = k/nb*nb; 159 160 // [ A00 A10t ] [ B0 ] = [ A00 B0 + A10t B1 ] 161 // [ A10 A11 ] [ B1 ] [ A10 B0 + A11 B1 ] 162 163 ConstSymMatrixView<Ta> A00 = A.subSymMatrix(0,k); 164 ConstSymMatrixView<Ta> A11 = A.subSymMatrix(k,N); 165 ConstMatrixView<Ta> A10 = A.subMatrix(k,N,0,k); 166 ConstMatrixView<Tb> B0 = B.rowRange(0,k); 167 ConstMatrixView<Tb> B1 = B.rowRange(k,N); 168 MatrixView<T> C0 = C.rowRange(0,k); 169 MatrixView<T> C1 = C.rowRange(k,N); 170 171 RecursiveMultMM<add>(alpha,A00,B0,C0); 172 RecursiveMultMM<add>(alpha,A11,B1,C1); 173 C1 += alpha * A10 * B0; 174 if (A.issym()) 175 C0 += alpha * A10.transpose() * B1; 176 else 177 C0 += alpha * A10.adjoint() * B1; 178 } 179 } 180 181 template <bool add, class T, class Ta, class Tb> NonBlasMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)182 static void NonBlasMultMM( 183 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 184 MatrixView<T> C) 185 { 186 TMVAssert(A.size() == C.colsize()); 187 TMVAssert(A.size() == B.colsize()); 188 TMVAssert(B.rowsize() == C.rowsize()); 189 TMVAssert(C.colsize() > 0); 190 TMVAssert(C.rowsize() > 0); 191 TMVAssert(A.rowsize() > 0); 192 TMVAssert(alpha != T(0)); 193 194 if (A.uplo() == Upper) 195 if (A.isherm()) NonBlasMultMM<add>(alpha,A.adjoint(),B,C); 196 else NonBlasMultMM<add>(alpha,A.transpose(),B,C); 197 else if (C.isconj()) 198 NonBlasMultMM<add>( 199 TMV_CONJ(alpha),A.conjugate(),B.conjugate(),C.conjugate()); 200 else RecursiveMultMM<add>(alpha,A,B,C); 201 } 202 203 #ifdef BLAS 204 template <class T, class Ta, class Tb> BlasMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,const int beta,MatrixView<T> C)205 static inline void BlasMultMM( 206 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 207 const int beta, MatrixView<T> C) 208 { 209 if (beta == 1) NonBlasMultMM<true>(alpha,A,B,C); 210 else NonBlasMultMM<false>(alpha,A,B,C); 211 } 212 #ifdef INST_DOUBLE 213 template <> BlasMultMM(const double alpha,const GenSymMatrix<double> & A,const GenMatrix<double> & B,const int beta,MatrixView<double> C)214 void BlasMultMM( 215 const double alpha, const GenSymMatrix<double>& A, 216 const GenMatrix<double>& B, const int beta, MatrixView<double> C) 217 { 218 int m = C.iscm() ? C.colsize() : C.rowsize(); 219 int n = C.iscm() ? C.rowsize() : C.colsize(); 220 int lda = A.stepj(); 221 int ldb = B.iscm()?B.stepj():B.stepi(); 222 int ldc = C.iscm()?C.stepj():C.stepi(); 223 if (beta == 0) C.setZero(); 224 double xbeta(1); 225 BLASNAME(dsymm) ( 226 BLASCM C.iscm()?BLASCH_L:BLASCH_R, 227 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO, 228 BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda), 229 BLASP(B.cptr()),BLASV(ldb),BLASV(xbeta), 230 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1); 231 } 232 template <> BlasMultMM(std::complex<double> alpha,const GenSymMatrix<std::complex<double>> & A,const GenMatrix<std::complex<double>> & B,const int beta,MatrixView<std::complex<double>> C)233 void BlasMultMM( 234 std::complex<double> alpha, 235 const GenSymMatrix<std::complex<double> >& A, 236 const GenMatrix<std::complex<double> >& B, 237 const int beta, MatrixView<std::complex<double> > C) 238 { 239 int m = C.iscm() ? C.colsize() : C.rowsize(); 240 int n = C.iscm() ? C.rowsize() : C.colsize(); 241 int lda = A.stepj(); 242 int ldb = B.iscm()?B.stepj():B.stepi(); 243 int ldc = C.iscm()?C.stepj():C.stepi(); 244 if (beta == 0) C.setZero(); 245 std::complex<double> xbeta(1); 246 if (A.issym()) 247 BLASNAME(zsymm) ( 248 BLASCM C.iscm()?BLASCH_L:BLASCH_R, 249 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO, 250 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 251 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta), 252 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1); 253 else { 254 if (!C.iscm()) alpha = TMV_CONJ(alpha); 255 BLASNAME(zhemm) ( 256 BLASCM C.iscm()?BLASCH_L:BLASCH_R, 257 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO, 258 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 259 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta), 260 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1); 261 } 262 } 263 #endif 264 #ifdef INST_FLOAT 265 template <> BlasMultMM(const float alpha,const GenSymMatrix<float> & A,const GenMatrix<float> & B,const int beta,MatrixView<float> C)266 void BlasMultMM( 267 const float alpha, const GenSymMatrix<float>& A, 268 const GenMatrix<float>& B, const int beta, MatrixView<float> C) 269 { 270 int m = C.iscm() ? C.colsize() : C.rowsize(); 271 int n = C.iscm() ? C.rowsize() : C.colsize(); 272 int lda = A.stepj(); 273 int ldb = B.iscm()?B.stepj():B.stepi(); 274 int ldc = C.iscm()?C.stepj():C.stepi(); 275 if (beta == 0) C.setZero(); 276 float xbeta(1); 277 BLASNAME(ssymm) ( 278 BLASCM C.iscm()?BLASCH_L:BLASCH_R, 279 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO, 280 BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda), 281 BLASP(B.cptr()),BLASV(ldb),BLASV(xbeta), 282 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1); 283 } 284 template <> BlasMultMM(std::complex<float> alpha,const GenSymMatrix<std::complex<float>> & A,const GenMatrix<std::complex<float>> & B,const int beta,MatrixView<std::complex<float>> C)285 void BlasMultMM( 286 std::complex<float> alpha, 287 const GenSymMatrix<std::complex<float> >& A, 288 const GenMatrix<std::complex<float> >& B, 289 const int beta, MatrixView<std::complex<float> > C) 290 { 291 int m = C.iscm() ? C.colsize() : C.rowsize(); 292 int n = C.iscm() ? C.rowsize() : C.colsize(); 293 int lda = A.stepj(); 294 int ldb = B.iscm()?B.stepj():B.stepi(); 295 int ldc = C.iscm()?C.stepj():C.stepi(); 296 if (beta == 0) C.setZero(); 297 std::complex<float> xbeta(1); 298 if (A.issym()) 299 BLASNAME(csymm) ( 300 BLASCM C.iscm()?BLASCH_L:BLASCH_R, 301 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO, 302 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 303 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta), 304 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1); 305 else { 306 if (!C.iscm()) alpha = TMV_CONJ(alpha); 307 BLASNAME(chemm) ( 308 BLASCM C.iscm()?BLASCH_L:BLASCH_R, 309 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO, 310 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda), 311 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta), 312 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1); 313 } 314 } 315 #endif 316 template <class T> BlasMultMM(const std::complex<T> alpha,const GenSymMatrix<std::complex<T>> & A,const GenMatrix<T> & B,const int beta,MatrixView<std::complex<T>> C)317 static void BlasMultMM( 318 const std::complex<T> alpha, 319 const GenSymMatrix<std::complex<T> >& A, const GenMatrix<T>& B, 320 const int beta, MatrixView<std::complex<T> > C) 321 { 322 if (TMV_IMAG(alpha) == T(0)) { 323 SymMatrix<T,Lower|ColMajor> A1 = A.realPart(); 324 Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*A1*B; 325 if (beta == 0) C.realPart() = C1; 326 else C.realPart() += C1; 327 if (A.issym()) { 328 A1 = A.imagPart(); 329 if (C.isconj()) C1 = -TMV_REAL(alpha)*A1*B; 330 else C1 = TMV_REAL(alpha)*A1*B; 331 } else { 332 LowerTriMatrixView<T> L = A1.lowerTri(); 333 L = A.lowerTri().imagPart(); 334 // A.imagPart() = L - LT 335 if (A.lowerTri().isconj() != C.isconj()) { 336 C1 = -TMV_REAL(alpha)*L*B; 337 C1 += TMV_REAL(alpha)*L.transpose()*B; 338 } else { 339 C1 = TMV_REAL(alpha)*L*B; 340 C1 -= TMV_REAL(alpha)*L.transpose()*B; 341 } 342 } 343 if (beta == 0) C.imagPart() = C1; 344 else C.imagPart() += C1; 345 } else { 346 SymMatrix<T,Lower|ColMajor> Ar = A.realPart(); 347 SymMatrix<T,Lower|ColMajor> Ai(A.size()); 348 LowerTriMatrixView<T> L = Ai.lowerTri(); 349 Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*Ar*B; 350 if (A.issym()) { 351 Ai = A.imagPart(); 352 C1 -= TMV_IMAG(alpha)*Ai*B; 353 } else { 354 L = A.lowerTri().imagPart(); 355 if (A.lowerTri().isconj()) L *= T(-1); 356 C1 -= TMV_IMAG(alpha)*L*B; 357 C1 += TMV_IMAG(alpha)*L.transpose()*B; 358 } 359 if (beta == 0) C.realPart() = C1; 360 else C.realPart() += C1; 361 C1 = TMV_IMAG(alpha)*Ar*B; 362 if (A.issym()) { 363 C1 += TMV_REAL(alpha)*Ai*B; 364 } else { 365 C1 += TMV_REAL(alpha)*L*B; 366 C1 -= TMV_REAL(alpha)*L.transpose()*B; 367 } 368 if (C.isconj()) C1 *= T(-1); 369 if (beta == 0) C.imagPart() = C1; 370 else C.imagPart() += C1; 371 } 372 } 373 template <class T> BlasMultMM(const std::complex<T> alpha,const GenSymMatrix<T> & A,const GenMatrix<std::complex<T>> & B,const int beta,MatrixView<std::complex<T>> C)374 static void BlasMultMM( 375 const std::complex<T> alpha, 376 const GenSymMatrix<T>& A, const GenMatrix<std::complex<T> >& B, 377 const int beta, MatrixView<std::complex<T> > C) 378 { 379 if (TMV_IMAG(alpha) == T(0)) { 380 Matrix<T,ColMajor> B1 = B.realPart(); 381 Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*A*B1; 382 if (beta == 0) C.realPart() = C1; 383 else C.realPart() += C1; 384 B1 = B.imagPart(); 385 if (B.isconj()) C1 = -TMV_REAL(alpha)*A*B1; 386 else C1 = TMV_REAL(alpha)*A*B1; 387 if (beta == 0) C.imagPart() = C1; 388 else C.imagPart() += C1; 389 } else { 390 Matrix<T,ColMajor> Br = B.realPart(); 391 Matrix<T,ColMajor> Bi = B.imagPart(); 392 Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*A*Br; 393 if (B.isconj()) C1 += TMV_IMAG(alpha)*A*Bi; 394 else C1 -= TMV_IMAG(alpha)*A*Bi; 395 if (beta == 0) C.realPart() = C1; 396 else C.realPart() += C1; 397 398 if (B.isconj()) C1 = -TMV_REAL(alpha)*A*Bi; 399 else C1 = TMV_REAL(alpha)*A*Bi; 400 C1 += TMV_IMAG(alpha)*A*Br; 401 if (beta == 0) C.imagPart() = C1; 402 else C.imagPart() += C1; 403 } 404 } 405 template <class T> BlasMultMM(const std::complex<T> alpha,const GenSymMatrix<T> & A,const GenMatrix<T> & B,const int beta,MatrixView<std::complex<T>> C)406 static void BlasMultMM( 407 const std::complex<T> alpha, 408 const GenSymMatrix<T>& A, const GenMatrix<T>& B, 409 const int beta, MatrixView<std::complex<T> > C) 410 { 411 Matrix<T,ColMajor> C1 = A*B; 412 if (beta == 0) C = alpha*C1; 413 else C += alpha*C1; 414 } 415 #endif // BLAS 416 417 template <bool add, class T, class Ta, class Tb> DoMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)418 static void DoMultMM( 419 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 420 MatrixView<T> C) 421 { 422 TMVAssert(A.size() == C.colsize()); 423 TMVAssert(A.size() == B.colsize()); 424 TMVAssert(B.rowsize() == C.rowsize()); 425 TMVAssert(C.colsize() > 0); 426 TMVAssert(C.rowsize() > 0); 427 TMVAssert(A.rowsize() > 0); 428 TMVAssert(alpha != T(0)); 429 430 #ifdef BLAS 431 if (A.isrm()) 432 DoMultMM<add>(alpha,A.issym()?A.transpose():A.adjoint(),B,C); 433 else if (A.isconj()) 434 DoMultMM<add>( 435 TMV_CONJ(alpha),A.conjugate(),B.conjugate(),C.conjugate()); 436 else if ( !((C.isrm() && C.stepi()>0) || (C.iscm() && C.stepj()>0)) || 437 (C.iscm() && C.isconj()) || 438 (C.isrm() && C.isconj()==A.issym()) ) { 439 Matrix<T,ColMajor> C2(C.colsize(),C.rowsize()); 440 DoMultMM<false>(T(1),A,B,C2.view()); 441 if (add) C += alpha*C2; 442 else C = alpha*C2; 443 } else if (!(A.iscm() && A.stepj()>0)) { 444 if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 445 if (A.isherm()) { 446 if (A.uplo() == Upper) { 447 HermMatrix<Ta,Upper|ColMajor> A2 = TMV_REAL(alpha)*A; 448 DoMultMM<add>(T(1),A2,B,C); 449 } else { 450 HermMatrix<Ta,Lower|ColMajor> A2 = TMV_REAL(alpha)*A; 451 DoMultMM<add>(T(1),A2,B,C); 452 } 453 } else { 454 if (A.uplo() == Upper) { 455 SymMatrix<Ta,Upper|ColMajor> A2 = TMV_REAL(alpha)*A; 456 DoMultMM<add>(T(1),A2,B,C); 457 } else { 458 SymMatrix<Ta,Lower|ColMajor> A2 = TMV_REAL(alpha)*A; 459 DoMultMM<add>(T(1),A2,B,C); 460 } 461 } 462 } else { 463 if (!A.issym()) { 464 if (A.uplo() == Upper) { 465 // alpha * A is not Hermitian, so can't do 466 // A2 = alpha * A 467 HermMatrix<Ta,Upper|ColMajor> A2 = A; 468 DoMultMM<add>(alpha,A2,B,C); 469 } else { 470 HermMatrix<Ta,Lower|ColMajor> A2 = A; 471 DoMultMM<add>(alpha,A2,B,C); 472 } 473 } else { 474 if (A.uplo() == Upper) { 475 SymMatrix<T,Upper|ColMajor> A2 = alpha*A; 476 DoMultMM<add>(T(1),A2,B,C); 477 } else { 478 SymMatrix<T,Lower|ColMajor> A2 = alpha*A; 479 DoMultMM<add>(T(1),A2,B,C); 480 } 481 } 482 } 483 } else if (!(B.isrm()==C.isrm() && B.iscm()==C.iscm()) || 484 (isComplex(Tb()) && B.isconj() != C.isconj()) || 485 !((B.isrm() && B.stepi()>0) || (B.iscm() && B.stepj()>0))) { 486 if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 487 if (C.isconj()) { 488 if (C.iscm()) { 489 Matrix<Tb,ColMajor> B2 = TMV_REAL(alpha)*B.conjugate(); 490 DoMultMM<add>(T(1),A,B2.conjugate(),C); 491 } else { 492 Matrix<Tb,RowMajor> B2 = TMV_REAL(alpha)*B.conjugate(); 493 DoMultMM<add>(T(1),A,B2.conjugate(),C); 494 } 495 } else { 496 if (C.iscm()) { 497 Matrix<Tb,ColMajor> B2 = TMV_REAL(alpha)*B; 498 DoMultMM<add>(T(1),A,B2,C); 499 } else { 500 Matrix<Tb,RowMajor> B2 = TMV_REAL(alpha)*B; 501 DoMultMM<add>(T(1),A,B2,C); 502 } 503 } 504 } else { 505 if (C.isconj()) { 506 if (C.iscm()) { 507 Matrix<T,ColMajor> B2 = TMV_CONJ(alpha)*B.conjugate(); 508 DoMultMM<add>(T(1),A,B2.conjugate(),C); 509 } else { 510 Matrix<T,RowMajor> B2 = TMV_CONJ(alpha)*B.conjugate(); 511 DoMultMM<add>(T(1),A,B2.conjugate(),C); 512 } 513 } else { 514 if (C.iscm()) { 515 Matrix<T,ColMajor> B2 = alpha*B; 516 DoMultMM<add>(T(1),A,B2,C); 517 } else { 518 Matrix<T,RowMajor> B2 = alpha*B; 519 DoMultMM<add>(T(1),A,B2,C); 520 } 521 } 522 } 523 } else { 524 BlasMultMM(alpha,A,B,add?1:0,C); 525 } 526 #else 527 NonBlasMultMM<add>(alpha,A,B,C); 528 #endif 529 } 530 531 template <bool add, class T, class Ta, class Tb> FullTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)532 static void FullTempMultMM( 533 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 534 MatrixView<T> C) 535 { 536 if (C.isrm()) { 537 Matrix<T,RowMajor> C2(C.colsize(),C.rowsize()); 538 DoMultMM<false>(T(1),A,B,C2.view()); 539 if (add) C += alpha*C2; 540 else C = alpha*C2; 541 } else { 542 Matrix<T,ColMajor> C2(C.colsize(),C.rowsize()); 543 DoMultMM<false>(T(1),A,B,C2.view()); 544 if (add) C += alpha*C2; 545 else C = alpha*C2; 546 } 547 } 548 549 template <bool add, class T, class Ta, class Tb> BlockTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)550 static void BlockTempMultMM( 551 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 552 MatrixView<T> C) 553 { 554 const ptrdiff_t N = C.rowsize(); 555 for(ptrdiff_t j=0;j<N;) { 556 ptrdiff_t j2 = TMV_MIN(N,j+SYM_MM_BLOCKSIZE); 557 if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 558 if (C.isrm()) { 559 Matrix<Tb,RowMajor> B2 = TMV_REAL(alpha) * B.colRange(j,j2); 560 DoMultMM<add>(T(1),A,B2,C.colRange(j,j2)); 561 } else { 562 Matrix<Tb,ColMajor> B2 = TMV_REAL(alpha) * B.colRange(j,j2); 563 DoMultMM<add>(T(1),A,B2,C.colRange(j,j2)); 564 } 565 } else { 566 if (C.isrm()) { 567 Matrix<T,RowMajor> B2 = alpha * B.colRange(j,j2); 568 DoMultMM<add>(T(1),A,B2,C.colRange(j,j2)); 569 } else { 570 Matrix<T,ColMajor> B2 = alpha * B.colRange(j,j2); 571 DoMultMM<add>(T(1),A,B2,C.colRange(j,j2)); 572 } 573 } 574 j = j2; 575 } 576 } 577 578 template <bool add, class T, class Ta, class Tb> MultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)579 void MultMM( 580 const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B, 581 MatrixView<T> C) 582 // C (+)= alpha * A * B 583 { 584 TMVAssert(A.size() == C.colsize()); 585 TMVAssert(A.size() == B.colsize()); 586 TMVAssert(B.rowsize() == C.rowsize()); 587 #ifdef XDEBUG 588 //cout<<"Start MultMM: alpha = "<<alpha<<endl; 589 //cout<<"A = "<<A.cptr()<<" "<<TMV_Text(A)<<" "<<A<<endl; 590 //cout<<"B = "<<B.cptr()<<" "<<TMV_Text(B)<<" "<<B<<endl; 591 //cout<<"C = "<<C.cptr()<<" "<<TMV_Text(C)<<" "<<C<<endl; 592 Matrix<Ta> A0 = A; 593 Matrix<Tb> B0 = B; 594 Matrix<T> C0 = C; 595 Matrix<T> C2 = alpha*A0*B0; 596 if (add) C2 += C0; 597 #endif 598 599 if (C.colsize() > 0 && C.rowsize() > 0) { 600 if (alpha == T(0)) { 601 if (!add) C.setZero(); 602 } 603 else if (SameStorage(A,C)) 604 FullTempMultMM<add>(alpha,A,B,C); 605 else if (SameStorage(B,C)) 606 if (C.stepi() == B.stepi() && C.stepj() == B.stepj()) 607 BlockTempMultMM<add>(alpha,A,B,C); 608 else 609 FullTempMultMM<add>(alpha,A,B,C); 610 else DoMultMM<add>(alpha, A, B, C); 611 } 612 613 #ifdef XDEBUG 614 //cout<<"Done: C = "<<C<<endl; 615 if (!(Norm(C-C2) <= 616 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(B0)+ 617 (add?Norm(C0):TMV_RealType(T)(0))))) { 618 cerr<<"MultMM: alpha = "<<alpha<<endl; 619 cerr<<"add = "<<add<<endl; 620 cerr<<"A = "<<TMV_Text(A)<<" "<<A0<<endl; 621 cerr<<"B = "<<TMV_Text(B)<<" "<<B0<<endl; 622 cerr<<"C = "<<TMV_Text(C)<<" "<<C0<<endl; 623 cerr<<"--> C = "<<C<<endl; 624 cerr<<"C2 = "<<C2<<endl; 625 abort(); 626 } 627 #endif 628 } 629 630 template <bool add, class T, class Ta, class Tb> BlockTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenSymMatrix<Tb> & B,MatrixView<T> C)631 static void BlockTempMultMM( 632 const T alpha, const GenSymMatrix<Ta>& A, const GenSymMatrix<Tb>& B, 633 MatrixView<T> C) 634 { 635 TMVAssert(A.size() == B.size()); 636 TMVAssert(A.size() == C.colsize()); 637 TMVAssert(A.size() == C.rowsize()); 638 TMVAssert(A.size() > 0); 639 TMVAssert(alpha != T(0)); 640 641 const ptrdiff_t N = A.size(); 642 643 for(ptrdiff_t j=0;j<N;) { 644 ptrdiff_t j2 = TMV_MIN(N,j+SYM_MM_BLOCKSIZE); 645 if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) { 646 if (C.isrm()) { 647 Matrix<Tb,RowMajor> B2(N,j2-j); 648 B2.rowRange(0,j) = TMV_REAL(alpha) * B.subMatrix(0,j,j,j2); 649 B2.rowRange(j,j2) = TMV_REAL(alpha) * B.subSymMatrix(j,j2); 650 B2.rowRange(j2,N) = TMV_REAL(alpha) * B.subMatrix(j2,N,j,j2); 651 DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2)); 652 } else { 653 Matrix<Tb,ColMajor> B2(N,j2-j); 654 B2.rowRange(0,j) = TMV_REAL(alpha) * B.subMatrix(0,j,j,j2); 655 B2.rowRange(j,j2) = TMV_REAL(alpha) * B.subSymMatrix(j,j2); 656 B2.rowRange(j2,N) = TMV_REAL(alpha) * B.subMatrix(j2,N,j,j2); 657 DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2)); 658 } 659 } else { 660 if (C.isrm()) { 661 Matrix<T,RowMajor> B2(N,j2-j); 662 B2.rowRange(0,j) = alpha * B.subMatrix(0,j,j,j2); 663 B2.rowRange(j,j2) = alpha * B.subSymMatrix(j,j2); 664 B2.rowRange(j2,N) = alpha * B.subMatrix(j2,N,j,j2); 665 DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2)); 666 } else { 667 Matrix<T,ColMajor> B2(N,j2-j); 668 B2.rowRange(0,j) = alpha * B.subMatrix(0,j,j,j2); 669 B2.rowRange(j,j2) = alpha * B.subSymMatrix(j,j2); 670 B2.rowRange(j2,N) = alpha * B.subMatrix(j2,N,j,j2); 671 DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2)); 672 } 673 } 674 j = j2; 675 } 676 } 677 678 template <bool add, class T, class Ta, class Tb> FullTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenSymMatrix<Tb> & B,MatrixView<T> C)679 static void FullTempMultMM( 680 const T alpha, const GenSymMatrix<Ta>& A, const GenSymMatrix<Tb>& B, 681 MatrixView<T> C) 682 { 683 if (C.isrm()) { 684 Matrix<T,RowMajor> C2(C.colsize(),C.rowsize()); 685 BlockTempMultMM<false>(T(1),A,B,C2.view()); 686 if (add) C += alpha*C2; 687 else C = alpha*C2; 688 } else { 689 Matrix<T,ColMajor> C2(C.colsize(),C.rowsize()); 690 BlockTempMultMM<false>(T(1),A,B,C2.view()); 691 if (add) C += alpha*C2; 692 else C = alpha*C2; 693 } 694 } 695 696 template <bool add, class T, class Ta, class Tb> MultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenSymMatrix<Tb> & B,MatrixView<T> C)697 void MultMM( 698 const T alpha, const GenSymMatrix<Ta>& A, 699 const GenSymMatrix<Tb>& B, MatrixView<T> C) 700 // C (+)= alpha * A * B 701 { 702 TMVAssert(A.size() == B.size()); 703 TMVAssert(A.size() == C.colsize()); 704 TMVAssert(A.size() == C.rowsize()); 705 #ifdef XDEBUG 706 //cout<<"Start MultMM: alpha = "<<alpha<<endl; 707 //cout<<"A = "<<A.cptr()<<" "<<TMV_Text(A)<<" "<<A<<endl; 708 //cout<<"B = "<<B.cptr()<<" "<<TMV_Text(B)<<" "<<B<<endl; 709 //cout<<"C = "<<C.cptr()<<" "<<TMV_Text(C)<<" "<<C<<endl; 710 Matrix<Ta> A0 = A; 711 Matrix<Tb> B0 = B; 712 Matrix<T> C0 = C; 713 Matrix<T> C2 = alpha*A0*B0; 714 if (add) C2 += C0; 715 #endif 716 717 if (A.size() > 0) { 718 if (SameStorage(A,C) || SameStorage(B,C)) 719 FullTempMultMM<add>(alpha,A,B,C); 720 else BlockTempMultMM<add>(alpha, A, B, C); 721 } 722 723 #ifdef XDEBUG 724 //cout<<"done: C = "<<C<<endl; 725 if (!(Norm(C-C2) <= 726 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(B0)+ 727 (add?Norm(C0):TMV_RealType(T)(0))))) { 728 cerr<<"MultMM: alpha = "<<alpha<<endl; 729 cerr<<"add = "<<add<<endl; 730 cerr<<"A = "<<TMV_Text(A)<<" "<<A0<<endl; 731 cerr<<"B = "<<TMV_Text(B)<<" "<<B0<<endl; 732 cerr<<"C = "<<TMV_Text(C)<<" "<<C0<<endl; 733 cerr<<"--> C = "<<C<<endl; 734 cerr<<"C2 = "<<C2<<endl; 735 abort(); 736 } 737 #endif 738 } 739 740 #define InstFile "TMV_MultSM.inst" 741 #include "TMV_Inst.h" 742 #undef InstFile 743 744 } // namespace tmv 745 746 747