1 ///////////////////////////////////////////////////////////////////////////////
2 //                                                                           //
3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis     //
4 // Copyright (C) 1998 - 2016                                                 //
5 // All rights reserved                                                       //
6 //                                                                           //
7 // The project is hosted at https://code.google.com/p/tmv-cpp/               //
8 // where you can find the current version and current documention.           //
9 //                                                                           //
10 // For concerns or problems with the software, Mike may be contacted at      //
11 // mike_jarvis17 [at] gmail.                                                 //
12 //                                                                           //
13 // This software is licensed under a FreeBSD license.  The file              //
14 // TMV_LICENSE should have bee included with this distribution.              //
15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/.       //
16 //                                                                           //
17 // Essentially, you can use this software however you want provided that     //
18 // you include the TMV_LICENSE file in any distribution that uses it.        //
19 //                                                                           //
20 ///////////////////////////////////////////////////////////////////////////////
21 
22 
23 //#define XDEBUG
24 
25 #include "TMV_Blas.h"
26 #include "tmv/TMV_MatrixArithFunc.h"
27 #include "TMV_MultMV.h"
28 #include "tmv/TMV_Matrix.h"
29 #include "tmv/TMV_VectorArith.h"
30 #include "tmv/TMV_MatrixArith.h"
31 
32 #ifdef XDEBUG
33 #include "tmv/TMV_VIt.h"
34 #include <iostream>
35 using std::cout;
36 using std::cerr;
37 using std::endl;
38 #endif
39 
40 // CBLAS trick of using RowMajor with ConjTrans when we have a
41 // case of A.conjugate() * x doesn't seem to be working with MKL 10.2.2.
42 // I haven't been able to figure out why.  (e.g. Is it a bug in the MKL
43 // code, or am I doing something wrong?)  So for now, just disable it.
44 #ifdef CBLAS
45 #undef CBLAS
46 #endif
47 
48 namespace tmv {
49 
cptr() const50     template <class T> const T* MatrixComposite<T>::cptr() const
51     {
52         if (!itsm.get()) {
53             ptrdiff_t len = this->colsize()*this->rowsize();
54             itsm.resize(len);
55             MatrixView<T>(itsm.get(),this->colsize(),this->rowsize(),
56                           stepi(),stepj(),NonConj,len
57                           TMV_FIRSTLAST1(itsm.get(),itsm.get()+len) ) = *this;
58         }
59         return itsm.get();
60     }
61 
stepi() const62     template <class T> ptrdiff_t MatrixComposite<T>::stepi() const
63     { return 1; }
64 
stepj() const65     template <class T> ptrdiff_t MatrixComposite<T>::stepj() const
66     { return this->colsize(); }
67 
ls() const68     template <class T> ptrdiff_t MatrixComposite<T>::ls() const
69     { return this->rowsize() * this->colsize(); }
70 
71     //
72     //
73     // MultMV
74     //
75 
76     // These routines are designed to work even if y has the same storage
77     // as either x or the first row/column of A.
78 
79     template <bool add, bool cx, bool ca, bool rm, class T, class Ta, class Tx>
RowMultMV(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)80     static void RowMultMV(
81         const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y)
82     {
83         TMVAssert(A.rowsize() == x.size());
84         TMVAssert(A.colsize() == y.size());
85         TMVAssert(x.size() > 0);
86         TMVAssert(y.size() > 0);
87         TMVAssert(y.ct()==NonConj);
88         TMVAssert(x.step() == 1);
89         TMVAssert(y.step() == 1);
90         TMVAssert(!SameStorage(x,y));
91         TMVAssert(cx == x.isconj());
92         TMVAssert(ca == A.isconj());
93 
94         const ptrdiff_t M = A.colsize();
95         const ptrdiff_t N = A.rowsize();
96         const ptrdiff_t si = A.stepi();
97         const ptrdiff_t sj = (rm ? 1 : A.stepj());
98 
99         const Ta* Ai0 = A.cptr();
100         const Tx*const x0 = x.cptr();
101         T* yi = y.ptr();
102 
103         for(ptrdiff_t i=M; i>0; --i,++yi,Ai0+=si) {
104             // *yi += A.row(i) * x
105 
106             const Ta* Aij = Ai0;
107             const Tx* xj = x0;
108             register T temp(0);
109             for(ptrdiff_t j=N; j>0; --j,++xj,(rm?++Aij:Aij+=sj))
110                 temp +=
111                     (cx ? TMV_CONJ(*xj) : *xj) *
112                     (ca ? TMV_CONJ(*Aij) : *Aij);
113 
114 #ifdef TMVFLDEBUG
115             TMVAssert(yi >= y._first);
116             TMVAssert(yi < y._last);
117 #endif
118             if (add) *yi += temp;
119             else *yi = temp;
120         }
121     }
122 
123     template <bool add, bool cx, bool ca, bool cm, class T, class Ta, class Tx>
ColMultMV(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)124     static void ColMultMV(
125         const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y)
126     {
127         TMVAssert(A.rowsize() == x.size());
128         TMVAssert(A.colsize() == y.size());
129         TMVAssert(x.size() > 0);
130         TMVAssert(y.size() > 0);
131         TMVAssert(y.ct()==NonConj);
132         TMVAssert(x.step() == 1);
133         TMVAssert(y.step() == 1);
134         TMVAssert(!SameStorage(x,y));
135         TMVAssert(cx == x.isconj());
136         TMVAssert(ca == A.isconj());
137         TMVAssert(cm == A.iscm());
138 
139         const ptrdiff_t M = A.colsize();
140         ptrdiff_t N = A.rowsize();
141         const ptrdiff_t si = (cm ? 1 : A.stepi());
142         const ptrdiff_t sj = A.stepj();
143 
144         const Ta* A0j = A.cptr();
145         const Tx* xj = x.cptr();
146         T*const y0 = y.ptr();
147 
148         if (!add) {
149             if (*xj == Tx(0)) {
150                 y.setZero();
151             } else {
152                 const Ta* Aij = A0j;
153                 T* yi = y0;
154                 const Tx xjval = (cx ? TMV_CONJ(*xj) : *xj);
155                 for(ptrdiff_t i=M; i>0; --i,++yi,(cm?++Aij:Aij+=si)) {
156 #ifdef TMVFLDEBUG
157                     TMVAssert(yi >= y._first);
158                     TMVAssert(yi < y._last);
159 #endif
160                     *yi = xjval * (ca ? TMV_CONJ(*Aij) : *Aij);
161                 }
162             }
163             ++xj; A0j+=sj; --N;
164         }
165 
166         for(; N>0; --N,++xj,A0j+=sj) {
167             // y += *xj * A.col(j)
168             if (*xj != Tx(0)) {
169                 const Ta* Aij = A0j;
170                 T* yi = y0;
171                 const Tx xjval = (cx ? TMV_CONJ(*xj) : *xj);
172                 for(ptrdiff_t i=M; i>0; --i,++yi,(cm?++Aij:Aij+=si)) {
173 #ifdef TMVFLDEBUG
174                     TMVAssert(yi >= y._first);
175                     TMVAssert(yi < y._last);
176 #endif
177                     *yi += xjval * (ca ? TMV_CONJ(*Aij) : *Aij);
178                 }
179             }
180         }
181     }
182 
183     template <bool add, bool cx, class T, class Ta, class Tx>
UnitAMultMV1(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)184     void UnitAMultMV1(
185         const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y)
186     {
187         TMVAssert(A.rowsize() == x.size());
188         TMVAssert(A.colsize() == y.size());
189         TMVAssert(x.size() > 0);
190         TMVAssert(y.size() > 0);
191         TMVAssert(y.ct() == NonConj);
192         TMVAssert(x.step() == 1);
193         TMVAssert(y.step() == 1);
194         TMVAssert(!SameStorage(x,y));
195         TMVAssert(cx == x.isconj());
196 
197         if (A.isrm())
198             if (A.isconj())
199                 RowMultMV<add,cx,true,true>(A,x,y);
200             else
201                 RowMultMV<add,cx,false,true>(A,x,y);
202         else if (A.iscm())
203             if (A.isconj())
204                 ColMultMV<add,cx,true,true>(A,x,y);
205             else
206                 ColMultMV<add,cx,false,true>(A,x,y);
207         else if ( A.rowsize() >= A.colsize() )
208             if (A.isconj())
209                 RowMultMV<add,cx,true,false>(A,x,y);
210             else
211                 RowMultMV<add,cx,false,false>(A,x,y);
212         else
213             if (A.isconj())
214                 ColMultMV<add,cx,true,false>(A,x,y);
215             else
216                 ColMultMV<add,cx,false,false>(A,x,y);
217     }
218 
219     template <bool add, bool cx, class T, class Ta, class Tx>
UnitAMultMV(const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)220     static void UnitAMultMV(
221         const GenMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y)
222     {
223 #ifdef XDEBUG
224         cout<<"Start UnitAMultMV: \n";
225         cout<<"add = "<<add<<endl;
226         cout<<"A = "<<TMV_Text(A)<<"  "<<A<<endl;
227         cout<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x<<endl;
228         cout<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<"  "<<y<<endl;
229         Vector<Tx> x0 = x;
230         Vector<T> y0 = y;
231         Matrix<Ta> A0 = A;
232         Vector<T> y2 = y;
233         for(ptrdiff_t i=0;i<y.size();i++) {
234             if (add)
235                 y2(i) += (A.row(i) * x0);
236             else
237                 y2(i) = (A.row(i) * x0);
238         }
239         cout<<"y2 = "<<y2<<endl;
240 #endif
241         // Check for 0's in beginning or end of x:
242         // y += [ A1 A2 A3 ] [ 0 ]  -->  y += A2 x
243         //                   [ x ]
244         //                   [ 0 ]
245 
246         const ptrdiff_t N = x.size(); // = A.rowsize()
247         ptrdiff_t j2 = N;
248         for(const Tx* x2=x.cptr()+N-1; j2>0 && *x2==Tx(0); --j2,--x2);
249         if (j2 == 0) {
250             if (!add) y.setZero();
251             return;
252         }
253         ptrdiff_t j1 = 0;
254         for(const Tx* x1=x.cptr(); *x1==Tx(0); ++j1,++x1) {}
255         TMVAssert(j1 !=j2);
256         if (j1 == 0 && j2 == N) UnitAMultMV1<add,cx>(A,x,y);
257         else UnitAMultMV1<add,cx>(A.colRange(j1,j2),x.subVector(j1,j2),y);
258 
259 #ifdef XDEBUG
260         cout<<"y => "<<y<<endl;
261         if (!(Norm(y-y2) <=
262               0.001*(Norm(A0)*Norm(x0)+
263                      (add?Norm(y0):TMV_RealType(T)(0))))) {
264             cerr<<"MultMV: \n";
265             cerr<<"add = "<<add<<endl;
266             cerr<<"A = "<<TMV_Text(A);
267             if (A.rowsize() < 30 && A.colsize() < 30) cerr<<"  "<<A0;
268             else cerr<<"  "<<A.colsize()<<" x "<<A.rowsize();
269             cerr<<endl<<"x = "<<TMV_Text(x)<<" step "<<x.step();
270             if (x.size() < 30) cerr<<"  "<<x0;
271             cerr<<endl<<"y = "<<TMV_Text(y)<<" step "<<y.step();
272             if (add && y.size() < 30) cerr<<"  "<<y0;
273             cerr<<endl<<"Aptr = "<<A.cptr();
274             cerr<<", xptr = "<<x.cptr()<<", yptr = "<<y.cptr()<<endl;
275             if (y.size() < 200) {
276                 cerr<<"--> y = "<<y<<endl;
277                 cerr<<"y2 = "<<y2<<endl;
278             } else {
279                 ptrdiff_t imax;
280                 (y-y2).maxAbsElement(&imax);
281                 cerr<<"y("<<imax<<") = "<<y(imax)<<endl;
282                 cerr<<"y2("<<imax<<") = "<<y2(imax)<<endl;
283             }
284             cerr<<"Norm(A0) = "<<Norm(A0)<<endl;
285             cerr<<"Norm(x0) = "<<Norm(x0)<<endl;
286             if (add) cerr<<"Norm(y0) = "<<Norm(y0)<<endl;
287             cerr<<"|A0|*|x0|+?|y0| = "<<
288                 Norm(A0)*Norm(x0)+
289                 (add?Norm(y0):TMV_RealType(T)(0))<<endl;
290             cerr<<"Norm(y-y2) = "<<Norm(y-y2)<<endl;
291             cerr<<"NormInf(y-y2) = "<<NormInf(y-y2)<<endl;
292             cerr<<"Norm1(y-y2) = "<<Norm1(y-y2)<<endl;
293             abort();
294         }
295 #endif
296     }
297 
NonBlasMultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)298     template <bool add, class T, class Ta, class Tx> static void NonBlasMultMV(
299         const T alpha, const GenMatrix<Ta>& A, const GenVector<Tx>& x,
300         VectorView<T> y)
301         // y (+)= alpha * A * x
302     {
303 #ifdef XDEBUG
304         cout<<"Start MultMV: alpha = "<<alpha<<endl;
305         cout<<"add = "<<add<<endl;
306         cout<<"A = "<<TMV_Text(A)<<"  "<<A<<endl;
307         cout<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x<<endl;
308         cout<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<"  "<<y<<endl;
309         Vector<Tx> x0 = x;
310         Vector<T> y0 = y;
311         Matrix<Ta> A0 = A;
312         Vector<T> y2 = y;
313         for(ptrdiff_t i=0;i<y.size();i++) {
314             if (add)
315                 y2(i) += alpha * (A.row(i) * x0);
316             else
317                 y2(i) = alpha * (A.row(i) * x0);
318         }
319         cout<<"y2 = "<<y2<<endl;
320 #endif
321         TMVAssert(A.rowsize() == x.size());
322         TMVAssert(A.colsize() == y.size());
323         TMVAssert(alpha != T(0));
324         TMVAssert(x.size() > 0);
325         TMVAssert(y.size() > 0);
326         TMVAssert(y.ct() == NonConj);
327 
328         const ptrdiff_t M = A.colsize();
329         const ptrdiff_t N = A.rowsize();
330 
331         if (x.step() != 1 || SameStorage(x,y) ||
332             (alpha != TMV_RealType(T)(1) && y.step() == 1 && M/4 >= N)) {
333             // This last check is taken from the ATLAS version of this code.
334             // Apparently M = 4N is the dividing line between applying alpha
335             // here versus at the end when adding Ax to y
336             if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) {
337                 Vector<Tx> xx = TMV_REAL(alpha)*x;
338                 if (y.step()!=1) {
339                     Vector<T> yy(y.size());
340                     UnitAMultMV<false,false>(A,xx,yy.view());
341                     if (add) y += yy;
342                     else y = yy;
343                 }
344                 else
345                     UnitAMultMV<add,false>(A,xx,y);
346             } else {
347                 Vector<T> xx = alpha*x;
348                 if (y.step() != 1) {
349                     Vector<T> yy(y.size());
350                     UnitAMultMV<false,false>(A,xx,yy.view());
351                     if (add) y += yy;
352                     else y = yy;
353                 }
354                 else
355                     UnitAMultMV<add,false>(A,xx,y);
356             }
357         } else if (y.step() != 1 || alpha != TMV_RealType(T)(1)) {
358             Vector<T> yy(y.size());
359             if (x.isconj())
360                 UnitAMultMV<false,true>(A,x,yy.view());
361             else
362                 UnitAMultMV<false,false>(A,x,yy.view());
363             if (add) y += alpha*yy;
364             else y = alpha*yy;
365         } else {
366             TMVAssert(alpha == T(1));
367             TMVAssert(y.step() == 1);
368             TMVAssert(x.step() == 1);
369             TMVAssert(!SameStorage(x,y));
370             if (x.isconj())
371                 UnitAMultMV<add,true>(A,x,y);
372             else
373                 UnitAMultMV<add,false>(A,x,y);
374         }
375 #ifdef XDEBUG
376         cout<<"y => "<<y<<endl;
377         if (!(Norm(y-y2) <=
378               0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
379                      (add?Norm(y0):TMV_RealType(T)(0))))) {
380             cerr<<"MultMV: alpha = "<<alpha<<endl;
381             cerr<<"add = "<<add<<endl;
382             cerr<<"A = "<<TMV_Text(A);
383             if (A.rowsize() < 30 && A.colsize() < 30) cerr<<"  "<<A0;
384             else cerr<<"  "<<A.colsize()<<" x "<<A.rowsize();
385             cerr<<endl<<"x = "<<TMV_Text(x)<<" step "<<x.step();
386             if (x.size() < 30) cerr<<"  "<<x0;
387             cerr<<endl<<"y = "<<TMV_Text(y)<<" step "<<y.step();
388             if (add && y.size() < 30) cerr<<"  "<<y0;
389             cerr<<endl<<"Aptr = "<<A.cptr();
390             cerr<<", xptr = "<<x.cptr()<<", yptr = "<<y.cptr()<<endl;
391             if (y.size() < 200) {
392                 cerr<<"--> y = "<<y<<endl;
393                 cerr<<"y2 = "<<y2<<endl;
394             } else {
395                 ptrdiff_t imax;
396                 (y-y2).maxAbsElement(&imax);
397                 cerr<<"y("<<imax<<") = "<<y(imax)<<endl;
398                 cerr<<"y2("<<imax<<") = "<<y2(imax)<<endl;
399             }
400             cerr<<"Norm(A0) = "<<Norm(A0)<<endl;
401             cerr<<"Norm(x0) = "<<Norm(x0)<<endl;
402             if (add) cerr<<"Norm(y0) = "<<Norm(y0)<<endl;
403             cerr<<"|alpha|*|A0|*|x0|+?|y0| = "<<
404                 TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
405                 (add?Norm(y0):TMV_RealType(T)(0))<<endl;
406             cerr<<"Norm(y-y2) = "<<Norm(y-y2)<<endl;
407             cerr<<"NormInf(y-y2) = "<<NormInf(y-y2)<<endl;
408             cerr<<"Norm1(y-y2) = "<<Norm1(y-y2)<<endl;
409             abort();
410         }
411 #endif
412     }
413 
414 #ifdef BLAS
BlasMultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,const int beta,VectorView<T> y)415     template <class T, class Ta, class Tx> static inline void BlasMultMV(
416         const T alpha, const GenMatrix<Ta>& A,
417         const GenVector<Tx>& x, const int beta, VectorView<T> y)
418     {
419         if (beta == 0) NonBlasMultMV<false>(alpha,A,x,y);
420         else NonBlasMultMV<true>(alpha,A,x,y);
421     }
422 #ifdef INST_DOUBLE
BlasMultMV(const double alpha,const GenMatrix<double> & A,const GenVector<double> & x,const int beta,VectorView<double> y)423     template <> void BlasMultMV(
424         const double alpha, const GenMatrix<double>& A,
425         const GenVector<double>& x, const int beta, VectorView<double> y)
426     {
427         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
428         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
429         int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
430         if (lda < m) { TMVAssert(n==1); lda = m; }
431         int xs = x.step();
432         int ys = y.step();
433         if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
434         if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
435         const double* xp = x.cptr();
436         if (xs < 0) xp += (x.size()-1)*xs;
437         double* yp = y.ptr();
438         if (ys < 0) yp += (y.size()-1)*ys;
439         // Some BLAS implementations seem to have trouble if the
440         // input y has a nan in it.
441         // They propagate the nan into the  output.
442         // I guess they strictly interpret y = beta*y + alpha*m*x,
443         // so if beta = 0, then beta*nan = nan.
444         // Anyway, to fix this problem, we always use beta=1, and just
445         // zero out y before calling the blas function if beta is 0.
446         if (beta == 0) y.setZero();
447         double xbeta(1);
448 
449 #if 0
450         std::cout<<"Before dgemv"<<std::endl;
451         std::cout<<"A = "<<A<<std::endl;
452         std::cout<<"x = "<<x<<std::endl;
453         std::cout<<"y = "<<y<<std::endl;
454         std::cout<<"m = "<<m<<std::endl;
455         std::cout<<"n = "<<n<<std::endl;
456         std::cout<<"lda = "<<lda<<std::endl;
457         std::cout<<"xs = "<<xs<<std::endl;
458         std::cout<<"ys = "<<ys<<std::endl;
459         std::cout<<"alpha = "<<alpha<<std::endl;
460         std::cout<<"beta = "<<xbeta<<std::endl;
461         std::cout<<"aptr = "<<A.cptr()<<std::endl;
462         std::cout<<"xp = "<<xp<<std::endl;
463         std::cout<<"yp = "<<yp<<std::endl;
464         std::cout<<"NT = "<<(A.isrm()?'T':'N')<<std::endl;
465         if (A.isrm()) {
466             std::cout<<"x.size = "<<x.size()<<std::endl;
467             std::cout<<"x = ";
468             for(int i=0;i<m;i++) std::cout<<*(xp+i*xs)<<" ";
469             std::cout<<std::endl;
470             std::cout<<"y.size = "<<y.size()<<std::endl;
471             std::cout<<"y = ";
472             for(int i=0;i<n;i++) std::cout<<*(yp+i*ys)<<" ";
473             std::cout<<std::endl;
474             std::cout<<"A.size = "<<A.colsize()<<','<<A.rowsize()<<std::endl;
475             std::cout<<"A = ";
476             for(int i=0;i<n*lda;i++) std::cout<<*(A.cptr()+i)<<" ";
477             std::cout<<std::endl;
478         }
479 #endif
480         BLASNAME(dgemv) (
481             BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
482             BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda),
483             BLASP(xp),BLASV(xs),BLASV(xbeta),BLASP(yp),BLASV(ys) BLAS1);
484 #if 0
485         std::cout<<"After dgemv"<<std::endl;
486 #endif
487     }
BlasMultMV(const std::complex<double> alpha,const GenMatrix<std::complex<double>> & A,const GenVector<std::complex<double>> & x,const int beta,VectorView<std::complex<double>> y)488     template <> void BlasMultMV(
489         const std::complex<double> alpha,
490         const GenMatrix<std::complex<double> >& A,
491         const GenVector<std::complex<double> >& x,
492         const int beta, VectorView<std::complex<double> > y)
493     {
494         if (x.isconj()
495 #ifndef CBLAS
496             && !(A.isconj() && BlasIsCM(A))
497 #endif
498         ) {
499             Vector<std::complex<double> > xx = alpha*x;
500             return BlasMultMV(std::complex<double>(1),A,xx,beta,y);
501         }
502 
503         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
504         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
505         int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
506         if (lda < m) { TMVAssert(n==1); lda = m; }
507         int xs = x.step();
508         int ys = y.step();
509         if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
510         if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
511         const std::complex<double>* xp = x.cptr();
512         if (xs < 0) xp += (x.size()-1)*xs;
513         std::complex<double>* yp = y.ptr();
514         if (ys < 0) yp += (y.size()-1)*ys;
515         if (beta == 0) y.setZero();
516         std::complex<double> xbeta(1);
517 #if 0
518         std::cout<<"Before zgemv"<<std::endl;
519         std::cout<<"A = "<<A<<std::endl;
520         std::cout<<"x = "<<x<<std::endl;
521         std::cout<<"y = "<<y<<std::endl;
522         std::cout<<"m = "<<m<<std::endl;
523         std::cout<<"n = "<<n<<std::endl;
524         std::cout<<"lda = "<<lda<<std::endl;
525         std::cout<<"xs = "<<xs<<std::endl;
526         std::cout<<"ys = "<<ys<<std::endl;
527         std::cout<<"alpha = "<<alpha<<std::endl;
528         std::cout<<"beta = "<<xbeta<<std::endl;
529         std::cout<<"aptr = "<<A.cptr()<<std::endl;
530         std::cout<<"xp = "<<xp<<std::endl;
531         std::cout<<"yp = "<<yp<<std::endl;
532 #endif
533         if (A.isconj() && BlasIsCM(A)) {
534 #ifdef CBLAS
535             TMV_SWAP(m,n);
536             BLASNAME(zgemv) (
537                 BLASRM BLASCH_CT,
538                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
539                 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) BLAS1);
540 #else
541             std::complex<double> ca = TMV_CONJ(alpha);
542             if (x.isconj()) {
543                 y.conjugateSelf();
544                 BLASNAME(zgemv) (
545                     BLASCM BLASCH_NT,
546                     BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda),
547                     BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys)
548                     BLAS1);
549                 y.conjugateSelf();
550             } else {
551                 Vector<std::complex<double> > xx = ca*x.conjugate();
552                 ca = std::complex<double>(1);
553                 xs = 1;
554                 xp = xx.cptr();
555                 y.conjugateSelf();
556                 BLASNAME(zgemv) (
557                     BLASCM BLASCH_NT,
558                     BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda),
559                     BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys)
560                     BLAS1);
561                 y.conjugateSelf();
562             }
563 #endif
564         } else {
565             BLASNAME(zgemv) (
566                 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
567                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
568                 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys)
569                 BLAS1);
570         }
571     }
BlasMultMV(const std::complex<double> alpha,const GenMatrix<std::complex<double>> & A,const GenVector<double> & x,const int beta,VectorView<std::complex<double>> y)572     template <> void BlasMultMV(
573         const std::complex<double> alpha,
574         const GenMatrix<std::complex<double> >& A,
575         const GenVector<double>& x,
576         const int beta, VectorView<std::complex<double> > y)
577     {
578         if (BlasIsCM(A)) {
579             if (y.step() != 1) {
580                 Vector<std::complex<double> > yy(y.size());
581                 BlasMultMV(std::complex<double>(1),A,x,0,yy.view());
582                 if (beta == 0) y = alpha*yy;
583                 else y += alpha*yy;
584             } else {
585                 if (beta == 0) {
586                     int m = 2*A.colsize();
587                     int n = A.rowsize();
588                     int lda = 2*A.stepj();
589                     if (lda < m) { TMVAssert(n==1); lda = m; }
590                     int xs = x.step();
591                     int ys = 1;
592                     const double* xp = x.cptr();
593                     if (xs < 0) xp += (x.size()-1)*xs;
594                     double* yp = (double*) y.ptr();
595                     double xalpha(1);
596                     y.setZero();
597                     double xbeta(1);
598                     BLASNAME(dgemv) (
599                         BLASCM BLASCH_NT,
600                         BLASV(m),BLASV(n),BLASV(xalpha),
601                         BLASP((double*)A.cptr()),BLASV(lda),
602                         BLASP(xp),BLASV(xs),BLASV(xbeta),
603                         BLASP(yp),BLASV(ys) BLAS1);
604                     if (A.isconj()) y.conjugateSelf();
605                     y *= alpha;
606                 } else if (A.isconj()) {
607                     Vector<std::complex<double> > yy(y.size());
608                     BlasMultMV(
609                         std::complex<double>(1),A.conjugate(),x,0,yy.view());
610                     y += alpha*yy.conjugate();
611                 } else if (TMV_IMAG(alpha) == 0.) {
612                     int m = 2*A.colsize();
613                     int n = A.rowsize();
614                     int lda = 2*A.stepj();
615                     if (lda < m) { TMVAssert(n==1); lda = m; }
616                     int xs = x.step();
617                     int ys = 1;
618                     const double* xp = x.cptr();
619                     if (xs < 0) xp += (x.size()-1)*xs;
620                     double* yp = (double*) y.ptr();
621                     if (ys < 0) yp += (y.size()-1)*ys;
622                     double xalpha(TMV_REAL(alpha));
623                     double xbeta(1);
624                     BLASNAME(dgemv) (
625                         BLASCM BLASCH_NT,
626                         BLASV(m),BLASV(n),BLASV(xalpha),
627                         BLASP((double*)A.cptr()),BLASV(lda),
628                         BLASP(xp),BLASV(xs),BLASV(xbeta),
629                         BLASP(yp),BLASV(ys) BLAS1);
630                 } else {
631                     Vector<std::complex<double> > yy(y.size());
632                     BlasMultMV(std::complex<double>(1),A,x,0,yy.view());
633                     y += alpha*yy;
634                 }
635             }
636         } else { // A.isrm
637             BlasMultMV(alpha,A,Vector<std::complex<double> >(x),beta,y);
638         }
639     }
BlasMultMV(const std::complex<double> alpha,const GenMatrix<double> & A,const GenVector<std::complex<double>> & x,const int beta,VectorView<std::complex<double>> y)640     template <> void BlasMultMV(
641         const std::complex<double> alpha,
642         const GenMatrix<double>& A,
643         const GenVector<std::complex<double> >& x,
644         const int beta, VectorView<std::complex<double> > y)
645     {
646         if (beta == 0) {
647             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
648             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
649             int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
650             if (lda < m) { TMVAssert(n==1); lda = m; }
651             int xs = 2*x.step();
652             int ys = 2*y.step();
653             if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
654             if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
655             const double* xp = (const double*) x.cptr();
656             if (xs < 0) xp += (x.size()-1)*xs;
657             double* yp = (double*) y.ptr();
658             if (ys < 0) yp += (y.size()-1)*ys;
659             double xalpha(1);
660             y.setZero();
661             double xbeta(1);
662             BLASNAME(dgemv) (
663                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
664                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
665                 BLASP(xp),BLASV(xs),BLASV(xbeta),
666                 BLASP(yp),BLASV(ys) BLAS1);
667             BLASNAME(dgemv) (
668                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
669                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
670                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
671                 BLASP(yp+1),BLASV(ys) BLAS1);
672             if (x.isconj()) y.conjugateSelf();
673             y *= alpha;
674         } else if (TMV_IMAG(alpha) == 0. && !x.isconj()) {
675             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
676             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
677             int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
678             if (lda < m) { TMVAssert(n==1); lda = m; }
679             int xs = 2*x.step();
680             int ys = 2*y.step();
681             if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
682             if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
683             const double* xp = (const double*) x.cptr();
684             if (xs < 0) xp += (x.size()-1)*xs;
685             double* yp = (double*) y.ptr();
686             if (ys < 0) yp += (y.size()-1)*ys;
687             double xalpha(TMV_REAL(alpha));
688             double xbeta(1);
689             BLASNAME(dgemv) (
690                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
691                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
692                 BLASP(xp),BLASV(xs),BLASV(xbeta),
693                 BLASP(yp),BLASV(ys) BLAS1);
694             BLASNAME(dgemv) (
695                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
696                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
697                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
698                 BLASP(yp+1),BLASV(ys) BLAS1);
699         } else {
700             Vector<std::complex<double> > xx = alpha*x;
701             BlasMultMV(std::complex<double>(1),A,xx,1,y);
702         }
703     }
BlasMultMV(const std::complex<double> alpha,const GenMatrix<double> & A,const GenVector<double> & x,const int beta,VectorView<std::complex<double>> y)704     template <> void BlasMultMV(
705         const std::complex<double> alpha,
706         const GenMatrix<double>& A,
707         const GenVector<double>& x,
708         const int beta, VectorView<std::complex<double> > y)
709     {
710         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
711         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
712         int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
713         if (lda < m) { TMVAssert(n==1); lda = m; }
714         int xs = x.step();
715         int ys = 2*y.step();
716         if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
717         if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
718         const double* xp = x.cptr();
719         if (xs < 0) xp += (x.size()-1)*xs;
720         double* yp = (double*) y.ptr();
721         if (ys < 0) yp += (y.size()-1)*ys;
722         double ar(TMV_REAL(alpha));
723         double ai(TMV_IMAG(alpha));
724         double xbeta(beta);
725         if (beta == 0) y.setZero();
726         if (ar != 0.) {
727             BLASNAME(dgemv) (
728                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
729                 BLASV(m),BLASV(n),BLASV(ar),BLASP(A.cptr()),BLASV(lda),
730                 BLASP(xp),BLASV(xs),BLASV(xbeta),
731                 BLASP(yp),BLASV(ys) BLAS1);
732         }
733         if (ai != 0.) {
734             BLASNAME(dgemv) (
735                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
736                 BLASV(m),BLASV(n),BLASV(ai),BLASP(A.cptr()),BLASV(lda),
737                 BLASP(xp),BLASV(xs),BLASV(xbeta),
738                 BLASP(yp+1),BLASV(ys) BLAS1);
739         }
740     }
741 #endif
742 #ifdef INST_FLOAT
BlasMultMV(const float alpha,const GenMatrix<float> & A,const GenVector<float> & x,const int beta,VectorView<float> y)743     template <> void BlasMultMV(
744         const float alpha, const GenMatrix<float>& A,
745         const GenVector<float>& x, const int beta, VectorView<float> y)
746     {
747         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
748         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
749         int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
750         if (lda < m) { TMVAssert(n==1); lda = m; }
751         int xs = x.step();
752         int ys = y.step();
753         if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
754         if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
755         const float* xp = x.cptr();
756         if (xs < 0) xp += (x.size()-1)*xs;
757         float* yp = y.ptr();
758         if (ys < 0) yp += (y.size()-1)*ys;
759         if (beta == 0) y.setZero();
760         float xbeta(1);
761 #if 0
762         std::cout<<"Before sgemv:\n";
763         std::cout<<"A = "<<A<<std::endl;
764         std::cout<<"x = "<<x<<std::endl;
765         std::cout<<"y = "<<y<<std::endl;
766         std::cout<<"m,n = "<<m<<','<<n<<std::endl;
767         std::cout<<"alpha,beta = "<<alpha<<','<<xbeta<<std::endl;
768         std::cout<<"A.cptr = "<<A.cptr()<<std::endl;
769         std::cout<<"xp = "<<xp<<std::endl;
770         std::cout<<"yp = "<<yp<<std::endl;
771         std::cout<<"lda,xs,ys = "<<lda<<','<<xs<<','<<ys<<std::endl;
772         std::cout<<"cm = "<<BlasIsCM(A)<<std::endl;
773 #endif
774 
775         BLASNAME(sgemv) (
776             BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
777             BLASV(m),BLASV(n),BLASV(alpha),BLASP(A.cptr()),BLASV(lda),
778             BLASP(xp),BLASV(xs),BLASV(xbeta),BLASP(yp),BLASV(ys)
779             BLAS1);
780 #if 0
781         std::cout<<"After sgemv:"<<std::endl;
782         std::cout<<"y -> "<<y<<std::endl;
783 #endif
784     }
BlasMultMV(const std::complex<float> alpha,const GenMatrix<std::complex<float>> & A,const GenVector<std::complex<float>> & x,const int beta,VectorView<std::complex<float>> y)785     template <> void BlasMultMV(
786         const std::complex<float> alpha,
787         const GenMatrix<std::complex<float> >& A,
788         const GenVector<std::complex<float> >& x,
789         const int beta, VectorView<std::complex<float> > y)
790     {
791         if (x.isconj()
792 #ifndef CBLAS
793             && !(A.isconj() && BlasIsCM(A))
794 #endif
795         ) {
796             Vector<std::complex<float> > xx = alpha*x;
797             return BlasMultMV(std::complex<float>(1),A,xx,beta,y);
798         }
799 
800         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
801         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
802         int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
803         if (lda < m) { TMVAssert(n==1); lda = m; }
804         int xs = x.step();
805         int ys = y.step();
806         if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
807         if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
808         const std::complex<float>* xp = x.cptr();
809         if (xs < 0) xp += (x.size()-1)*xs;
810         std::complex<float>* yp = y.ptr();
811         if (ys < 0) yp += (y.size()-1)*ys;
812         if (beta == 0) y.setZero();
813         std::complex<float> xbeta(1);
814 #if 0
815         std::cout<<"Before cgemv:\n";
816         std::cout<<"A = "<<A<<std::endl;
817         std::cout<<"x = "<<x<<std::endl;
818         std::cout<<"y = "<<y<<std::endl;
819         std::cout<<"m,n = "<<m<<','<<n<<std::endl;
820         std::cout<<"alpha,beta = "<<alpha<<','<<xbeta<<std::endl;
821         std::cout<<"A.cptr = "<<A.cptr()<<std::endl;
822         std::cout<<"xp = "<<xp<<std::endl;
823         std::cout<<"yp = "<<yp<<std::endl;
824         std::cout<<"lda,xs,ys = "<<lda<<','<<xs<<','<<ys<<std::endl;
825         std::cout<<"conj = "<<A.isconj()<<std::endl;
826         std::cout<<"cm = "<<A.iscm()<<std::endl;
827 #endif
828         if (A.isconj() && BlasIsCM(A)) {
829 #ifdef CBLAS
830             TMV_SWAP(m,n);
831             BLASNAME(cgemv) (
832                 BLASRM BLASCH_CT,
833                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
834                 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys) BLAS1);
835 #else
836             std::complex<float> ca = TMV_CONJ(alpha);
837             if (x.isconj()) {
838                 y.conjugateSelf();
839                 BLASNAME(cgemv) (
840                     BLASCM BLASCH_NT,
841                     BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda),
842                     BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys)
843                     BLAS1);
844                 y.conjugateSelf();
845             } else {
846                 Vector<std::complex<float> > xx = ca*x.conjugate();
847                 ca = std::complex<float>(1);
848                 xs = 1;
849                 xp = xx.cptr();
850                 y.conjugateSelf();
851                 BLASNAME(cgemv) (
852                     BLASCM BLASCH_NT,
853                     BLASV(m),BLASV(n),BLASP(&ca),BLASP(A.cptr()),BLASV(lda),
854                     BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys)
855                     BLAS1);
856                 y.conjugateSelf();
857             }
858 #endif
859         } else {
860             BLASNAME(cgemv) (
861                 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
862                 BLASV(m),BLASV(n),BLASP(&alpha),BLASP(A.cptr()),BLASV(lda),
863                 BLASP(xp),BLASV(xs),BLASP(&xbeta),BLASP(yp),BLASV(ys)
864                 BLAS1);
865         }
866 #if 0
867         std::cout<<"After cgemv:"<<std::endl;
868         std::cout<<"y -> "<<y<<std::endl;
869 #endif
870     }
BlasMultMV(const std::complex<float> alpha,const GenMatrix<std::complex<float>> & A,const GenVector<float> & x,const int beta,VectorView<std::complex<float>> y)871     template <> void BlasMultMV(
872         const std::complex<float> alpha,
873         const GenMatrix<std::complex<float> >& A,
874         const GenVector<float>& x,
875         const int beta, VectorView<std::complex<float> > y)
876     {
877         if (BlasIsCM(A)) {
878             if (y.step() != 1) {
879                 Vector<std::complex<float> > yy(y.size());
880                 BlasMultMV(std::complex<float>(1),A,x,0,yy.view());
881                 if (beta == 0) y = alpha*yy;
882                 else y += alpha*yy;
883             } else {
884                 if (beta == 0) {
885                     int m = 2*A.colsize();
886                     int n = A.rowsize();
887                     int lda = 2*A.stepj();
888                     if (lda < m) { TMVAssert(n==1); lda = m; }
889                     int xs = x.step();
890                     int ys = 1;
891                     const float* xp = x.cptr();
892                     if (xs < 0) xp += (x.size()-1)*xs;
893                     float* yp = (float*) y.ptr();
894                     float xalpha(1);
895                     y.setZero();
896                     float xbeta(1);
897                     BLASNAME(sgemv) (
898                         BLASCM BLASCH_NT,
899                         BLASV(m),BLASV(n),BLASV(xalpha),
900                         BLASP((float*)A.cptr()),BLASV(lda),
901                         BLASP(xp),BLASV(xs),BLASV(xbeta),
902                         BLASP(yp),BLASV(ys) BLAS1);
903                     if (A.isconj()) y.conjugateSelf();
904                     y *= alpha;
905                 } else if (A.isconj()) {
906                     Vector<std::complex<float> > yy(y.size());
907                     BlasMultMV(std::complex<float>(1),A.conjugate(),x,0,yy.view());
908                     y += alpha*yy.conjugate();
909                 } else if (TMV_IMAG(alpha) == 0.F) {
910                     int m = 2*A.colsize();
911                     int n = A.rowsize();
912                     int lda = 2*A.stepj();
913                     if (lda < m) { TMVAssert(n==1); lda = m; }
914                     int xs = x.step();
915                     int ys = 1;
916                     const float* xp = x.cptr();
917                     if (xs < 0) xp += (x.size()-1)*xs;
918                     float* yp = (float*) y.ptr();
919                     float xalpha(TMV_REAL(alpha));
920                     float xbeta(1);
921                     BLASNAME(sgemv) (
922                         BLASCM BLASCH_NT,
923                         BLASV(m),BLASV(n),BLASV(xalpha),
924                         BLASP((float*)A.cptr()),BLASV(lda),
925                         BLASP(xp),BLASV(xs),BLASV(xbeta),
926                         BLASP(yp),BLASV(ys) BLAS1);
927                 } else {
928                     Vector<std::complex<float> > yy(y.size());
929                     BlasMultMV(std::complex<float>(1),A,x,0,yy.view());
930                     y += alpha*yy;
931                 }
932             }
933         } else { // A.isrm
934             BlasMultMV(alpha,A,Vector<std::complex<float> >(x),beta,y);
935         }
936     }
BlasMultMV(const std::complex<float> alpha,const GenMatrix<float> & A,const GenVector<std::complex<float>> & x,const int beta,VectorView<std::complex<float>> y)937     template <> void BlasMultMV(
938         const std::complex<float> alpha,
939         const GenMatrix<float>& A,
940         const GenVector<std::complex<float> >& x,
941         const int beta, VectorView<std::complex<float> > y)
942     {
943         if (beta == 0) {
944             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
945             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
946             int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
947             if (lda < m) { TMVAssert(n==1); lda = m; }
948             int xs = 2*x.step();
949             int ys = 2*y.step();
950             if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
951             if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
952             const float* xp = (const float*) x.cptr();
953             if (xs < 0) xp += (x.size()-1)*xs;
954             float* yp = (float*) y.ptr();
955             if (ys < 0) yp += (y.size()-1)*ys;
956             float xalpha(1);
957             y.setZero();
958             float xbeta(1);
959             BLASNAME(sgemv) (
960                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
961                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
962                 BLASP(xp),BLASV(xs),BLASV(xbeta),
963                 BLASP(yp),BLASV(ys) BLAS1);
964             BLASNAME(sgemv) (
965                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
966                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
967                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
968                 BLASP(yp+1),BLASV(ys) BLAS1);
969             if (x.isconj()) y.conjugateSelf();
970             y *= alpha;
971         } else if (TMV_IMAG(alpha) == 0.F && !x.isconj()) {
972             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
973             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
974             int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
975             if (lda < m) { TMVAssert(n==1); lda = m; }
976             int xs = 2*x.step();
977             int ys = 2*y.step();
978             if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
979             if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
980             const float* xp = (const float*) x.cptr();
981             if (xs < 0) xp += (x.size()-1)*xs;
982             float* yp = (float*) y.ptr();
983             if (ys < 0) yp += (y.size()-1)*ys;
984             float xalpha(TMV_REAL(alpha));
985             float xbeta(1);
986             BLASNAME(sgemv) (
987                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
988                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
989                 BLASP(xp),BLASV(xs),BLASV(xbeta),
990                 BLASP(yp),BLASV(ys) BLAS1);
991             BLASNAME(sgemv) (
992                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
993                 BLASV(m),BLASV(n),BLASV(xalpha),BLASP(A.cptr()),BLASV(lda),
994                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
995                 BLASP(yp+1),BLASV(ys) BLAS1);
996         } else {
997             Vector<std::complex<float> > xx = alpha*x;
998             BlasMultMV(std::complex<float>(1),A,xx,1,y);
999         }
1000     }
BlasMultMV(const std::complex<float> alpha,const GenMatrix<float> & A,const GenVector<float> & x,const int beta,VectorView<std::complex<float>> y)1001     template <> void BlasMultMV(
1002         const std::complex<float> alpha,
1003         const GenMatrix<float>& A,
1004         const GenVector<float>& x,
1005         const int beta, VectorView<std::complex<float> > y)
1006     {
1007         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
1008         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
1009         int lda = BlasIsCM(A) ? A.stepj() : A.stepi();
1010         if (lda < m) { TMVAssert(n==1); lda = m; }
1011         int xs = x.step();
1012         int ys = 2*y.step();
1013         if (xs == 0) { TMVAssert(x.size() == 1); xs = 1; }
1014         if (ys == 0) { TMVAssert(y.size() == 1); ys = 1; }
1015         const float* xp = x.cptr();
1016         if (xs < 0) xp += (x.size()-1)*xs;
1017         float* yp = (float*) y.ptr();
1018         if (ys < 0) yp += (y.size()-1)*ys;
1019         float ar(TMV_REAL(alpha));
1020         float ai(TMV_IMAG(alpha));
1021         if (beta == 0) y.setZero();
1022         float xbeta(1);
1023         if (ar != 0.F) {
1024             BLASNAME(sgemv) (
1025                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
1026                 BLASV(m),BLASV(n),BLASV(ar),BLASP(A.cptr()),BLASV(lda),
1027                 BLASP(xp),BLASV(xs),BLASV(xbeta),
1028                 BLASP(yp),BLASV(ys) BLAS1);
1029         }
1030         if (ai != 0.F) {
1031             BLASNAME(sgemv) (
1032                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
1033                 BLASV(m),BLASV(n),BLASV(ai),BLASP(A.cptr()),BLASV(lda),
1034                 BLASP(xp),BLASV(xs),BLASV(xbeta),
1035                 BLASP(yp+1),BLASV(ys) BLAS1);
1036         }
1037     }
1038 #endif
1039 #endif // BLAS
1040 
1041     template <bool add, class T, class Ta, class Tx>
DoMultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)1042     static void DoMultMV(
1043         const T alpha, const GenMatrix<Ta>& A,
1044         const GenVector<Tx>& x, VectorView<T> y)
1045     {
1046 #ifdef XDEBUG
1047         std::cout<<"Start DoMultMV\n";
1048         std::cout<<"alpha = "<<alpha<<std::endl;
1049         std::cout<<"add = "<<add<<std::endl;
1050         std::cout<<"A = "<<A<<std::endl;
1051         std::cout<<"x = "<<x<<std::endl;
1052         if (add) std::cout<<"y = "<<y<<std::endl;
1053 #endif
1054 
1055         TMVAssert(A.rowsize() == x.size());
1056         TMVAssert(A.colsize() == y.size());
1057         TMVAssert(alpha != T(0));
1058         TMVAssert(x.size() > 0);
1059         TMVAssert(y.size() > 0);
1060         TMVAssert(y.ct() == NonConj);
1061 
1062 #ifdef BLAS
1063         if (x.step() == 0) {
1064             if (x.size() <= 1)
1065                 DoMultMV<add>(
1066                     alpha,A,ConstVectorView<Tx>(x.cptr(),x.size(),1,x.ct()),y);
1067             else
1068                 DoMultMV<add>(alpha,A,Vector<Tx>(x),y);
1069         } else if (y.step() == 0) {
1070             TMVAssert(y.size() <= 1);
1071             DoMultMV<add>(alpha,A,x,VectorView<T>(y.ptr(),y.size(),1,y.ct()));
1072 #if 1
1073         } else if (y.step() != 1) {
1074             // Most BLAS implementations do fine with the y.step() < 0.
1075             // And in fact, they _should_ do ok even if the step is negative.
1076             // However, some implementations seem to propagate nan's from
1077             // the temporary memory they create to do the unit-1 calculation.
1078             // So to make sure they don't have to make a temporary, I just
1079             // do it here for them.
1080 #else
1081         } else if (y.step() < 0) {
1082 #endif
1083             Vector<T> yy(y.size());
1084             DoMultMV<false>(T(1),A,x,yy.view());
1085             if (add) y += alpha*yy;
1086             else y = alpha*yy;
1087 #if 1
1088         } else if (x.step() != 1) {
1089             // I don't think the non-unit step is a problem, but just to be
1090             // sure...
1091 #else
1092         } else if (x.step() < 0) {
1093 #endif
1094             Vector<T> xx = alpha*x;
1095             DoMultMV<add>(T(1),A,xx,y);
1096         } else if (BlasIsCM(A) || BlasIsRM(A)) {
1097             if (!SameStorage(A,y)) {
1098                 if (!SameStorage(x,y) && !SameStorage(A,x)) {
1099                     BlasMultMV(alpha,A,x,add?1:0,y);
1100                 } else {
1101                     Vector<T> xx = alpha*x;
1102                     BlasMultMV(T(1),A,xx,add?1:0,y);
1103                 }
1104             } else {
1105                 Vector<T> yy(y.size());
1106                 if (!SameStorage(A,x)) {
1107                     BlasMultMV(T(1),A,x,0,yy.view());
1108                     if (add) y += alpha*yy;
1109                     else y = alpha*yy;
1110                 } else {
1111                     Vector<T> xx = alpha*x;
1112                     BlasMultMV(T(1),A,xx,0,yy.view());
1113                     if (add) y += yy;
1114                     else y = yy;
1115                 }
1116             }
1117         } else {
1118             if (TMV_IMAG(alpha) == T(0)) {
1119                 Matrix<Ta,RowMajor> A2 = TMV_REAL(alpha)*A;
1120                 DoMultMV<add>(T(1),A2,x,y);
1121             } else {
1122                 Matrix<T,RowMajor> A2 = alpha*A;
1123                 DoMultMV<add>(T(1),A2,x,y);
1124             }
1125         }
1126 #else
1127         NonBlasMultMV<add>(alpha,A,x,y);
1128 #endif
1129 #ifdef XDEBUG
1130         std::cout<<"y => "<<y<<std::endl;
1131 #endif
1132     }
1133 
MultMV(const T alpha,const GenMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)1134     template <bool add, class T, class Ta, class Tx> void MultMV(
1135         const T alpha, const GenMatrix<Ta>& A, const GenVector<Tx>& x,
1136         VectorView<T> y)
1137         // y (+)= alpha * A * x
1138     {
1139 #ifdef XDEBUG
1140         cout<<"Start MultMV: alpha = "<<alpha<<endl;
1141         cout<<"add = "<<add<<endl;
1142         cout<<"A = "<<TMV_Text(A)<<"  "<<A<<endl;
1143         cout<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x<<endl;
1144         cout<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<endl;
1145         if (add) cout<<"y = "<<y<<endl;
1146         cout<<"ptrs = "<<A.cptr()<<"  "<<x.cptr()<<"  "<<y.cptr()<<std::endl;
1147         Vector<Tx> x0 = x;
1148         Vector<T> y0 = y;
1149         Matrix<Ta> A0 = A;
1150         Vector<T> y2 = y;
1151         for(ptrdiff_t i=0;i<y.size();i++) {
1152             if (add) y2(i) += alpha * (A.row(i) * x0);
1153             else y2(i) = alpha * (A.row(i) * x0);
1154         }
1155         cout<<"y2 = "<<y2<<endl;
1156 #endif
1157         TMVAssert(A.rowsize() == x.size());
1158         TMVAssert(A.colsize() == y.size());
1159 
1160         if (y.size() > 0) {
1161             if (x.size()==0 || alpha==T(0)) {
1162                 if (!add) y.setZero();
1163             } else if (y.isconj()) {
1164                 DoMultMV<add>(
1165                     TMV_CONJ(alpha),A.conjugate(),x.conjugate(),y.conjugate());
1166             } else {
1167                 DoMultMV<add>(alpha,A,x,y);
1168             }
1169         }
1170 
1171 #ifdef XDEBUG
1172         cout<<"y => "<<y<<endl;
1173         if (!(Norm(y-y2) <=
1174               0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
1175                      (add?Norm(y0):TMV_RealType(T)(0))))) {
1176             cerr<<"MultMV: alpha = "<<alpha<<endl;
1177             cerr<<"add = "<<add<<endl;
1178             cerr<<"A = "<<TMV_Text(A);
1179             if (A.rowsize() < 30 && A.colsize() < 30) cerr<<"  "<<A0;
1180             else cerr<<"  "<<A.colsize()<<" x "<<A.rowsize();
1181             cerr<<endl<<"x = "<<TMV_Text(x)<<" step "<<x.step();
1182             if (x.size() < 30) cerr<<"  "<<x0;
1183             cerr<<endl<<"y = "<<TMV_Text(y)<<" step "<<y.step();
1184             if (add && y.size() < 30) cerr<<"  "<<y0;
1185             cerr<<endl<<"Aptr = "<<A.cptr();
1186             cerr<<", xptr = "<<x.cptr()<<", yptr = "<<y.cptr()<<endl;
1187             if (y.size() < 200) {
1188                 cerr<<"--> y = "<<y<<endl;
1189                 cerr<<"y2 = "<<y2<<endl;
1190             } else {
1191                 ptrdiff_t imax;
1192                 (y-y2).maxAbsElement(&imax);
1193                 cerr<<"y("<<imax<<") = "<<y(imax)<<endl;
1194                 cerr<<"y2("<<imax<<") = "<<y2(imax)<<endl;
1195             }
1196             cerr<<"Norm(A0) = "<<Norm(A0)<<endl;
1197             cerr<<"Norm(x0) = "<<Norm(x0)<<endl;
1198             if (add) cerr<<"Norm(y0) = "<<Norm(y0)<<endl;
1199             cerr<<"|alpha|*|A0|*|x0|+?|y0| = "<<
1200                 TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
1201                 (add?Norm(y0):TMV_RealType(T)(0))<<endl;
1202             cerr<<"Norm(y-y2) = "<<Norm(y-y2)<<endl;
1203             cerr<<"NormInf(y-y2) = "<<NormInf(y-y2)<<endl;
1204             cerr<<"Norm1(y-y2) = "<<Norm1(y-y2)<<endl;
1205             abort();
1206         }
1207 #endif
1208     }
1209 
1210 #define InstFile "TMV_MultMV.inst"
1211 #include "TMV_Inst.h"
1212 #undef InstFile
1213 
1214 } // namespace tmv
1215 
1216 
1217