1 ///////////////////////////////////////////////////////////////////////////////
2 //                                                                           //
3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis     //
4 // Copyright (C) 1998 - 2016                                                 //
5 // All rights reserved                                                       //
6 //                                                                           //
7 // The project is hosted at https://code.google.com/p/tmv-cpp/               //
8 // where you can find the current version and current documention.           //
9 //                                                                           //
10 // For concerns or problems with the software, Mike may be contacted at      //
11 // mike_jarvis17 [at] gmail.                                                 //
12 //                                                                           //
13 // This software is licensed under a FreeBSD license.  The file              //
14 // TMV_LICENSE should have bee included with this distribution.              //
15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/.       //
16 //                                                                           //
17 // Essentially, you can use this software however you want provided that     //
18 // you include the TMV_LICENSE file in any distribution that uses it.        //
19 //                                                                           //
20 ///////////////////////////////////////////////////////////////////////////////
21 
22 
23 //#define XDEBUG
24 
25 
26 #include "TMV_Blas.h"
27 #include "tmv/TMV_SymMatrixArithFunc.h"
28 #include "tmv/TMV_SymMatrix.h"
29 #include "tmv/TMV_Matrix.h"
30 #include "tmv/TMV_SymMatrixArith.h"
31 #include "tmv/TMV_MatrixArith.h"
32 #include "tmv/TMV_VectorArith.h"
33 #ifdef BLAS
34 #include "tmv/TMV_TriMatrixArith.h"
35 #endif
36 
37 #ifdef XDEBUG
38 #include <iostream>
39 using std::cout;
40 using std::cerr;
41 using std::endl;
42 #endif
43 
44 namespace tmv {
45 
46 #ifdef TMV_BLOCKSIZE
47 #define SYM_MM_BLOCKSIZE TMV_BLOCKSIZE
48 #define SYM_MM_BLOCKSIZE2 (TMV_BLOCKSIZE/2)
49 #else
50 #define SYM_MM_BLOCKSIZE 64
51 #define SYM_MM_BLOCKSIZE2 32
52 #endif
53 
54     //
55     // MultMM
56     //
57 
58     template <bool add, class T, class Ta, class Tb>
RRowMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)59     static void RRowMultMM(
60         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
61         MatrixView<T> C)
62     {
63         TMVAssert(A.size() == C.colsize());
64         TMVAssert(A.size() == B.colsize());
65         TMVAssert(B.rowsize() == C.rowsize());
66         TMVAssert(C.colsize() > 0);
67         TMVAssert(C.rowsize() > 0);
68         TMVAssert(A.rowsize() > 0);
69         TMVAssert(alpha != T(0));
70         TMVAssert(C.ct()==NonConj);
71         TMVAssert(A.uplo() == Lower);
72 
73         const ptrdiff_t N = A.size();
74         for(ptrdiff_t j=0;j<N;++j) {
75             if (add) C.row(j) += alpha * A.row(j,0,j+1) * B.rowRange(0,j+1);
76             else C.row(j) = alpha * A.row(j,0,j+1) * B.rowRange(0,j+1);
77             C.rowRange(0,j) += alpha * A.col(j,0,j) ^ B.row(j);
78         }
79     }
80 
81     template <bool add, class T, class Ta, class Tb>
CRowMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)82     static void CRowMultMM(
83         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
84         MatrixView<T> C)
85     {
86         TMVAssert(A.size() == C.colsize());
87         TMVAssert(A.size() == B.colsize());
88         TMVAssert(B.rowsize() == C.rowsize());
89         TMVAssert(C.colsize() > 0);
90         TMVAssert(C.rowsize() > 0);
91         TMVAssert(A.rowsize() > 0);
92         TMVAssert(alpha != T(0));
93         TMVAssert(C.ct()==NonConj);
94         TMVAssert(A.uplo() == Lower);
95 
96         const ptrdiff_t N = A.size();
97         for(ptrdiff_t j=N-1;j>=0;--j) {
98             if (add) C.row(j) += alpha * A.row(j,j,N) * B.rowRange(j,N);
99             else C.row(j) = alpha * A.row(j,j,N) * B.rowRange(j,N);
100             C.rowRange(j+1,N) += alpha * A.col(j,j+1,N) ^ B.row(j);
101         }
102     }
103 
104     template <bool add, class T, class Ta, class Tb>
RowMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)105     static inline void RowMultMM(
106         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
107         MatrixView<T> C)
108     {
109         if (A.iscm()) CRowMultMM<add>(alpha,A,B,C);
110         else RRowMultMM<add>(alpha,A,B,C);
111     }
112 
113     template <bool add, class T, class Ta, class Tb>
ColMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)114     static void ColMultMM(
115         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
116         MatrixView<T> C)
117     {
118         TMVAssert(A.size() == C.colsize());
119         TMVAssert(A.size() == B.colsize());
120         TMVAssert(B.rowsize() == C.rowsize());
121         TMVAssert(C.colsize() > 0);
122         TMVAssert(C.rowsize() > 0);
123         TMVAssert(A.rowsize() > 0);
124         TMVAssert(alpha != T(0));
125         TMVAssert(C.ct()==NonConj);
126         TMVAssert(A.uplo() == Lower);
127 
128         const ptrdiff_t N = C.rowsize();
129         for(ptrdiff_t j=0;j<N;++j)
130             if (add) C.col(j) += alpha * A * B.col(j);
131             else C.col(j) = alpha * A * B.col(j);
132     }
133 
134     template <bool add, class T, class Ta, class Tb>
RecursiveMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)135     static void RecursiveMultMM(
136         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
137         MatrixView<T> C)
138     {
139         TMVAssert(A.size() == C.colsize());
140         TMVAssert(A.size() == B.colsize());
141         TMVAssert(B.rowsize() == C.rowsize());
142         TMVAssert(C.colsize() > 0);
143         TMVAssert(C.rowsize() > 0);
144         TMVAssert(A.rowsize() > 0);
145         TMVAssert(alpha != T(0));
146         TMVAssert(C.ct()==NonConj);
147         TMVAssert(A.uplo() == Lower);
148 
149         const ptrdiff_t N = A.size();
150         if (N <= SYM_MM_BLOCKSIZE2) {
151             if (B.isrm() && C.isrm()) RowMultMM<add>(alpha,A,B,C);
152             else if (B.iscm() && C.iscm()) ColMultMM<add>(alpha,A,B,C);
153             else if (C.colsize() < C.rowsize()) RowMultMM<add>(alpha,A,B,C);
154             else ColMultMM<add>(alpha,A,B,C);
155         } else {
156             ptrdiff_t k = N/2;
157             const ptrdiff_t nb = SYM_MM_BLOCKSIZE;
158             if (k > nb) k = k/nb*nb;
159 
160             // [ A00 A10t ] [ B0 ] = [ A00 B0 + A10t B1 ]
161             // [ A10 A11  ] [ B1 ]   [ A10 B0 + A11 B1  ]
162 
163             ConstSymMatrixView<Ta> A00 = A.subSymMatrix(0,k);
164             ConstSymMatrixView<Ta> A11 = A.subSymMatrix(k,N);
165             ConstMatrixView<Ta> A10 = A.subMatrix(k,N,0,k);
166             ConstMatrixView<Tb> B0 = B.rowRange(0,k);
167             ConstMatrixView<Tb> B1 = B.rowRange(k,N);
168             MatrixView<T> C0 = C.rowRange(0,k);
169             MatrixView<T> C1 = C.rowRange(k,N);
170 
171             RecursiveMultMM<add>(alpha,A00,B0,C0);
172             RecursiveMultMM<add>(alpha,A11,B1,C1);
173             C1 += alpha * A10 * B0;
174             if (A.issym())
175                 C0 += alpha * A10.transpose() * B1;
176             else
177                 C0 += alpha * A10.adjoint() * B1;
178         }
179     }
180 
181     template <bool add, class T, class Ta, class Tb>
NonBlasMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)182     static void NonBlasMultMM(
183         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
184         MatrixView<T> C)
185     {
186         TMVAssert(A.size() == C.colsize());
187         TMVAssert(A.size() == B.colsize());
188         TMVAssert(B.rowsize() == C.rowsize());
189         TMVAssert(C.colsize() > 0);
190         TMVAssert(C.rowsize() > 0);
191         TMVAssert(A.rowsize() > 0);
192         TMVAssert(alpha != T(0));
193 
194         if (A.uplo() == Upper)
195             if (A.isherm()) NonBlasMultMM<add>(alpha,A.adjoint(),B,C);
196             else NonBlasMultMM<add>(alpha,A.transpose(),B,C);
197         else if (C.isconj())
198             NonBlasMultMM<add>(
199                 TMV_CONJ(alpha),A.conjugate(),B.conjugate(),C.conjugate());
200         else RecursiveMultMM<add>(alpha,A,B,C);
201     }
202 
203 #ifdef BLAS
204     template <class T, class Ta, class Tb>
BlasMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,const int beta,MatrixView<T> C)205     static inline void BlasMultMM(
206         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
207         const int beta, MatrixView<T> C)
208     {
209         if (beta == 1) NonBlasMultMM<true>(alpha,A,B,C);
210         else NonBlasMultMM<false>(alpha,A,B,C);
211     }
212 #ifdef INST_DOUBLE
213     template <>
BlasMultMM(const double alpha,const GenSymMatrix<double> & A,const GenMatrix<double> & B,const int beta,MatrixView<double> C)214     void BlasMultMM(
215         const double alpha, const GenSymMatrix<double>& A,
216         const GenMatrix<double>& B, const int beta, MatrixView<double> C)
217     {
218         int m = C.iscm() ? C.colsize() : C.rowsize();
219         int n = C.iscm() ? C.rowsize() : C.colsize();
220         int lda = A.stepj();
221         int ldb = B.iscm()?B.stepj():B.stepi();
222         int ldc = C.iscm()?C.stepj():C.stepi();
223         if (beta == 0) C.setZero();
224         double xbeta(1);
225         BLASNAME(dsymm) (
226             BLASCM C.iscm()?BLASCH_L:BLASCH_R,
227             A.uplo() == Upper ? BLASCH_UP : BLASCH_LO,
228             BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda),
229             BLASP(B.cptr()),BLASV(ldb),BLASV(xbeta),
230             BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1);
231     }
232     template <>
BlasMultMM(std::complex<double> alpha,const GenSymMatrix<std::complex<double>> & A,const GenMatrix<std::complex<double>> & B,const int beta,MatrixView<std::complex<double>> C)233     void BlasMultMM(
234         std::complex<double> alpha,
235         const GenSymMatrix<std::complex<double> >& A,
236         const GenMatrix<std::complex<double> >& B,
237         const int beta, MatrixView<std::complex<double> > C)
238     {
239         int m = C.iscm() ? C.colsize() : C.rowsize();
240         int n = C.iscm() ? C.rowsize() : C.colsize();
241         int lda = A.stepj();
242         int ldb = B.iscm()?B.stepj():B.stepi();
243         int ldc = C.iscm()?C.stepj():C.stepi();
244         if (beta == 0) C.setZero();
245         std::complex<double> xbeta(1);
246         if (A.issym())
247             BLASNAME(zsymm) (
248                 BLASCM C.iscm()?BLASCH_L:BLASCH_R,
249                 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO,
250                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
251                 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta),
252                 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1);
253         else {
254             if (!C.iscm()) alpha = TMV_CONJ(alpha);
255             BLASNAME(zhemm) (
256                 BLASCM C.iscm()?BLASCH_L:BLASCH_R,
257                 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO,
258                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
259                 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta),
260                 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1);
261         }
262     }
263 #endif
264 #ifdef INST_FLOAT
265     template <>
BlasMultMM(const float alpha,const GenSymMatrix<float> & A,const GenMatrix<float> & B,const int beta,MatrixView<float> C)266     void BlasMultMM(
267         const float alpha, const GenSymMatrix<float>& A,
268         const GenMatrix<float>& B, const int beta, MatrixView<float> C)
269     {
270         int m = C.iscm() ? C.colsize() : C.rowsize();
271         int n = C.iscm() ? C.rowsize() : C.colsize();
272         int lda = A.stepj();
273         int ldb = B.iscm()?B.stepj():B.stepi();
274         int ldc = C.iscm()?C.stepj():C.stepi();
275         if (beta == 0) C.setZero();
276         float xbeta(1);
277         BLASNAME(ssymm) (
278             BLASCM C.iscm()?BLASCH_L:BLASCH_R,
279             A.uplo() == Upper ? BLASCH_UP : BLASCH_LO,
280             BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda),
281             BLASP(B.cptr()),BLASV(ldb),BLASV(xbeta),
282             BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1);
283     }
284     template <>
BlasMultMM(std::complex<float> alpha,const GenSymMatrix<std::complex<float>> & A,const GenMatrix<std::complex<float>> & B,const int beta,MatrixView<std::complex<float>> C)285     void BlasMultMM(
286         std::complex<float> alpha,
287         const GenSymMatrix<std::complex<float> >& A,
288         const GenMatrix<std::complex<float> >& B,
289         const int beta, MatrixView<std::complex<float> > C)
290     {
291         int m = C.iscm() ? C.colsize() : C.rowsize();
292         int n = C.iscm() ? C.rowsize() : C.colsize();
293         int lda = A.stepj();
294         int ldb = B.iscm()?B.stepj():B.stepi();
295         int ldc = C.iscm()?C.stepj():C.stepi();
296         if (beta == 0) C.setZero();
297         std::complex<float> xbeta(1);
298         if (A.issym())
299             BLASNAME(csymm) (
300                 BLASCM C.iscm()?BLASCH_L:BLASCH_R,
301                 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO,
302                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
303                 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta),
304                 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1);
305         else {
306             if (!C.iscm()) alpha = TMV_CONJ(alpha);
307             BLASNAME(chemm) (
308                 BLASCM C.iscm()?BLASCH_L:BLASCH_R,
309                 A.uplo() == Upper ? BLASCH_UP : BLASCH_LO,
310                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
311                 BLASP(B.cptr()),BLASV(ldb),BLASP(&xbeta),
312                 BLASP(C.ptr()),BLASV(ldc) BLAS1 BLAS1);
313         }
314     }
315 #endif
316     template <class T>
BlasMultMM(const std::complex<T> alpha,const GenSymMatrix<std::complex<T>> & A,const GenMatrix<T> & B,const int beta,MatrixView<std::complex<T>> C)317     static void BlasMultMM(
318         const std::complex<T> alpha,
319         const GenSymMatrix<std::complex<T> >& A, const GenMatrix<T>& B,
320         const int beta, MatrixView<std::complex<T> > C)
321     {
322         if (TMV_IMAG(alpha) == T(0)) {
323             SymMatrix<T,Lower|ColMajor> A1 = A.realPart();
324             Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*A1*B;
325             if (beta == 0) C.realPart() = C1;
326             else C.realPart() += C1;
327             if (A.issym()) {
328                 A1 = A.imagPart();
329                 if (C.isconj()) C1 = -TMV_REAL(alpha)*A1*B;
330                 else C1 = TMV_REAL(alpha)*A1*B;
331             } else {
332                 LowerTriMatrixView<T> L = A1.lowerTri();
333                 L = A.lowerTri().imagPart();
334                 // A.imagPart() = L - LT
335                 if (A.lowerTri().isconj() != C.isconj()) {
336                     C1 = -TMV_REAL(alpha)*L*B;
337                     C1 += TMV_REAL(alpha)*L.transpose()*B;
338                 } else {
339                     C1 = TMV_REAL(alpha)*L*B;
340                     C1 -= TMV_REAL(alpha)*L.transpose()*B;
341                 }
342             }
343             if (beta == 0) C.imagPart() = C1;
344             else C.imagPart() += C1;
345         } else {
346             SymMatrix<T,Lower|ColMajor> Ar = A.realPart();
347             SymMatrix<T,Lower|ColMajor> Ai(A.size());
348             LowerTriMatrixView<T> L = Ai.lowerTri();
349             Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*Ar*B;
350             if (A.issym()) {
351                 Ai = A.imagPart();
352                 C1 -= TMV_IMAG(alpha)*Ai*B;
353             } else {
354                 L = A.lowerTri().imagPart();
355                 if (A.lowerTri().isconj()) L *= T(-1);
356                 C1 -= TMV_IMAG(alpha)*L*B;
357                 C1 += TMV_IMAG(alpha)*L.transpose()*B;
358             }
359             if (beta == 0) C.realPart() = C1;
360             else C.realPart() += C1;
361             C1 = TMV_IMAG(alpha)*Ar*B;
362             if (A.issym()) {
363                 C1 += TMV_REAL(alpha)*Ai*B;
364             } else {
365                 C1 += TMV_REAL(alpha)*L*B;
366                 C1 -= TMV_REAL(alpha)*L.transpose()*B;
367             }
368             if (C.isconj()) C1 *= T(-1);
369             if (beta == 0) C.imagPart() = C1;
370             else C.imagPart() += C1;
371         }
372     }
373     template <class T>
BlasMultMM(const std::complex<T> alpha,const GenSymMatrix<T> & A,const GenMatrix<std::complex<T>> & B,const int beta,MatrixView<std::complex<T>> C)374     static void BlasMultMM(
375         const std::complex<T> alpha,
376         const GenSymMatrix<T>& A, const GenMatrix<std::complex<T> >& B,
377         const int beta, MatrixView<std::complex<T> > C)
378     {
379         if (TMV_IMAG(alpha) == T(0)) {
380             Matrix<T,ColMajor> B1 = B.realPart();
381             Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*A*B1;
382             if (beta == 0) C.realPart() = C1;
383             else C.realPart() += C1;
384             B1 = B.imagPart();
385             if (B.isconj()) C1 = -TMV_REAL(alpha)*A*B1;
386             else C1 = TMV_REAL(alpha)*A*B1;
387             if (beta == 0) C.imagPart() = C1;
388             else C.imagPart() += C1;
389         } else {
390             Matrix<T,ColMajor> Br = B.realPart();
391             Matrix<T,ColMajor> Bi = B.imagPart();
392             Matrix<T,ColMajor> C1 = TMV_REAL(alpha)*A*Br;
393             if (B.isconj()) C1 += TMV_IMAG(alpha)*A*Bi;
394             else C1 -= TMV_IMAG(alpha)*A*Bi;
395             if (beta == 0) C.realPart() = C1;
396             else C.realPart() += C1;
397 
398             if (B.isconj()) C1 = -TMV_REAL(alpha)*A*Bi;
399             else C1 = TMV_REAL(alpha)*A*Bi;
400             C1 += TMV_IMAG(alpha)*A*Br;
401             if (beta == 0) C.imagPart() = C1;
402             else C.imagPart() += C1;
403         }
404     }
405     template <class T>
BlasMultMM(const std::complex<T> alpha,const GenSymMatrix<T> & A,const GenMatrix<T> & B,const int beta,MatrixView<std::complex<T>> C)406     static void BlasMultMM(
407         const std::complex<T> alpha,
408         const GenSymMatrix<T>& A, const GenMatrix<T>& B,
409         const int beta, MatrixView<std::complex<T> > C)
410     {
411         Matrix<T,ColMajor> C1 = A*B;
412         if (beta == 0) C = alpha*C1;
413         else C += alpha*C1;
414     }
415 #endif // BLAS
416 
417     template <bool add, class T, class Ta, class Tb>
DoMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)418     static void DoMultMM(
419         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
420         MatrixView<T> C)
421     {
422         TMVAssert(A.size() == C.colsize());
423         TMVAssert(A.size() == B.colsize());
424         TMVAssert(B.rowsize() == C.rowsize());
425         TMVAssert(C.colsize() > 0);
426         TMVAssert(C.rowsize() > 0);
427         TMVAssert(A.rowsize() > 0);
428         TMVAssert(alpha != T(0));
429 
430 #ifdef BLAS
431         if (A.isrm())
432             DoMultMM<add>(alpha,A.issym()?A.transpose():A.adjoint(),B,C);
433         else if (A.isconj())
434             DoMultMM<add>(
435                 TMV_CONJ(alpha),A.conjugate(),B.conjugate(),C.conjugate());
436         else if ( !((C.isrm() && C.stepi()>0) || (C.iscm() && C.stepj()>0)) ||
437                   (C.iscm() && C.isconj()) ||
438                   (C.isrm() && C.isconj()==A.issym()) ) {
439             Matrix<T,ColMajor> C2(C.colsize(),C.rowsize());
440             DoMultMM<false>(T(1),A,B,C2.view());
441             if (add) C += alpha*C2;
442             else C = alpha*C2;
443         } else if (!(A.iscm() && A.stepj()>0)) {
444             if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) {
445                 if (A.isherm()) {
446                     if (A.uplo() == Upper) {
447                         HermMatrix<Ta,Upper|ColMajor> A2 = TMV_REAL(alpha)*A;
448                         DoMultMM<add>(T(1),A2,B,C);
449                     } else {
450                         HermMatrix<Ta,Lower|ColMajor> A2 = TMV_REAL(alpha)*A;
451                         DoMultMM<add>(T(1),A2,B,C);
452                     }
453                 } else {
454                     if (A.uplo() == Upper) {
455                         SymMatrix<Ta,Upper|ColMajor> A2 = TMV_REAL(alpha)*A;
456                         DoMultMM<add>(T(1),A2,B,C);
457                     } else {
458                         SymMatrix<Ta,Lower|ColMajor> A2 = TMV_REAL(alpha)*A;
459                         DoMultMM<add>(T(1),A2,B,C);
460                     }
461                 }
462             } else {
463                 if (!A.issym()) {
464                     if (A.uplo() == Upper) {
465                         // alpha * A is not Hermitian, so can't do
466                         // A2 = alpha * A
467                         HermMatrix<Ta,Upper|ColMajor> A2 = A;
468                         DoMultMM<add>(alpha,A2,B,C);
469                     } else {
470                         HermMatrix<Ta,Lower|ColMajor> A2 = A;
471                         DoMultMM<add>(alpha,A2,B,C);
472                     }
473                 } else {
474                     if (A.uplo() == Upper) {
475                         SymMatrix<T,Upper|ColMajor> A2 = alpha*A;
476                         DoMultMM<add>(T(1),A2,B,C);
477                     } else {
478                         SymMatrix<T,Lower|ColMajor> A2 = alpha*A;
479                         DoMultMM<add>(T(1),A2,B,C);
480                     }
481                 }
482             }
483         } else if (!(B.isrm()==C.isrm() && B.iscm()==C.iscm()) ||
484                    (isComplex(Tb()) && B.isconj() != C.isconj()) ||
485                    !((B.isrm() && B.stepi()>0) || (B.iscm() && B.stepj()>0))) {
486             if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) {
487                 if (C.isconj()) {
488                     if (C.iscm()) {
489                         Matrix<Tb,ColMajor> B2 = TMV_REAL(alpha)*B.conjugate();
490                         DoMultMM<add>(T(1),A,B2.conjugate(),C);
491                     } else {
492                         Matrix<Tb,RowMajor> B2 = TMV_REAL(alpha)*B.conjugate();
493                         DoMultMM<add>(T(1),A,B2.conjugate(),C);
494                     }
495                 } else {
496                     if (C.iscm()) {
497                         Matrix<Tb,ColMajor> B2 = TMV_REAL(alpha)*B;
498                         DoMultMM<add>(T(1),A,B2,C);
499                     } else {
500                         Matrix<Tb,RowMajor> B2 = TMV_REAL(alpha)*B;
501                         DoMultMM<add>(T(1),A,B2,C);
502                     }
503                 }
504             } else {
505                 if (C.isconj()) {
506                     if (C.iscm()) {
507                         Matrix<T,ColMajor> B2 = TMV_CONJ(alpha)*B.conjugate();
508                         DoMultMM<add>(T(1),A,B2.conjugate(),C);
509                     } else {
510                         Matrix<T,RowMajor> B2 = TMV_CONJ(alpha)*B.conjugate();
511                         DoMultMM<add>(T(1),A,B2.conjugate(),C);
512                     }
513                 } else {
514                     if (C.iscm()) {
515                         Matrix<T,ColMajor> B2 = alpha*B;
516                         DoMultMM<add>(T(1),A,B2,C);
517                     } else {
518                         Matrix<T,RowMajor> B2 = alpha*B;
519                         DoMultMM<add>(T(1),A,B2,C);
520                     }
521                 }
522             }
523         } else {
524             BlasMultMM(alpha,A,B,add?1:0,C);
525         }
526 #else
527         NonBlasMultMM<add>(alpha,A,B,C);
528 #endif
529     }
530 
531     template <bool add, class T, class Ta, class Tb>
FullTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)532     static void FullTempMultMM(
533         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
534         MatrixView<T> C)
535     {
536         if (C.isrm()) {
537             Matrix<T,RowMajor> C2(C.colsize(),C.rowsize());
538             DoMultMM<false>(T(1),A,B,C2.view());
539             if (add) C += alpha*C2;
540             else C = alpha*C2;
541         } else {
542             Matrix<T,ColMajor> C2(C.colsize(),C.rowsize());
543             DoMultMM<false>(T(1),A,B,C2.view());
544             if (add) C += alpha*C2;
545             else C = alpha*C2;
546         }
547     }
548 
549     template <bool add, class T, class Ta, class Tb>
BlockTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)550     static void BlockTempMultMM(
551         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
552         MatrixView<T> C)
553     {
554         const ptrdiff_t N = C.rowsize();
555         for(ptrdiff_t j=0;j<N;) {
556             ptrdiff_t j2 = TMV_MIN(N,j+SYM_MM_BLOCKSIZE);
557             if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) {
558                 if (C.isrm()) {
559                     Matrix<Tb,RowMajor> B2 = TMV_REAL(alpha) * B.colRange(j,j2);
560                     DoMultMM<add>(T(1),A,B2,C.colRange(j,j2));
561                 } else {
562                     Matrix<Tb,ColMajor> B2 = TMV_REAL(alpha) * B.colRange(j,j2);
563                     DoMultMM<add>(T(1),A,B2,C.colRange(j,j2));
564                 }
565             } else {
566                 if (C.isrm()) {
567                     Matrix<T,RowMajor> B2 = alpha * B.colRange(j,j2);
568                     DoMultMM<add>(T(1),A,B2,C.colRange(j,j2));
569                 } else {
570                     Matrix<T,ColMajor> B2 = alpha * B.colRange(j,j2);
571                     DoMultMM<add>(T(1),A,B2,C.colRange(j,j2));
572                 }
573             }
574             j = j2;
575         }
576     }
577 
578     template <bool add, class T, class Ta, class Tb>
MultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenMatrix<Tb> & B,MatrixView<T> C)579     void MultMM(
580         const T alpha, const GenSymMatrix<Ta>& A, const GenMatrix<Tb>& B,
581         MatrixView<T> C)
582     // C (+)= alpha * A * B
583     {
584         TMVAssert(A.size() == C.colsize());
585         TMVAssert(A.size() == B.colsize());
586         TMVAssert(B.rowsize() == C.rowsize());
587 #ifdef XDEBUG
588         //cout<<"Start MultMM: alpha = "<<alpha<<endl;
589         //cout<<"A = "<<A.cptr()<<"  "<<TMV_Text(A)<<"  "<<A<<endl;
590         //cout<<"B = "<<B.cptr()<<"  "<<TMV_Text(B)<<"  "<<B<<endl;
591         //cout<<"C = "<<C.cptr()<<"  "<<TMV_Text(C)<<"  "<<C<<endl;
592         Matrix<Ta> A0 = A;
593         Matrix<Tb> B0 = B;
594         Matrix<T> C0 = C;
595         Matrix<T> C2 = alpha*A0*B0;
596         if (add) C2 += C0;
597 #endif
598 
599         if (C.colsize() > 0 && C.rowsize() > 0) {
600             if (alpha == T(0)) {
601                 if (!add) C.setZero();
602             }
603             else if (SameStorage(A,C))
604                 FullTempMultMM<add>(alpha,A,B,C);
605             else if (SameStorage(B,C))
606                 if (C.stepi() == B.stepi() && C.stepj() == B.stepj())
607                     BlockTempMultMM<add>(alpha,A,B,C);
608                 else
609                     FullTempMultMM<add>(alpha,A,B,C);
610             else DoMultMM<add>(alpha, A, B, C);
611         }
612 
613 #ifdef XDEBUG
614         //cout<<"Done: C = "<<C<<endl;
615         if (!(Norm(C-C2) <=
616               0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(B0)+
617                      (add?Norm(C0):TMV_RealType(T)(0))))) {
618             cerr<<"MultMM: alpha = "<<alpha<<endl;
619             cerr<<"add = "<<add<<endl;
620             cerr<<"A = "<<TMV_Text(A)<<"  "<<A0<<endl;
621             cerr<<"B = "<<TMV_Text(B)<<"  "<<B0<<endl;
622             cerr<<"C = "<<TMV_Text(C)<<"  "<<C0<<endl;
623             cerr<<"--> C = "<<C<<endl;
624             cerr<<"C2 = "<<C2<<endl;
625             abort();
626         }
627 #endif
628     }
629 
630     template <bool add, class T, class Ta, class Tb>
BlockTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenSymMatrix<Tb> & B,MatrixView<T> C)631     static void BlockTempMultMM(
632         const T alpha, const GenSymMatrix<Ta>& A, const GenSymMatrix<Tb>& B,
633         MatrixView<T> C)
634     {
635         TMVAssert(A.size() == B.size());
636         TMVAssert(A.size() == C.colsize());
637         TMVAssert(A.size() == C.rowsize());
638         TMVAssert(A.size() > 0);
639         TMVAssert(alpha != T(0));
640 
641         const ptrdiff_t N = A.size();
642 
643         for(ptrdiff_t j=0;j<N;) {
644             ptrdiff_t j2 = TMV_MIN(N,j+SYM_MM_BLOCKSIZE);
645             if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) {
646                 if (C.isrm()) {
647                     Matrix<Tb,RowMajor> B2(N,j2-j);
648                     B2.rowRange(0,j) = TMV_REAL(alpha) * B.subMatrix(0,j,j,j2);
649                     B2.rowRange(j,j2) = TMV_REAL(alpha) * B.subSymMatrix(j,j2);
650                     B2.rowRange(j2,N) = TMV_REAL(alpha) * B.subMatrix(j2,N,j,j2);
651                     DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2));
652                 } else {
653                     Matrix<Tb,ColMajor> B2(N,j2-j);
654                     B2.rowRange(0,j) = TMV_REAL(alpha) * B.subMatrix(0,j,j,j2);
655                     B2.rowRange(j,j2) = TMV_REAL(alpha) * B.subSymMatrix(j,j2);
656                     B2.rowRange(j2,N) = TMV_REAL(alpha) * B.subMatrix(j2,N,j,j2);
657                     DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2));
658                 }
659             } else {
660                 if (C.isrm()) {
661                     Matrix<T,RowMajor> B2(N,j2-j);
662                     B2.rowRange(0,j) = alpha * B.subMatrix(0,j,j,j2);
663                     B2.rowRange(j,j2) = alpha * B.subSymMatrix(j,j2);
664                     B2.rowRange(j2,N) = alpha * B.subMatrix(j2,N,j,j2);
665                     DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2));
666                 } else {
667                     Matrix<T,ColMajor> B2(N,j2-j);
668                     B2.rowRange(0,j) = alpha * B.subMatrix(0,j,j,j2);
669                     B2.rowRange(j,j2) = alpha * B.subSymMatrix(j,j2);
670                     B2.rowRange(j2,N) = alpha * B.subMatrix(j2,N,j,j2);
671                     DoMultMM<add>(T(1),A,B2.view(),C.colRange(j,j2));
672                 }
673             }
674             j = j2;
675         }
676     }
677 
678     template <bool add, class T, class Ta, class Tb>
FullTempMultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenSymMatrix<Tb> & B,MatrixView<T> C)679     static void FullTempMultMM(
680         const T alpha, const GenSymMatrix<Ta>& A, const GenSymMatrix<Tb>& B,
681         MatrixView<T> C)
682     {
683         if (C.isrm()) {
684             Matrix<T,RowMajor> C2(C.colsize(),C.rowsize());
685             BlockTempMultMM<false>(T(1),A,B,C2.view());
686             if (add) C += alpha*C2;
687             else C = alpha*C2;
688         } else {
689             Matrix<T,ColMajor> C2(C.colsize(),C.rowsize());
690             BlockTempMultMM<false>(T(1),A,B,C2.view());
691             if (add) C += alpha*C2;
692             else C = alpha*C2;
693         }
694     }
695 
696     template <bool add, class T, class Ta, class Tb>
MultMM(const T alpha,const GenSymMatrix<Ta> & A,const GenSymMatrix<Tb> & B,MatrixView<T> C)697     void MultMM(
698         const T alpha, const GenSymMatrix<Ta>& A,
699         const GenSymMatrix<Tb>& B, MatrixView<T> C)
700     // C (+)= alpha * A * B
701     {
702         TMVAssert(A.size() == B.size());
703         TMVAssert(A.size() == C.colsize());
704         TMVAssert(A.size() == C.rowsize());
705 #ifdef XDEBUG
706         //cout<<"Start MultMM: alpha = "<<alpha<<endl;
707         //cout<<"A = "<<A.cptr()<<"  "<<TMV_Text(A)<<"  "<<A<<endl;
708         //cout<<"B = "<<B.cptr()<<"  "<<TMV_Text(B)<<"  "<<B<<endl;
709         //cout<<"C = "<<C.cptr()<<"  "<<TMV_Text(C)<<"  "<<C<<endl;
710         Matrix<Ta> A0 = A;
711         Matrix<Tb> B0 = B;
712         Matrix<T> C0 = C;
713         Matrix<T> C2 = alpha*A0*B0;
714         if (add) C2 += C0;
715 #endif
716 
717         if (A.size() > 0) {
718             if (SameStorage(A,C) || SameStorage(B,C))
719                 FullTempMultMM<add>(alpha,A,B,C);
720             else BlockTempMultMM<add>(alpha, A, B, C);
721         }
722 
723 #ifdef XDEBUG
724         //cout<<"done: C = "<<C<<endl;
725         if (!(Norm(C-C2) <=
726               0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(B0)+
727                      (add?Norm(C0):TMV_RealType(T)(0))))) {
728             cerr<<"MultMM: alpha = "<<alpha<<endl;
729             cerr<<"add = "<<add<<endl;
730             cerr<<"A = "<<TMV_Text(A)<<"  "<<A0<<endl;
731             cerr<<"B = "<<TMV_Text(B)<<"  "<<B0<<endl;
732             cerr<<"C = "<<TMV_Text(C)<<"  "<<C0<<endl;
733             cerr<<"--> C = "<<C<<endl;
734             cerr<<"C2 = "<<C2<<endl;
735             abort();
736         }
737 #endif
738     }
739 
740 #define InstFile "TMV_MultSM.inst"
741 #include "TMV_Inst.h"
742 #undef InstFile
743 
744 } // namespace tmv
745 
746 
747