1 ///////////////////////////////////////////////////////////////////////////////
2 //                                                                           //
3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis     //
4 // Copyright (C) 1998 - 2016                                                 //
5 // All rights reserved                                                       //
6 //                                                                           //
7 // The project is hosted at https://code.google.com/p/tmv-cpp/               //
8 // where you can find the current version and current documention.           //
9 //                                                                           //
10 // For concerns or problems with the software, Mike may be contacted at      //
11 // mike_jarvis17 [at] gmail.                                                 //
12 //                                                                           //
13 // This software is licensed under a FreeBSD license.  The file              //
14 // TMV_LICENSE should have bee included with this distribution.              //
15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/.       //
16 //                                                                           //
17 // Essentially, you can use this software however you want provided that     //
18 // you include the TMV_LICENSE file in any distribution that uses it.        //
19 //                                                                           //
20 ///////////////////////////////////////////////////////////////////////////////
21 
22 
23 //#define XDEBUG
24 
25 
26 #include "TMV_Blas.h"
27 #include "tmv/TMV_BandMatrixArithFunc.h"
28 #include "tmv/TMV_BandMatrix.h"
29 #include "tmv/TMV_VectorArith.h"
30 #include "tmv/TMV_DiagMatrix.h"
31 #include "tmv/TMV_DiagMatrixArithFunc.h"
32 #ifdef BLAS
33 #include "tmv/TMV_MatrixArith.h"
34 #include "tmv/TMV_BandMatrixArith.h"
35 #endif
36 #include <iostream>
37 
38 // CBLAS trick of using RowMajor with ConjTrans when we have a
39 // case of A.conjugate() * x doesn't seem to be working with MKL 10.2.2.
40 // I haven't been able to figure out why.  (e.g. Is it a bug in the MKL
41 // code, or am I doing something wrong?)  So for now, just disable it.
42 #ifdef CBLAS
43 #undef CBLAS
44 #endif
45 
46 #ifdef XDEBUG
47 #include "tmv/TMV_MatrixArith.h"
48 #include <iostream>
49 using std::cout;
50 using std::cerr;
51 using std::endl;
52 #endif
53 
54 namespace tmv {
55 
56     //
57     // BandMatrixComposite
58     //
59 
60     template <class T>
ls() const61     ptrdiff_t BandMatrixComposite<T>::ls() const
62     {
63         return BandStorageLength(
64             ColMajor,this->colsize(),this->rowsize(),
65             this->nlo(),this->nhi());
66     }
67 
68     template <class T>
constLinearView() const69     ConstVectorView<T> BandMatrixComposite<T>::constLinearView() const
70     {
71         cptr(); // This makes the instantiation, but we don't need the result.
72         return ConstVectorView<T>(itsm1.get(),ls(),1,NonConj);
73     }
74 
75     template <class T>
cptr() const76     const T* BandMatrixComposite<T>::cptr() const
77     {
78         if (!itsm1.get()) {
79             ptrdiff_t cs = this->colsize();
80             ptrdiff_t rs = this->rowsize();
81             ptrdiff_t lo = this->nlo();
82             ptrdiff_t hi = this->nhi();
83             ptrdiff_t len = ls();
84             itsm1.resize(len);
85             ptrdiff_t si = stepi();
86             ptrdiff_t sj = stepj();
87             ptrdiff_t ds = si + sj;
88             itsm = this->isdm() ? itsm1.get()-lo*si : itsm1.get();
89             this->assignToB(BandMatrixView<T>(
90                     itsm,cs,rs,lo,hi,si,sj,ds,NonConj,len
91                     TMV_FIRSTLAST1(itsm1.get(),itsm1.get()+len) ) );
92         }
93         return itsm;
94     }
95 
96     template <class T>
stepi() const97     ptrdiff_t BandMatrixComposite<T>::stepi() const
98     { return 1; }
99 
100     template <class T>
stepj() const101     ptrdiff_t BandMatrixComposite<T>::stepj() const
102     { return this->nlo()+this->nhi(); }
103 
104     template <class T>
diagstep() const105     ptrdiff_t BandMatrixComposite<T>::diagstep() const
106     { return this->nlo()+this->nhi()+1; }
107 
108 
109     //
110     // MultMV
111     //
112 
113     template <bool add, bool cx, bool ca, bool rm,  class T, class Ta, class Tx>
RowMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)114     static void RowMultMV(
115         const GenBandMatrix<Ta>& A, const GenVector<Tx>& x,
116         VectorView<T> y)
117     {
118         TMVAssert(A.rowsize()==x.size());
119         TMVAssert(A.colsize()==y.size());
120         TMVAssert(x.size() > 0);
121         TMVAssert(y.size() > 0);
122         TMVAssert(y.ct() == NonConj);
123         TMVAssert(x.step()==1);
124         TMVAssert(y.step()==1);
125         TMVAssert(!SameStorage(x,y));
126         TMVAssert(cx == x.isconj());
127         TMVAssert(ca == A.isconj());
128         TMVAssert(rm == A.isrm());
129 
130         const ptrdiff_t si = A.stepi();
131         const ptrdiff_t sj = (rm ? 1 : A.stepj());
132         const ptrdiff_t ds = A.diagstep();
133         const ptrdiff_t M = A.colsize();
134         const ptrdiff_t N = A.rowsize();
135 
136         const Ta* Aij1 = A.cptr();
137         const Tx* xj1 = x.cptr();
138         T* yi = y.ptr();
139 
140         ptrdiff_t k=A.nlo();
141         ptrdiff_t i=0;
142         ptrdiff_t j1=0;
143         ptrdiff_t j2=A.nhi()+1;
144         ptrdiff_t len = j2; // len = j2-j1
145         for(; i<M; ++i, ++yi) {
146 #ifdef TMVFLDEBUG
147             TMVAssert(yi >= y._first);
148             TMVAssert(yi < y._last);
149 #endif
150             if (!add) *yi = T(0);
151 
152             // *yi += A.row(i,j1,j2) * x.subVector(j1,j2);
153             const Ta* Aij = Aij1;
154             const Tx* xj = xj1;
155             for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) {
156 #ifdef TMVFLDEBUG
157                 TMVAssert(yi >= y._first);
158                 TMVAssert(yi < y._last);
159 #endif
160                 *yi += (cx ? TMV_CONJ(*xj) : *xj) * (ca ? TMV_CONJ(*Aij) : *Aij);
161             }
162 
163             if (k>0) { --k; ++len; Aij1+=si; }
164             else { ++j1; ++xj1; Aij1+=ds; }
165             if (j2<N) ++j2;
166             else { --len; if (j1==N) { ++i, ++yi; break; } }
167         }
168         if (!add) for(;i<M; ++i, ++yi) {
169 #ifdef TMVFLDEBUG
170             TMVAssert(yi >= y._first);
171             TMVAssert(yi < y._last);
172 #endif
173             *yi = T(0);
174         }
175     }
176 
177     template <bool add, bool cx, bool ca, bool cm, class T, class Ta, class Tx>
ColMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)178     static void ColMultMV(
179         const GenBandMatrix<Ta>& A, const GenVector<Tx>& x, VectorView<T> y)
180     {
181         TMVAssert(A.rowsize() == x.size());
182         TMVAssert(A.colsize() == y.size());
183         TMVAssert(x.size() > 0);
184         TMVAssert(y.size() > 0);
185         TMVAssert(y.ct() == NonConj);
186         TMVAssert(x.step()==1);
187         TMVAssert(y.step()==1);
188         TMVAssert(!SameStorage(x,y));
189         TMVAssert(cx == x.isconj());
190         TMVAssert(ca == A.isconj());
191         TMVAssert(cm == A.iscm());
192 
193         const ptrdiff_t N = A.rowsize();
194         const ptrdiff_t M = A.colsize();
195         const ptrdiff_t si = (cm ? 1 : A.stepi());
196         const ptrdiff_t sj = A.stepj();
197         const ptrdiff_t ds = A.diagstep();
198 
199         const Ta* Ai1j = A.cptr();
200         const Tx* xj = x.cptr();
201         T* yi1 = y.ptr();
202 
203         ptrdiff_t k=A.nhi();
204         ptrdiff_t i1=0;
205         ptrdiff_t i2=A.nlo()+1;
206         ptrdiff_t len = i2; // = i2-i1
207 
208         if (!add) y.setZero();
209 
210         for(ptrdiff_t j=N; j>0; --j,++xj) {
211             if (*xj != Tx(0)) {
212                 // y.subVector(i1,i2) += *xj * A.col(j,i1,i2);
213                 const Ta* Aij = Ai1j;
214                 T* yi = yi1;
215                 for(ptrdiff_t i=len;i>0;--i,++yi,(cm?++Aij:Aij+=si)) {
216 #ifdef TMVFLDEBUG
217                     TMVAssert(yi >= y._first);
218                     TMVAssert(yi < y._last);
219 #endif
220                     *yi +=
221                         (cx ? TMV_CONJ(*xj) : *xj) *
222                         (ca ? TMV_CONJ(*Aij) : *Aij);
223                 }
224             }
225             if (k>0) { --k; Ai1j+=sj; ++len; }
226             else { ++i1; ++yi1; Ai1j+=ds; }
227             if (i2<M) ++i2;
228             else { --len; if (i1==M) break; }
229         }
230     }
231 
232     template <bool add, bool cx, bool ca, bool dm, class T, class Ta, class Tx>
DiagMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)233     static void DiagMultMV(
234         const GenBandMatrix<Ta>& A, const GenVector<Tx>& x,
235         VectorView<T> y)
236     {
237         TMVAssert(A.rowsize() == x.size());
238         TMVAssert(A.colsize() == y.size());
239         TMVAssert(x.size() > 0);
240         TMVAssert(y.size() > 0);
241         TMVAssert(y.ct() == NonConj);
242         TMVAssert(x.step()==1);
243         TMVAssert(y.step()==1);
244         TMVAssert(!SameStorage(x,y));
245         TMVAssert(cx == x.isconj());
246         TMVAssert(ca == A.isconj());
247         TMVAssert(dm == A.isdm());
248 
249         const ptrdiff_t si = A.stepi();
250         const ptrdiff_t sj = A.stepj();
251         const ptrdiff_t ds = A.diagstep();
252         const ptrdiff_t lo = A.nlo();
253         const ptrdiff_t hi = A.nhi();
254         const ptrdiff_t M = A.colsize();
255         const ptrdiff_t N = A.rowsize();
256 
257         const Ta* Ai1j1 = A.cptr() + lo*si;
258         const Tx* xj1 = x.cptr();
259         T* yi1 = y.ptr() + lo;
260 
261         ptrdiff_t j2=TMV_MIN(M-lo,N);
262         ptrdiff_t len=j2; // == j2-j1
263 
264         if (!add) y.setZero();
265 
266         for(ptrdiff_t k=-A.nlo(); k<=hi; ++k) {
267             // y.subVector(i1,i2) += DiagMatrixViewOf(A.diag(k)) *
268             //     x.subVector(j1,j2);
269             const Ta* Aij = Ai1j1;
270             const Tx* xj = xj1;
271             T* yi = yi1;
272             for(ptrdiff_t i=len;i>0;--i,++yi,++xj,(dm?++Aij:Aij+=ds)) {
273 #ifdef TMVFLDEBUG
274                 TMVAssert(yi >= y._first);
275                 TMVAssert(yi < y._last);
276 #endif
277                 *yi +=
278                     (cx ? TMV_CONJ(*xj) : *xj) *
279                     (ca ? TMV_CONJ(*Aij) : *Aij);
280             }
281             if (k<0) { --yi1; ++len; Ai1j1-=si; }
282             else { ++xj1; Ai1j1+=sj; }
283             if (j2 < N) ++j2; else --len;
284         }
285     }
286 
287     template <bool add, bool cx, class T, class Ta, class Tx>
UnitAMultMV1(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)288     static void UnitAMultMV1(
289         const GenBandMatrix<Ta>& A, const GenVector<Tx>& x,
290         VectorView<T> y)
291     {
292         TMVAssert(A.rowsize() == x.size());
293         TMVAssert(A.colsize() == y.size());
294         TMVAssert(x.size() > 0);
295         TMVAssert(y.size() > 0);
296         TMVAssert(y.ct() == NonConj);
297         TMVAssert(x.step()==1);
298         TMVAssert(y.step()==1);
299         TMVAssert(!SameStorage(x,y));
300         TMVAssert(cx == x.isconj());
301 
302         if (A.isrm())
303             if (A.isconj())
304                 RowMultMV<add,cx,true,true>(A,x,y);
305             else
306                 RowMultMV<add,cx,false,true>(A,x,y);
307         else if (A.iscm())
308             if (A.isconj())
309                 ColMultMV<add,cx,true,true>(A,x,y);
310             else
311                 ColMultMV<add,cx,false,true>(A,x,y);
312         else if (A.isdm())
313             if (A.isconj())
314                 DiagMultMV<add,cx,true,true>(A,x,y);
315             else
316                 DiagMultMV<add,cx,false,true>(A,x,y);
317         else
318             if (A.isconj())
319                 DiagMultMV<add,cx,true,false>(A,x,y);
320             else
321                 DiagMultMV<add,cx,false,false>(A,x,y);
322     }
323 
324     template <bool add, bool cx, class T, class Ta, class Tx>
UnitAMultMV(const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)325     static void UnitAMultMV(
326         const GenBandMatrix<Ta>& A, const GenVector<Tx>& x,
327         VectorView<T> y)
328     {
329         // Check for 0's in beginning or end of x:
330         // y += [ A1 A2 A3 ] [ 0 ]  -->  y += A2 x
331         //                   [ x ]
332         //                   [ 0 ]
333 
334         const ptrdiff_t N = x.size(); // = A.rowsize()
335         ptrdiff_t j2 = N;
336         for(const Tx* x2=x.cptr()+N-1; j2>0 && *x2==Tx(0); --j2,--x2);
337         if (j2 == 0) {
338             if (!add) y.setZero();
339             return;
340         }
341         ptrdiff_t j1 = 0;
342         for(const Tx* x1=x.cptr(); *x1==Tx(0); ++j1,++x1);
343         if (j1 == 0 && j2 == N) UnitAMultMV1<add,cx>(A,x,y);
344         else {
345             const ptrdiff_t hi = A.nhi();
346             const ptrdiff_t lo = A.nlo();
347             const ptrdiff_t M = y.size(); // = A.colsize()
348             // This next bit is copied from the BandMatrix colRange function
349             ptrdiff_t i1 = j1 > hi ? j1-hi : 0;
350             ptrdiff_t i2 = TMV_MIN(j2+lo,M);
351             ptrdiff_t newhi = j1 < hi ? hi-j1 : 0;
352             ptrdiff_t newlo = lo+hi-newhi;
353             ptrdiff_t newM = i2-i1;
354             ptrdiff_t newN = j2-j1;
355             TMVAssert(newM > 0);
356             TMVAssert(newN > 0);
357             if (newhi >= newN) newhi = newN-1;
358             if (newlo >= newM) newlo = newM-1;
359             TMVAssert(A.hasSubBandMatrix(i1,i2,j1,j2,newlo,newhi,1,1));
360             const Ta* p = A.cptr()+i1*A.stepi()+j1*A.stepj();
361             ConstBandMatrixView<Ta> Acols(
362                 p,newM,newN,newlo,newhi,
363                 A.stepi(),A.stepj(),A.diagstep(),A.ct());
364             UnitAMultMV1<add,cx>(Acols,x.subVector(j1,j2),y.subVector(i1,i2));
365             if (!add) {
366                 y.subVector(0,i1).setZero();
367                 y.subVector(i2,M).setZero();
368             }
369         }
370     }
371 
372     template <bool add, class T, class Ta, class Tx>
NonBlasMultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)373     static void NonBlasMultMV(
374         const T alpha, const GenBandMatrix<Ta>& A, const GenVector<Tx>& x,
375         VectorView<T> y)
376     // y (+)= alpha * A * x
377     {
378         TMVAssert(A.rowsize() == x.size());
379         TMVAssert(A.colsize() == y.size());
380         TMVAssert(alpha != T(0));
381         TMVAssert(x.size() > 0);
382         TMVAssert(y.size() > 0);
383         TMVAssert(y.ct() == NonConj);
384 
385 #ifdef XDEBUG
386         cout<<"NonBlasMultMV: A = "<<A<<endl;
387         Vector<T> y0 = y;
388         Vector<Tx> x0 = x;
389         Matrix<Ta> A0 = A;
390         Vector<T> y2 = alpha*A0*x0;
391         if (add) y2 += y0;
392 #endif
393 
394         if (x.step() != 1 || SameStorage(x,y)) {
395             if (TMV_IMAG(alpha) == TMV_RealType(T)(0)) {
396                 Vector<Tx> xx = TMV_REAL(alpha) * x;
397                 if (y.step() != 1) {
398                     Vector<T> yy(y.size());
399                     UnitAMultMV<false,false>(A,xx,yy.view());
400                     if (add) y += yy;
401                     else y = yy;
402                 } else {
403                     UnitAMultMV<add,false>(A,xx,y);
404                 }
405             } else {
406                 Vector<T> xx = alpha * x;
407                 if (y.step() != 1) {
408                     Vector<T> yy(y.size());
409                     UnitAMultMV<false,false>(A,xx,yy.view());
410                     if (add) y += yy;
411                     else y = yy;
412                 } else {
413                     UnitAMultMV<add,false>(A,xx,y);
414                 }
415             }
416         } else if (y.step() != 1 || alpha != TMV_RealType(T)(1)) {
417             Vector<T> yy(y.size());
418             if (x.isconj())
419                 UnitAMultMV<false,true>(A,x,yy.view());
420             else
421                 UnitAMultMV<false,false>(A,x,yy.view());
422             if (add) y += alpha * yy;
423             else y = alpha * yy;
424         } else {
425             TMVAssert(alpha == T(1));
426             TMVAssert(y.step() == 1);
427             TMVAssert(x.step() == 1);
428             TMVAssert(!SameStorage(x,y));
429             if (x.isconj())
430                 UnitAMultMV<add,true>(A,x,y);
431             else
432                 UnitAMultMV<add,false>(A,x,y);
433         }
434 
435 #ifdef XDEBUG
436         if (!(Norm(y2-y) <= 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
437                                 (add?Norm(y0):TMV_RealType(T)(0))))) {
438             cerr<<"NonBlas MultMV: alpha = "<<alpha<<endl;
439             cerr<<"add = "<<add<<endl;
440             cerr<<"A = "<<TMV_Text(A)<<"  "<<A.cptr()<<"  "<<A0<<endl;
441             cerr<<"x = "<<TMV_Text(x)<<"  "<<x.cptr()<<
442                 " step "<<x.step()<<"  "<<x0<<endl;
443             cerr<<"y = "<<TMV_Text(y)<<"  "<<y.cptr()<<
444                 " step "<<y.step()<<"  "<<y0<<endl;
445             cerr<<"--> y = "<<y<<endl;
446             cerr<<"y2 = "<<y2<<endl;
447             abort();
448         }
449 #endif
450     }
451 
452 #ifdef BLAS
453     template <class T, class Ta, class Tx>
BlasMultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,int beta,VectorView<T> y)454     static inline void BlasMultMV(
455         const T alpha, const GenBandMatrix<Ta>& A,
456         const GenVector<Tx>& x, int beta, VectorView<T> y)
457     {
458         if (beta == 1) NonBlasMultMV<true>(alpha,A,x,y);
459         else NonBlasMultMV<false>(alpha,A,x,y);
460     }
461 #ifdef INST_DOUBLE
462     template <>
BlasMultMV(const double alpha,const GenBandMatrix<double> & A,const GenVector<double> & x,int beta,VectorView<double> y)463     void BlasMultMV(
464         const double alpha,
465         const GenBandMatrix<double>& A, const GenVector<double>& x,
466         int beta, VectorView<double> y)
467     {
468         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
469         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
470         int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
471         int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
472         int ds = A.diagstep();
473         int xs = x.step();
474         int ys = y.step();
475         const double* xp = x.cptr();
476         if (xs < 0) xp += (x.size()-1)*xs;
477         double* yp = y.ptr();
478         if (ys < 0) yp += (y.size()-1)*ys;
479         if (beta == 0) y.setZero();
480         double xbeta(1);
481         BLASNAME(dgbmv) (
482             BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
483             BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
484             BLASV(alpha),BLASP(A.cptr()-hi),BLASV(ds),
485             BLASP(xp),BLASV(xs),BLASV(xbeta),
486             BLASP(yp),BLASV(ys) BLAS1);
487     }
488     template <>
BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<std::complex<double>> & A,const GenVector<std::complex<double>> & x,int beta,VectorView<std::complex<double>> y)489     void BlasMultMV(
490         const std::complex<double> alpha,
491         const GenBandMatrix<std::complex<double> >& A,
492         const GenVector<std::complex<double> >& x,
493         int beta, VectorView<std::complex<double> > y)
494     {
495         if (x.isconj()
496 #ifndef CBLAS
497             && !(A.isconj() && BlasIsCM(A))
498 #endif
499         ) {
500             Vector<std::complex<double> > xx = alpha*x;
501             return BlasMultMV(std::complex<double>(1),A,xx,beta,y);
502         }
503 
504         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
505         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
506         int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
507         int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
508         int ds = A.diagstep();
509         int xs = x.step();
510         int ys = y.step();
511         const std::complex<double>* xp = x.cptr();
512         if (xs < 0) xp += (x.size()-1)*xs;
513         std::complex<double>* yp = y.ptr();
514         if (ys < 0) yp += (y.size()-1)*ys;
515         if (beta == 0) y.setZero();
516         std::complex<double> xbeta(1);
517 
518         if (A.isconj() && BlasIsCM(A)) {
519 #ifdef CBLAS
520             TMV_SWAP(m,n);
521             TMV_SWAP(lo,hi);
522             BLASNAME(zgbmv) (
523                 BLASRM BLASCH_CT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
524                 BLASP(&alpha),BLASP(A.cptr()-lo),BLASV(ds),
525                 BLASP(xp),BLASV(xs),BLASP(&xbeta),
526                 BLASP(yp),BLASV(ys) BLAS1);
527 #else
528             std::complex<double> ca = TMV_CONJ(alpha);
529             if (x.isconj()) {
530                 y.conjugateSelf();
531                 BLASNAME(zgbmv) (
532                     BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
533                     BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds),
534                     BLASP(xp),BLASV(xs),BLASP(&xbeta),
535                     BLASP(yp),BLASV(ys) BLAS1);
536                 y.conjugateSelf();
537             } else {
538                 Vector<std::complex<double> > xx=ca*x.conjugate();
539                 ca = std::complex<double>(1);
540                 xs = 1;
541                 xp = xx.cptr();
542                 y.conjugateSelf();
543                 BLASNAME(zgbmv) (
544                     BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
545                     BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds),
546                     BLASP(xp),BLASV(xs),BLASP(&xbeta),
547                     BLASP(yp),BLASV(ys) BLAS1);
548                 y.conjugateSelf();
549             }
550 #endif
551         } else {
552             BLASNAME(zgbmv) (
553                 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
554                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
555                 BLASP(&alpha),BLASP(A.cptr()-hi),BLASV(ds),
556                 BLASP(xp),BLASV(xs),BLASP(&xbeta),
557                 BLASP(yp),BLASV(ys) BLAS1);
558         }
559     }
560     template <>
BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<std::complex<double>> & A,const GenVector<double> & x,int beta,VectorView<std::complex<double>> y)561     void BlasMultMV(
562         const std::complex<double> alpha,
563         const GenBandMatrix<std::complex<double> >& A,
564         const GenVector<double>& x,
565         int beta, VectorView<std::complex<double> > y)
566     { BlasMultMV(alpha,A,Vector<std::complex<double> >(x),beta,y); }
567     template <>
BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<double> & A,const GenVector<std::complex<double>> & x,int beta,VectorView<std::complex<double>> y)568     void BlasMultMV(
569         const std::complex<double> alpha,
570         const GenBandMatrix<double>& A,
571         const GenVector<std::complex<double> >& x,
572         int beta, VectorView<std::complex<double> > y)
573     {
574         if (beta == 0) {
575             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
576             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
577             int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
578             int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
579             int ds = A.diagstep();
580             int xs = 2*x.step();
581             int ys = 2*y.step();
582             const double* xp = (const double*) x.cptr();
583             if (xs < 0) xp += (x.size()-1)*xs;
584             double* yp = (double*) y.ptr();
585             if (ys < 0) yp += (y.size()-1)*ys;
586             double xalpha(1);
587             y.setZero();
588             double xbeta(1);
589             BLASNAME(dgbmv) (
590                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
591                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
592                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
593                 BLASP(xp),BLASV(xs),BLASV(xbeta),
594                 BLASP(yp),BLASV(ys) BLAS1);
595             BLASNAME(dgbmv) (
596                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
597                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
598                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
599                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
600                 BLASP(yp+1),BLASV(ys) BLAS1);
601             if (x.isconj()) y.conjugateSelf();
602             y *= alpha;
603         } else if (TMV_IMAG(alpha) == 0. && !x.isconj()) {
604             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
605             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
606             int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
607             int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
608             int ds = A.diagstep();
609             int xs = 2*x.step();
610             int ys = 2*y.step();
611             const double* xp = (const double*) x.cptr();
612             if (xs < 0) xp += (x.size()-1)*xs;
613             double* yp = (double*)  y.ptr();
614             if (ys < 0) yp += (y.size()-1)*ys;
615             double xalpha(TMV_REAL(alpha));
616             double xbeta(beta);
617             BLASNAME(dgbmv) (
618                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
619                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
620                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
621                 BLASP(xp),BLASV(xs),BLASV(xbeta),
622                 BLASP(yp),BLASV(ys) BLAS1);
623             BLASNAME(dgbmv) (
624                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
625                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
626                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
627                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
628                 BLASP(yp+1),BLASV(ys) BLAS1);
629         } else {
630             Vector<std::complex<double> > xx = alpha*x;
631             BlasMultMV(std::complex<double>(1),A,xx,1,y);
632         }
633     }
634     template <>
BlasMultMV(const std::complex<double> alpha,const GenBandMatrix<double> & A,const GenVector<double> & x,int beta,VectorView<std::complex<double>> y)635     void BlasMultMV(
636         const std::complex<double> alpha,
637         const GenBandMatrix<double>& A,
638         const GenVector<double>& x,
639         int beta, VectorView<std::complex<double> > y)
640     {
641         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
642         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
643         int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
644         int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
645         int ds = A.diagstep();
646         int xs = x.step();
647         int ys = 2*y.step();
648         const double* xp = x.cptr();
649         if (xs < 0) xp += (x.size()-1)*xs;
650         double* yp = (double*) y.ptr();
651         if (ys < 0) yp += (y.size()-1)*ys;
652         double ar(TMV_REAL(alpha));
653         double ai(TMV_IMAG(alpha));
654         if (beta == 0) y.setZero();
655         double xbeta(1);
656         if (ar != 0.) {
657             BLASNAME(dgbmv) (
658                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
659                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
660                 BLASV(ar),BLASP(A.cptr()-hi),BLASV(ds),
661                 BLASP(xp),BLASV(xs),BLASV(xbeta),
662                 BLASP(yp),BLASV(ys) BLAS1);
663         }
664         if (ai != 0.) {
665             BLASNAME(dgbmv) (
666                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
667                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
668                 BLASV(ai),BLASP(A.cptr()-hi),BLASV(ds),
669                 BLASP(xp),BLASV(xs),BLASV(xbeta),
670                 BLASP(yp+1),BLASV(ys) BLAS1);
671         }
672     }
673 #endif
674 #ifdef INST_FLOAT
675     template <>
BlasMultMV(const float alpha,const GenBandMatrix<float> & A,const GenVector<float> & x,int beta,VectorView<float> y)676     void BlasMultMV(
677         const float alpha,
678         const GenBandMatrix<float>& A, const GenVector<float>& x,
679         int beta, VectorView<float> y)
680     {
681         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
682         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
683         int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
684         int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
685         int ds = A.diagstep();
686         int xs = x.step();
687         int ys = y.step();
688         const float* xp = x.cptr();
689         if (xs < 0) xp += (x.size()-1)*xs;
690         float* yp = y.ptr();
691         if (ys < 0) yp += (y.size()-1)*ys;
692         if (beta == 0) y.setZero();
693         float xbeta(1);
694         BLASNAME(sgbmv) (
695             BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
696             BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
697             BLASV(alpha),BLASP(A.cptr()-hi),BLASV(ds),
698             BLASP(xp),BLASV(xs),BLASV(xbeta),
699             BLASP(yp),BLASV(ys) BLAS1);
700     }
701     template <>
BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<std::complex<float>> & A,const GenVector<std::complex<float>> & x,int beta,VectorView<std::complex<float>> y)702     void BlasMultMV(
703         const std::complex<float> alpha,
704         const GenBandMatrix<std::complex<float> >& A,
705         const GenVector<std::complex<float> >& x,
706         int beta, VectorView<std::complex<float> > y)
707     {
708         if (x.isconj()
709 #ifndef CBLAS
710             && !(A.isconj() && BlasIsCM(A))
711 #endif
712         ) {
713             Vector<std::complex<float> > xx = alpha*x;
714             return BlasMultMV(std::complex<float>(1),A,xx,beta,y);
715         }
716 
717         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
718         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
719         int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
720         int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
721         int ds = A.diagstep();
722         int xs = x.step();
723         int ys = y.step();
724         const std::complex<float>* xp = x.cptr();
725         if (xs < 0) xp += (x.size()-1)*xs;
726         std::complex<float>* yp = y.ptr();
727         if (ys < 0) yp += (y.size()-1)*ys;
728         if (beta == 0) y.setZero();
729         std::complex<float> xbeta(1);
730 
731         if (A.isconj() && BlasIsCM(A)) {
732 #ifdef CBLAS
733             TMV_SWAP(m,n);
734             TMV_SWAP(lo,hi);
735             BLASNAME(cgbmv) (
736                 BLASRM BLASCH_CT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
737                 BLASP(&alpha),BLASP(A.cptr()-lo),BLASV(ds),
738                 BLASP(xp),BLASV(xs),BLASP(&xbeta),
739                 BLASP(yp),BLASV(ys) BLAS1);
740 #else
741             std::complex<float> ca = TMV_CONJ(alpha);
742             if (x.isconj()) {
743                 y.conjugateSelf();
744                 BLASNAME(cgbmv) (
745                     BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
746                     BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds),
747                     BLASP(xp),BLASV(xs),BLASP(&xbeta),
748                     BLASP(yp),BLASV(ys) BLAS1);
749                 y.conjugateSelf();
750             } else {
751                 Vector<std::complex<float> > xx=ca*x.conjugate();
752                 ca = std::complex<float>(1);
753                 xs = 1;
754                 xp = xx.cptr();
755                 y.conjugateSelf();
756                 BLASNAME(cgbmv) (
757                     BLASCM BLASCH_NT, BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
758                     BLASP(&ca),BLASP(A.cptr()-hi),BLASV(ds),
759                     BLASP(xp),BLASV(xs),BLASP(&xbeta),
760                     BLASP(yp),BLASV(ys) BLAS1);
761                 y.conjugateSelf();
762             }
763 #endif
764         } else {
765             BLASNAME(cgbmv) (
766                 BLASCM BlasIsCM(A)?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
767                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
768                 BLASP(&alpha),BLASP(A.cptr()-hi),BLASV(ds),
769                 BLASP(xp),BLASV(xs),BLASP(&xbeta),
770                 BLASP(yp),BLASV(ys) BLAS1);
771         }
772     }
773     template <>
BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<std::complex<float>> & A,const GenVector<float> & x,int beta,VectorView<std::complex<float>> y)774     void BlasMultMV(
775         const std::complex<float> alpha,
776         const GenBandMatrix<std::complex<float> >& A,
777         const GenVector<float>& x,
778         int beta, VectorView<std::complex<float> > y)
779     { BlasMultMV(alpha,A,Vector<std::complex<float> >(x),beta,y); }
780     template <>
BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<float> & A,const GenVector<std::complex<float>> & x,int beta,VectorView<std::complex<float>> y)781     void BlasMultMV(
782         const std::complex<float> alpha,
783         const GenBandMatrix<float>& A,
784         const GenVector<std::complex<float> >& x,
785         int beta, VectorView<std::complex<float> > y)
786     {
787         if (beta == 0) {
788             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
789             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
790             int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
791             int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
792             int ds = A.diagstep();
793             int xs = 2*x.step();
794             int ys = 2*y.step();
795             const float* xp = (const float*) x.cptr();
796             if (xs < 0) xp += (x.size()-1)*xs;
797             float* yp = (float*) y.ptr();
798             if (ys < 0) yp += (y.size()-1)*ys;
799             float xalpha(1);
800             y.setZero();
801             float xbeta(1);
802             BLASNAME(sgbmv) (
803                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
804                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
805                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
806                 BLASP(xp),BLASV(xs),BLASV(xbeta),
807                 BLASP(yp),BLASV(ys) BLAS1);
808             BLASNAME(sgbmv) (
809                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
810                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
811                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
812                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
813                 BLASP(yp+1),BLASV(ys) BLAS1);
814             if (x.isconj()) y.conjugateSelf();
815             y *= alpha;
816         } else if (TMV_IMAG(alpha) == 0.F && !x.isconj()) {
817             int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
818             int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
819             int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
820             int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
821             int ds = A.diagstep();
822             int xs = 2*x.step();
823             int ys = 2*y.step();
824             const float* xp = (const float*) x.cptr();
825             if (xs < 0) xp += (x.size()-1)*xs;
826             float* yp = (float*) y.ptr();
827             if (ys < 0) yp += (y.size()-1)*ys;
828             float xalpha(TMV_REAL(alpha));
829             float xbeta(1);
830             BLASNAME(sgbmv) (
831                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
832                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
833                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
834                 BLASP(xp),BLASV(xs),BLASV(xbeta),
835                 BLASP(yp),BLASV(ys) BLAS1);
836             BLASNAME(sgbmv) (
837                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
838                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
839                 BLASV(xalpha),BLASP(A.cptr()-hi),BLASV(ds),
840                 BLASP(xp+1),BLASV(xs),BLASV(xbeta),
841                 BLASP(yp+1),BLASV(ys) BLAS1);
842         } else {
843             Vector<std::complex<float> > xx = alpha*x;
844             BlasMultMV(std::complex<float>(1),A,xx,1,y);
845         }
846     }
847     template <>
BlasMultMV(const std::complex<float> alpha,const GenBandMatrix<float> & A,const GenVector<float> & x,int beta,VectorView<std::complex<float>> y)848     void BlasMultMV(
849         const std::complex<float> alpha,
850         const GenBandMatrix<float>& A,
851         const GenVector<float>& x,
852         int beta, VectorView<std::complex<float> > y)
853     {
854         int m = BlasIsCM(A) ? A.colsize() : A.rowsize();
855         int n = BlasIsCM(A) ? A.rowsize() : A.colsize();
856         int lo = BlasIsCM(A) ? A.nlo() : A.nhi();
857         int hi = BlasIsCM(A) ? A.nhi() : A.nlo();
858         int ds = A.diagstep();
859         int xs = x.step();
860         int ys = 2*y.step();
861         const float* xp = x.cptr();
862         if (xs < 0) xp += (x.size()-1)*xs;
863         float* yp = (float*) y.ptr();
864         if (ys < 0) yp += (y.size()-1)*ys;
865         float ar(TMV_REAL(alpha));
866         float ai(TMV_IMAG(alpha));
867         if (beta == 0) y.setZero();
868         float xbeta(1);
869         if (ar != 0.F) {
870             BLASNAME(sgbmv) (
871                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
872                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
873                 BLASV(ar),BLASP(A.cptr()-hi),BLASV(ds),
874                 BLASP(xp),BLASV(xs),BLASV(xbeta),
875                 BLASP(yp),BLASV(ys) BLAS1);
876         }
877         if (ai != 0.F) {
878             BLASNAME(sgbmv) (
879                 BLASCM BlasIsCM(A)?BLASCH_NT:BLASCH_T,
880                 BLASV(m),BLASV(n),BLASV(lo),BLASV(hi),
881                 BLASV(ai),BLASP(A.cptr()-hi),BLASV(ds),
882                 BLASP(xp),BLASV(xs),BLASV(xbeta),
883                 BLASP(yp+1),BLASV(ys) BLAS1);
884         }
885     }
886 #endif
887 #endif // BLAS
888 
889     template <bool add, class T, class Ta, class Tx>
DoMultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)890     static void DoMultMV(
891         const T alpha, const GenBandMatrix<Ta>& A, const GenVector<Tx>& x,
892         VectorView<T> y)
893     {
894         //cout<<"Start DoMultMV\n";
895         //cout<<"A = "<<TMV_Text(A)<<"  "<<A.cptr()<<"  "<<A<<endl;
896         //cout<<"x = "<<TMV_Text(x)<<"  "<<x.cptr()<<"  step "<<x.step()<<"  "<<x<<endl;
897         //cout<<"y = "<<TMV_Text(y)<<"  "<<y.cptr()<<"  step "<<y.step()<<"  "<<y<<endl;
898         //cout<<"alpha = "<<alpha<<", add = "<<add<<endl;
899         TMVAssert(A.rowsize() == x.size());
900         TMVAssert(A.colsize() == y.size());
901         TMVAssert(alpha != T(0));
902         TMVAssert(x.size() > 0);
903         TMVAssert(y.size() > 0);
904 
905         if (y.isconj()) {
906             DoMultMV<add>(
907                 TMV_CONJ(alpha),A.conjugate(),x.conjugate(),y.conjugate());
908         } else {
909 #ifdef BLAS
910             if (x.step() == 0) {
911                 if (x.size() <= 1)
912                     DoMultMV<add>(
913                         alpha,A,
914                         ConstVectorView<Tx>(x.cptr(),x.size(),1,x.ct()),y);
915                 else
916                     DoMultMV<add>(alpha,A,Vector<Tx>(x),y);
917             } else if (y.step() == 0) {
918                 TMVAssert(y.size() <= 1);
919                 DoMultMV<add>(
920                     alpha,A,x,VectorView<T>(y.ptr(),y.size(),1,y.ct()));
921             } else if (BlasIsRM(A) || BlasIsCM(A)) {
922                 if (!SameStorage(A,y)) {
923                     if (!SameStorage(x,y) && !SameStorage(A,x)) {
924                         BlasMultMV(alpha,A,x,add?1:0,y);
925                     } else {
926                         Vector<T> xx = alpha*x;
927                         BlasMultMV(T(1),A,xx,add?1:0,y);
928                     }
929                 } else {
930                     Vector<T> yy(y.size(),T(0));
931                     if (!SameStorage(A,x)) {
932                         BlasMultMV(T(1),A,x,0,yy.view());
933                         if (add) y += alpha*yy;
934                         else y = alpha*yy;
935                     } else {
936                         Vector<T> xx = alpha*x;
937                         BlasMultMV(T(1),A,xx,0,yy.view());
938                         if (add) y += yy;
939                         else y = yy;
940                     }
941                 }
942             } else if ((A.isrm() && A.stepi() < A.nlo()+A.nhi()) ||
943                        (A.iscm() && A.stepj() < A.nlo()+A.nhi())) {
944                 if (SameStorage(A,y)) {
945                     Vector<T> yy(y.size(),T(0));
946                     DoMultMV<false>(T(1),A,x,yy.view());
947                     if (add) y += alpha*yy;
948                     else y = alpha*yy;
949                 } else if (SameStorage(x,y)) {
950                     DoMultMV<add>(T(1),A,Vector<T>(alpha*x),y);
951                 } else if (A.nlo()+1 == A.colsize()) {
952                     if (A.nhi()+1 == A.rowsize()) {
953                         ConstMatrixView<Ta> A1 =
954                             A.subMatrix(0,A.colsize(),0,A.rowsize());
955                         MultMV<add>(alpha,A1,x,y);
956                     } else {
957                         ConstMatrixView<Ta> A1 =
958                             A.subMatrix(0,A.colsize(),0,A.nhi());
959                         MultMV<add>(alpha,A1,x.subVector(0,A.nhi()),y);
960                         ConstBandMatrixView<Ta> A2 = A.colRange(A.nhi(),A.rowsize());
961                         BlasMultMV(alpha,A2,x.subVector(A.nhi(),A.rowsize()),1,y);
962                     }
963                 } else {
964                     TMVAssert(A.nlo()>0);
965                     if (A.nhi()+1 == A.rowsize()) {
966                         ConstMatrixView<Ta> A1 =
967                             A.subMatrix(0,A.nlo(),0,A.rowsize());
968                         MultMV<add>(alpha,A1,x,y.subVector(0,A.nlo()));
969                     } else {
970                         ConstBandMatrixView<Ta> A1 = A.rowRange(0,A.nlo());
971                         BlasMultMV(alpha,A1,x.subVector(0,A1.rowsize()),
972                                    add?1:0,y.subVector(0,A.nlo()));
973                     }
974                     ConstBandMatrixView<Ta> A2 = A.rowRange(A.nlo(),A.colsize());
975                     BlasMultMV(alpha,A2,x,add?1:0,y.subVector(A.nlo(),A.colsize()));
976                 }
977             } else {
978                 if (TMV_IMAG(alpha) == T(0)) {
979                     BandMatrix<Ta,RowMajor> A2 = TMV_REAL(alpha)*A;
980                     DoMultMV<add>(T(1),A2,x,y);
981                 } else {
982                     BandMatrix<T,RowMajor> A2 = alpha*A;
983                     DoMultMV<add>(T(1),A2,x,y);
984                 }
985             }
986 #else
987             NonBlasMultMV<add>(alpha,A,x,y);
988 #endif
989         }
990         //std::cout<<"Done DoMultMV\n";
991     }
992 
993     //
994     // MultEqMV
995     //
996 
997     template <bool rm, bool ca, class T, class Ta>
DoRowUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)998     static void DoRowUpperMultEqMV(
999         const GenBandMatrix<Ta>& A, VectorView<T> x)
1000     {
1001         TMVAssert(A.isSquare());
1002         TMVAssert(A.colsize() == x.size());
1003         TMVAssert(x.size() > 0);
1004         TMVAssert(x.step()==1);
1005         TMVAssert(x.ct() == NonConj);
1006         TMVAssert(rm == A.isrm());
1007         TMVAssert(ca == A.isconj());
1008 
1009         const ptrdiff_t N = x.size();
1010         const ptrdiff_t sj = (rm ? 1 : A.stepj());
1011         const ptrdiff_t ds = A.diagstep();
1012 
1013         T* xi = x.ptr();
1014         const Ta* Aii = A.cptr();
1015         ptrdiff_t j2=A.nhi()+1;
1016         ptrdiff_t len = j2-1;
1017 
1018         for(; len>0; ++xi,Aii+=ds) {
1019             // i = 0..N-2
1020             // x(i) = A.row(i,i,j2) * x.subVector(i,j2);
1021 #ifdef TMVFLDEBUG
1022             TMVAssert(xi >= x._first);
1023             TMVAssert(xi < x._last);
1024 #endif
1025             *xi *= (ca ? TMV_CONJ(*Aii) : *Aii);
1026             const T* xj = xi+1;
1027             const Ta* Aij = Aii+sj;
1028             for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) {
1029 #ifdef TMVFLDEBUG
1030                 TMVAssert(xi >= x._first);
1031                 TMVAssert(xi < x._last);
1032 #endif
1033                 *xi += (*xj) * (ca ? TMV_CONJ(*Aij) : *Aij);
1034             }
1035 
1036             if (j2<N) ++j2;
1037             else --len;
1038         }
1039 #ifdef TMVFLDEBUG
1040         TMVAssert(xi >= x._first);
1041         TMVAssert(xi < x._last);
1042 #endif
1043         *xi *= (ca ? TMV_CONJ(*Aii) : *Aii);
1044     }
1045 
1046     template <bool rm, class T, class Ta>
RowUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1047     static inline void RowUpperMultEqMV(
1048         const GenBandMatrix<Ta>& A, VectorView<T> x)
1049     {
1050         if (A.isconj())
1051             DoRowUpperMultEqMV<rm,true>(A,x);
1052         else
1053             DoRowUpperMultEqMV<rm,false>(A,x);
1054     }
1055 
1056     template <bool cm, bool ca, class T, class Ta>
DoColUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1057     static void DoColUpperMultEqMV(
1058         const GenBandMatrix<Ta>& A, VectorView<T> x)
1059     {
1060         TMVAssert(A.isSquare());
1061         TMVAssert(A.colsize() == x.size());
1062         TMVAssert(x.size() > 0);
1063         TMVAssert(x.step()==1);
1064         TMVAssert(x.ct() == NonConj);
1065         TMVAssert(cm == A.iscm());
1066         TMVAssert(ca == A.isconj());
1067 
1068         const ptrdiff_t N = x.size();
1069         const ptrdiff_t si = cm ? 1 : A.stepi();
1070         const ptrdiff_t sj = A.stepj();
1071         const ptrdiff_t ds = A.diagstep();
1072 
1073         const Ta* Ai1j = A.cptr();
1074         T* xi1 = x.ptr();
1075 
1076         *xi1 *= (ca ? TMV_CONJ(*Ai1j) : *Ai1j);
1077         Ai1j += sj;
1078         const T* xj = x.cptr()+1;
1079 
1080         ptrdiff_t k=A.nhi()-1;
1081         ptrdiff_t len = 1;
1082         for(ptrdiff_t j=1; j<N; ++j,++xj) {
1083             if (*xj != T(0)) {
1084                 // j = 1..N-1
1085                 // x.subVector(i1,j) += x(j) * A.col(j,i1,j);
1086                 const Ta* Aij = Ai1j;
1087                 T* xi = xi1;
1088                 for(ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) {
1089 #ifdef TMVFLDEBUG
1090                     TMVAssert(xi >= x._first);
1091                     TMVAssert(xi < x._last);
1092 #endif
1093                     *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij);
1094                 }
1095                 // Now Aij == Ajj, xi == xj
1096                 // so this next statement is really *xj *= *Ajj
1097 #ifdef TMVFLDEBUG
1098                 TMVAssert(xi >= x._first);
1099                 TMVAssert(xi < x._last);
1100 #endif
1101                 *xi *= (ca ? TMV_CONJ(*Aij) : *Aij);
1102             }
1103 
1104             if (k>0) { --k; Ai1j+=sj; ++len; }
1105             else { ++xi1; Ai1j+=ds; }
1106         }
1107     }
1108 
1109     template <bool cm, class T, class Ta>
ColUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1110     static inline void ColUpperMultEqMV(
1111         const GenBandMatrix<Ta>& A, VectorView<T> x)
1112     {
1113         if (A.isconj())
1114             DoColUpperMultEqMV<cm,true>(A,x);
1115         else
1116             DoColUpperMultEqMV<cm,false>(A,x);
1117     }
1118 
1119     template <bool rm, bool ca, class T, class Ta>
DoRowLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1120     static void DoRowLowerMultEqMV(
1121         const GenBandMatrix<Ta>& A, VectorView<T> x)
1122     {
1123         TMVAssert(A.isSquare());
1124         TMVAssert(A.colsize() == x.size());
1125         TMVAssert(x.size() > 0);
1126         TMVAssert(x.step()==1);
1127         TMVAssert(x.ct() == NonConj);
1128         TMVAssert(rm == A.isrm());
1129         TMVAssert(ca == A.isconj());
1130 
1131         const ptrdiff_t N = x.size();
1132         const ptrdiff_t si = A.stepi();
1133         const ptrdiff_t sj = (rm ? 1 : A.stepj());
1134         const ptrdiff_t ds = A.diagstep();
1135 
1136         ptrdiff_t j1 = N-1-A.nlo();
1137         const T* xj1 = x.cptr() + j1;
1138         T* xi = x.ptr() + N-1;
1139         const Ta* Aii = A.cptr() + (N-1)*ds;
1140         const Ta* Aij1 = Aii - A.nlo()*sj;
1141         ptrdiff_t len = A.nlo();
1142 
1143         for(; len>0; --xi,Aii-=ds) {
1144             // i = N-1..1
1145             // x(i) = A.row(i,j1,i+1) * x.subVector(j1,i+1);
1146 #ifdef TMVFLDEBUG
1147             TMVAssert(xi >= x._first);
1148             TMVAssert(xi < x._last);
1149 #endif
1150             *xi *= (ca ? TMV_CONJ(*Aii) : *Aii);
1151             const Ta* Aij = Aij1;
1152             const T* xj = xj1;
1153             for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) {
1154 #ifdef TMVFLDEBUG
1155                 TMVAssert(xi >= x._first);
1156                 TMVAssert(xi < x._last);
1157 #endif
1158                 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij);
1159             }
1160 
1161             if (j1>0) { --j1; Aij1-=ds; --xj1; }
1162             else { --len; Aij1-=si; }
1163         }
1164 #ifdef TMVFLDEBUG
1165         TMVAssert(xi >= x._first);
1166         TMVAssert(xi < x._last);
1167 #endif
1168         *xi *= (ca ? TMV_CONJ(*Aii) : *Aii);
1169     }
1170 
1171     template <bool rm, class T, class Ta>
RowLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1172     static inline void RowLowerMultEqMV(
1173         const GenBandMatrix<Ta>& A, VectorView<T> x)
1174     {
1175         if (A.isconj())
1176             DoRowLowerMultEqMV<rm,true>(A,x);
1177         else
1178             DoRowLowerMultEqMV<rm,false>(A,x);
1179     }
1180 
1181     template <bool cm, bool ca, class T, class Ta>
DoColLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1182     static void DoColLowerMultEqMV(
1183         const GenBandMatrix<Ta>& A, VectorView<T> x)
1184     {
1185         TMVAssert(A.isSquare());
1186         TMVAssert(A.colsize() == x.size());
1187         TMVAssert(x.size() > 0);
1188         TMVAssert(x.step() == 1);
1189         TMVAssert(x.ct() == NonConj);
1190         TMVAssert(cm == A.iscm());
1191         TMVAssert(ca == A.isconj());
1192 
1193         const ptrdiff_t N = x.size();
1194         const ptrdiff_t si = cm ? 1 : A.stepi();
1195         const ptrdiff_t ds = A.diagstep();
1196 
1197         T* xj = x.ptr() + N-1;
1198         const Ta* Ajj = A.cptr()+(N-1)*ds;
1199 
1200 #ifdef TMVFLDEBUG
1201         TMVAssert(xj >= x._first);
1202         TMVAssert(xj < x._last);
1203 #endif
1204         *xj *= (ca ? TMV_CONJ(*Ajj) : *Ajj);
1205         --xj;
1206         Ajj -= ds;
1207 
1208         ptrdiff_t k=A.nlo()-1;
1209         for(ptrdiff_t j=N-1,len=1;j>0;--j,--xj,Ajj-=ds) {
1210             if (*xj!=T(0)) {
1211                 // Actual j = N-2..0
1212                 // x.subVector(j+1,N) += *xj * A.col(j,j+1,N);
1213                 T* xi = xj+1;
1214                 const Ta* Aij = Ajj+si;
1215                 for (ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) {
1216 #ifdef TMVFLDEBUG
1217                     TMVAssert(xi >= x._first);
1218                     TMVAssert(xi < x._last);
1219 #endif
1220                     *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij);
1221                 }
1222 #ifdef TMVFLDEBUG
1223                 TMVAssert(xj >= x._first);
1224                 TMVAssert(xj < x._last);
1225 #endif
1226                 *xj *= (ca ? TMV_CONJ(*Ajj) : *Ajj);
1227 
1228             }
1229             if (k>0) { --k; ++len; }
1230         }
1231     }
1232 
1233     template <bool cm, class T, class Ta>
ColLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1234     static inline void ColLowerMultEqMV(
1235         const GenBandMatrix<Ta>& A, VectorView<T> x)
1236     {
1237         if (A.isconj())
1238             DoColLowerMultEqMV<cm,true>(A,x);
1239         else
1240             DoColLowerMultEqMV<cm,false>(A,x);
1241     }
1242 
1243     template <class T, class Ta>
DoUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1244     static inline void DoUpperMultEqMV(
1245         const GenBandMatrix<Ta>& A, VectorView<T> x)
1246     // x = A * x
1247     {
1248         if (A.isrm()) RowUpperMultEqMV<true>(A,x);
1249         else if (A.iscm()) ColUpperMultEqMV<true>(A,x);
1250         else RowUpperMultEqMV<false>(A,x);
1251     }
1252 
1253     template <class T, class Ta>
DoLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1254     static inline void DoLowerMultEqMV(
1255         const GenBandMatrix<Ta>& A, VectorView<T> x)
1256     {
1257         if (A.isrm()) RowLowerMultEqMV<true>(A,x);
1258         else if (A.iscm() && !SameStorage(A,x))
1259             ColLowerMultEqMV<true>(A,x);
1260         else RowLowerMultEqMV<false>(A,x);
1261     }
1262 
1263     template <class T, class Ta>
NonBlasUpperMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1264     static void NonBlasUpperMultEqMV(
1265         const GenBandMatrix<Ta>& A, VectorView<T> x)
1266     {
1267         TMVAssert(A.isSquare());
1268         TMVAssert(A.colsize() == x.size());
1269         TMVAssert(x.size() > 0);
1270         TMVAssert(x.step() == 1);
1271         TMVAssert(x.ct() == NonConj);
1272 
1273         //     [ A11 A12  0  ] [ 0  ]   [ A12 x2 ]
1274         // x = [  0  A22 A23 ] [ x2 ] = [ A22 x2 ]
1275         //     [  0   0  A33 ] [ 0  ]   [   0    ]
1276 
1277         const ptrdiff_t N = x.size(); // = A.size()
1278         ptrdiff_t j2 = N;
1279         for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2);
1280         if (j2 == 0) return;
1281         ptrdiff_t j1 = 0;
1282         for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1);
1283         if (j1 == 0 && j2 == N) DoUpperMultEqMV(A,x);
1284         else {
1285             TMVAssert(j1 < j2);
1286             const Ta* p22 = A.cptr() + j1*A.diagstep();
1287             const ptrdiff_t N22 = j2-j1;
1288             VectorView<T> x2 = x.subVector(j1,j2);
1289             if (N22 > A.nhi()) {
1290                 ConstBandMatrixView<Ta> A22(
1291                     p22,N22,N22,0,A.nhi(),
1292                     A.stepi(),A.stepj(),A.diagstep(),A.ct());
1293                 if (j1 > 0) {
1294                     const ptrdiff_t jx = j1+A.nhi();
1295                     if (j1 < A.nhi()) {
1296                         const Ta* p12 = A.cptr() + j1*A.stepj();
1297                         ConstBandMatrixView<Ta> A12(
1298                             p12,j1,A.nhi(),j1-1,A.nhi()-j1,
1299                             A.stepi(),A.stepj(),A.diagstep(),A.ct());
1300                         UnitAMultMV1<false,false>(
1301                             A12,x.subVector(j1,jx),x.subVector(0,j1));
1302                     } else {
1303                         const Ta* p12 = p22 - A.nhi()*A.stepi();
1304                         ConstBandMatrixView<Ta> A12(
1305                             p12,A.nhi(),A.nhi(),A.nhi()-1,0,
1306                             A.stepi(),A.stepj(),A.diagstep(),A.ct());
1307                         VectorView<T> x1x = x.subVector(j1-A.nhi(),j1);
1308                         x1x = x.subVector(j1,jx);
1309                         DoLowerMultEqMV(A12,x1x);
1310                     }
1311                 }
1312                 DoUpperMultEqMV(A22,x2);
1313             } else {
1314                 ConstBandMatrixView<Ta> A22(
1315                     p22,N22,N22,0,N22-1,
1316                     A.stepi(),A.stepj(),A.diagstep(),A.ct());
1317                 if (j1 > 0) {
1318                     const ptrdiff_t M12 = (j1 < A.nhi()) ? j1 : A.nhi();
1319                     const Ta* p12 = p22 - M12*A.stepi();
1320                     ptrdiff_t newhi = A.nhi()-M12;
1321                     if (newhi >= N22) newhi = N22-1;
1322                     ConstBandMatrixView<Ta> A12(
1323                         p12,M12,N22,M12-1,newhi,
1324                         A.stepi(),A.stepj(),A.diagstep(),A.ct());
1325                     UnitAMultMV1<false,false>(A12,x2,x.subVector(j1-M12,j1));
1326                 }
1327                 DoUpperMultEqMV(A22,x2);
1328             }
1329         }
1330     }
1331 
1332     template <class T, class Ta>
NonBlasLowerMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1333     static void NonBlasLowerMultEqMV(
1334         const GenBandMatrix<Ta>& A, VectorView<T> x)
1335     // x = A * x
1336     {
1337         TMVAssert(A.isSquare());
1338         TMVAssert(A.colsize() == x.size());
1339         TMVAssert(x.size() > 0);
1340         TMVAssert(x.step() == 1);
1341         TMVAssert(x.ct() == NonConj);
1342 
1343         //     [ A11  0   0  ] [ 0  ]   [   0    ]
1344         // x = [ A21 A22  0  ] [ x2 ] = [ A22 x2 ]
1345         //     [  0  A32 A33 ] [ 0  ]   [ A32 x2 ]
1346 
1347         const ptrdiff_t N = x.size(); // = A.size()
1348         ptrdiff_t j2 = N;
1349         for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2);
1350         if (j2 == 0) return;
1351         ptrdiff_t j1 = 0;
1352         for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1);
1353         if (j1 == 0 && j2 == N) DoLowerMultEqMV(A,x);
1354         else {
1355             TMVAssert(j1 < j2);
1356             const Ta* p22 = A.cptr() + j1*A.diagstep();
1357             const ptrdiff_t N22 = j2-j1;
1358             VectorView<T> x2 = x.subVector(j1,j2);
1359             if (N22 > A.nlo()) {
1360                 ConstBandMatrixView<Ta> A22(
1361                     p22,N22,N22,A.nlo(),0,
1362                     A.stepi(),A.stepj(),A.diagstep(),A.ct());
1363                 if (j2 < N) {
1364                     const ptrdiff_t jx = j2-A.nlo();
1365                     const Ta* p32 = A.cptr() +
1366                         j2*A.diagstep() - A.nlo()*A.stepj();
1367                     if (j2+A.nlo() > N) {
1368                         ConstBandMatrixView<Ta> A32(
1369                             p32,N-j2,A.nlo(),0,A.nlo()-1,
1370                             A.stepi(),A.stepj(),A.diagstep(),A.ct());
1371                         UnitAMultMV1<false,false>(
1372                             A32,x.subVector(jx,j2),x.subVector(j2,N));
1373                     } else {
1374                         ConstBandMatrixView<Ta> A32(
1375                             p32,A.nlo(),A.nlo(),0,A.nlo()-1,
1376                             A.stepi(),A.stepj(),A.diagstep(),A.ct());
1377                         VectorView<T> x3x = x.subVector(j2,j2+A.nlo());
1378                         x3x = x.subVector(jx,j2);
1379                         DoUpperMultEqMV(A32,x3x);
1380                     }
1381                 }
1382                 DoLowerMultEqMV(A22,x2);
1383             } else {
1384                 ConstBandMatrixView<Ta> A22(
1385                     p22,N22,N22,N22-1,0,
1386                     A.stepi(),A.stepj(),A.diagstep(),A.ct());
1387                 if (j2 < N) {
1388                     const Ta* p32 = p22 + N22*A.stepi();
1389                     const ptrdiff_t M32 = (j2+A.nlo() > N) ? N-j2 : A.nlo();
1390                     ptrdiff_t newlo = A.nlo()-N22;
1391                     if (newlo >= M32) newlo = M32-1;
1392                     ConstBandMatrixView<Ta> A32(
1393                         p32,M32,N22,newlo,N22-1,
1394                         A.stepi(),A.stepj(),A.diagstep(),A.ct());
1395                     UnitAMultMV1<false,false>(A32,x2,x.subVector(j2,j2+M32));
1396                 }
1397                 DoLowerMultEqMV(A22,x2);
1398             }
1399         }
1400     }
1401 
1402 #ifdef BLAS
1403     template <class T, class Ta>
BlasMultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1404     static inline void BlasMultEqMV(
1405         const GenBandMatrix<Ta>& A, VectorView<T> x)
1406     {
1407         if (A.nlo() == 0) NonBlasUpperMultEqMV(A,x);
1408         else NonBlasLowerMultEqMV(A,x);
1409     }
1410 #ifdef INST_DOUBLE
1411     template <>
BlasMultEqMV(const GenBandMatrix<double> & A,VectorView<double> x)1412     void BlasMultEqMV(
1413         const GenBandMatrix<double>& A, VectorView<double> x)
1414     {
1415         bool up = A.nlo()==0;
1416         int n=A.colsize();
1417         int lohi = up ? A.nhi() : A.nlo();
1418         int aoffset =
1419             up && BlasIsCM(A) ? A.nhi() :
1420             !up && !BlasIsCM(A) ? A.nlo() : 0;
1421         int ds = A.diagstep();
1422         int xs = x.step();
1423         BLASNAME(dtbmv) (
1424             BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1425             BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU,
1426             BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1427             BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1428     }
1429     template <>
BlasMultEqMV(const GenBandMatrix<std::complex<double>> & A,VectorView<std::complex<double>> x)1430     void BlasMultEqMV(
1431         const GenBandMatrix<std::complex<double> >& A,
1432         VectorView<std::complex<double> > x)
1433     {
1434         bool up = A.nlo()==0;
1435         int n=A.colsize();
1436         int lohi = up ? A.nhi() : A.nlo();
1437         int aoffset =
1438             up && BlasIsCM(A) ? A.nhi() :
1439             !up && !BlasIsCM(A) ? A.nlo() : 0;
1440         int ds = A.diagstep();
1441         int xs = x.step();
1442         if (BlasIsCM(A) && A.isconj()) {
1443 #ifdef CBLAS
1444             BLASNAME(ztbmv) (
1445                 BLASRM up ? BLASCH_LO : BLASCH_UP, BLASCH_CT, BLASCH_NU,
1446                 BLASV(n),BLASV(lohi),BLASP(A.cptr()-A.nhi()),
1447                 BLASV(ds),BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1448 #else
1449             x.conjugateSelf();
1450             BLASNAME(ztbmv) (
1451                 BLASCM up?BLASCH_UP:BLASCH_LO, BLASCH_NT, BLASCH_NU,
1452                 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1453                 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1454             x.conjugateSelf();
1455 #endif
1456         } else {
1457             BLASNAME(ztbmv) (
1458                 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1459                 BlasIsCM(A) ? BLASCH_NT : A.isconj() ? BLASCH_CT : BLASCH_T,
1460                 BLASCH_NU,
1461                 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1462                 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1463         }
1464     }
1465     template <>
BlasMultEqMV(const GenBandMatrix<double> & A,VectorView<std::complex<double>> x)1466     void BlasMultEqMV(
1467         const GenBandMatrix<double>& A,
1468         VectorView<std::complex<double> > x)
1469     {
1470         bool up = A.nlo()==0;
1471         int n=A.colsize();
1472         int lohi = up ? A.nhi() : A.nlo();
1473         int aoffset =
1474             up && BlasIsCM(A) ? A.nhi() :
1475             !up && !BlasIsCM(A) ? A.nlo() : 0;
1476         int ds = A.diagstep();
1477         int xs = 2*x.step();
1478         BLASNAME(dtbmv) (
1479             BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1480             BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU,
1481             BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1482             BLASP((double*)x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1483         BLASNAME(dtbmv) (
1484             BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1485             BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU,
1486             BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1487             BLASP((double*)x.ptr()+1),BLASV(xs) BLAS1 BLAS1 BLAS1);
1488     }
1489 #endif
1490 #ifdef INST_FLOAT
1491     template <>
BlasMultEqMV(const GenBandMatrix<float> & A,VectorView<float> x)1492     void BlasMultEqMV(
1493         const GenBandMatrix<float>& A, VectorView<float> x)
1494     {
1495         bool up = A.nlo()==0;
1496         int n=A.colsize();
1497         int lohi = up ? A.nhi() : A.nlo();
1498         int aoffset =
1499             up && BlasIsCM(A) ? A.nhi() :
1500             !up && !BlasIsCM(A) ? A.nlo() : 0;
1501         int ds = A.diagstep();
1502         int xs = x.step();
1503         BLASNAME(stbmv) (
1504             BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1505             BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU,
1506             BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1507             BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1508     }
1509     template <>
BlasMultEqMV(const GenBandMatrix<std::complex<float>> & A,VectorView<std::complex<float>> x)1510     void BlasMultEqMV(
1511         const GenBandMatrix<std::complex<float> >& A,
1512         VectorView<std::complex<float> > x)
1513     {
1514         bool up = A.nlo()==0;
1515         int n=A.colsize();
1516         int lohi = up ? A.nhi() : A.nlo();
1517         int aoffset =
1518             up && BlasIsCM(A) ? A.nhi() :
1519             !up && !BlasIsCM(A) ? A.nlo() : 0;
1520         int ds = A.diagstep();
1521         int xs = x.step();
1522         //cout<<"Before ctbmv\n";
1523         //cout<<"up = "<<up<<std::endl;
1524         //cout<<"n = "<<n<<std::endl;
1525         //cout<<"lohi = "<<lohi<<std::endl;
1526         //cout<<"aoffset = "<<aoffset<<std::endl;
1527         //cout<<"ds = "<<ds<<std::endl;
1528         //cout<<"xs = "<<xs<<std::endl;
1529         if (BlasIsCM(A) && A.isconj()) {
1530             //cout<<"cm && conj\n";
1531 #ifdef CBLAS
1532             BLASNAME(ctbmv) (
1533                 BLASRM up ? BLASCH_LO : BLASCH_UP, BLASCH_CT, BLASCH_NU,
1534                 BLASV(n),BLASV(lohi),BLASP(A.cptr()-A.nhi()),
1535                 BLASV(ds),BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1536 #else
1537             x.conjugateSelf();
1538             BLASNAME(ctbmv) (
1539                 BLASCM up?BLASCH_UP:BLASCH_LO, BLASCH_NT, BLASCH_NU,
1540                 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1541                 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1542             x.conjugateSelf();
1543 #endif
1544         } else {
1545             //cout<<"!cm || !conj\n";
1546             BLASNAME(ctbmv) (
1547                 BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1548                 BlasIsCM(A) ? BLASCH_NT : A.isconj() ? BLASCH_CT : BLASCH_T,
1549                 BLASCH_NU,
1550                 BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1551                 BLASP(x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1552         }
1553     }
1554     template <>
BlasMultEqMV(const GenBandMatrix<float> & A,VectorView<std::complex<float>> x)1555     void BlasMultEqMV(
1556         const GenBandMatrix<float>& A,
1557         VectorView<std::complex<float> > x)
1558     {
1559         bool up = A.nlo()==0;
1560         int n=A.colsize();
1561         int lohi = up ? A.nhi() : A.nlo();
1562         int aoffset =
1563             up && BlasIsCM(A) ? A.nhi() :
1564             !up && !BlasIsCM(A) ? A.nlo() : 0;
1565         int ds = A.diagstep();
1566         int xs = 2*x.step();
1567         BLASNAME(stbmv) (
1568             BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1569             BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU,
1570             BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1571             BLASP((float*)x.ptr()),BLASV(xs) BLAS1 BLAS1 BLAS1);
1572         BLASNAME(stbmv) (
1573             BLASCM BlasIsCM(A) == up?BLASCH_UP:BLASCH_LO,
1574             BlasIsCM(A) ? BLASCH_NT : BLASCH_T, BLASCH_NU,
1575             BLASV(n),BLASV(lohi),BLASP(A.cptr()-aoffset),BLASV(ds),
1576             BLASP((float*)x.ptr()+1),BLASV(xs) BLAS1 BLAS1 BLAS1);
1577     }
1578 #endif
1579 #endif // BLAS
1580 
1581     template <class T, class Ta>
MultEqMV(const GenBandMatrix<Ta> & A,VectorView<T> x)1582     static void MultEqMV(
1583         const GenBandMatrix<Ta>& A, VectorView<T> x)
1584     {
1585 #ifdef XDEBUG
1586         cout<<"Start MultEqMV\n";
1587         Vector<T> x0 = x;
1588         Matrix<Ta> A0 = A;
1589         Vector<T> x2 = A0 * x0;
1590 #endif
1591         TMVAssert(A.isSquare());
1592         TMVAssert(A.colsize() == x.size());
1593         TMVAssert(x.size() > 0);
1594         TMVAssert(x.step() == 1);
1595         TMVAssert(A.nlo() == 0 || A.nhi() == 0);
1596 
1597         if (x.isconj()) MultEqMV(A.conjugate(),x.conjugate());
1598         else {
1599 #ifdef BLAS
1600             if ( !(BlasIsRM(A) || BlasIsCM(A)) ) {
1601                 //cout<<"!(rm or cm)\n";
1602                 BandMatrix<Ta,ColMajor> AA = A;
1603                 BlasMultEqMV(AA,x);
1604             } else if (SameStorage(A,x) || x.step() != 1) {
1605                 //cout<<"copy x\n";
1606                 Vector<T> xx = x;
1607                 BlasMultEqMV(A,xx.view());
1608                 x = xx;
1609             } else {
1610                 //cout<<"normal\n";
1611                 BlasMultEqMV(A,x);
1612             }
1613 #else
1614             if (A.nlo() == 0) NonBlasUpperMultEqMV(A,x);
1615             else NonBlasLowerMultEqMV(A,x);
1616 #endif
1617         }
1618 
1619 #ifdef XDEBUG
1620         cout<<"Done MultEqMV\n";
1621         if (!(Norm(x-x2) <= 0.001*(Norm(A0)*Norm(x0)))) {
1622             cerr<<"MultEqMV: \n";
1623             cerr<<"A = "<<TMV_Text(A)<<"  "<<A0<<endl;
1624             cerr<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x0<<endl;
1625             cerr<<"-> x = "<<x<<endl;
1626             cerr<<"x2 = "<<x2<<endl;
1627             abort();
1628         }
1629 #endif
1630     }
1631 
1632     template <bool add, class T, class Ta, class Tx>
MultMV(const T alpha,const GenBandMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)1633     void MultMV(
1634         const T alpha, const GenBandMatrix<Ta>& A, const GenVector<Tx>& x,
1635         VectorView<T> y)
1636     // y (+)= alpha * A * x
1637     {
1638         TMVAssert(A.rowsize() == x.size());
1639         TMVAssert(A.colsize() == y.size());
1640 #ifdef XDEBUG
1641         cout<<"Start Band: MultMV\n";
1642         cout<<"A = "<<TMV_Text(A)<<"  "<<A.cptr()<<"  "<<A<<endl;
1643         cout<<"x = "<<TMV_Text(x)<<"  "<<x.cptr()<<"  step "<<x.step()<<"  "<<x<<endl;
1644         cout<<"y = "<<TMV_Text(y)<<"  "<<y.cptr()<<"  step "<<y.step()<<"  "<<y<<endl;
1645         cout<<"alpha = "<<alpha<<", add = "<<add<<endl;
1646         Vector<T> y0 = y;
1647         cout<<"y0 = "<<y0<<std::endl;
1648         Vector<Tx> x0 = x;
1649         cout<<"x0 = "<<x0<<std::endl;
1650         Matrix<Ta> A0 = A;
1651         cout<<"A0 = "<<A0<<std::endl;
1652         Vector<T> y2 = alpha*A0*x0;
1653         cout<<"y2 = "<<y2<<std::endl;
1654         if (add) y2 += y0;
1655         cout<<"y2 => "<<y2<<std::endl;
1656 #endif
1657 
1658         if (y.size() > 0) {
1659             if (x.size()==0 || alpha==T(0)) {
1660                 if (!add) y.setZero();
1661             } else if (A.rowsize() > A.colsize()+A.nhi()) {
1662                 MultMV<add>(
1663                     alpha,A.colRange(0,A.colsize()+A.nhi()),
1664                     x.subVector(0,A.colsize()+A.nhi()),y);
1665             } else if (A.colsize() > A.rowsize()+A.nlo()) {
1666                 MultMV<add>(
1667                     alpha,A.rowRange(0,A.rowsize()+A.nlo()),
1668                     x,y.subVector(0,A.rowsize()+A.nlo()));
1669                 if (!add)
1670                     y.subVector(A.rowsize()+A.nlo(),A.colsize()).setZero();
1671             } else if (A.isSquare() && (A.nlo() == 0 || A.nhi() == 0)) {
1672                 if (A.nlo() == 0 && A.nhi() == 0)
1673                     MultMV<add>(alpha,DiagMatrixViewOf(A.diag()),x,y);
1674                 else if (!add && y.step() == 1) {
1675                     y = alpha * x;
1676                     MultEqMV(A,y);
1677                 } else {
1678                     Vector<T> xx = alpha*x;
1679                     MultEqMV(A,xx.view());
1680                     if (add) y += xx;
1681                     else y = xx;
1682                 }
1683             } else {
1684                 if (SameStorage(y,A)) {
1685                     Vector<T> yy(y.size());
1686                     DoMultMV<false>(T(1),A,x,yy.view());
1687                     if (add) y += alpha*yy;
1688                     else y = alpha*yy;
1689                 } else {
1690                     DoMultMV<add>(alpha,A,x,y);
1691                 }
1692             }
1693         }
1694 #ifdef XDEBUG
1695         cout<<"->y = "<<y<<endl;
1696         if (!(Norm(y2-y) <= 0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
1697                                 (add?Norm(y0):TMV_RealType(T)(0))))) {
1698             cerr<<"MultMV: alpha = "<<alpha<<endl;
1699             cerr<<"add = "<<add<<endl;
1700             cerr<<"A = "<<TMV_Text(A)<<"  "<<A.cptr()<<"  "<<A0<<endl;
1701             cerr<<"x = "<<TMV_Text(x)<<"  "<<x.cptr()<<" step "<<x.step()<<"  "<<x0<<endl;
1702             cerr<<"y = "<<TMV_Text(y)<<"  "<<y.cptr()<<" step "<<y.step()<<"  "<<y0<<endl;
1703             cerr<<"--> y = "<<y<<endl;
1704             cerr<<"y2 = "<<y2<<endl;
1705             cerr<<"Norm(diff) = "<<Norm(y2-y)<<endl;
1706             cerr<<"abs(alpha)*Norm(A0)*Norm(x0) = "<<TMV_ABS(alpha)*Norm(A0)*Norm(x0)<<endl;
1707             cerr<<"Norm(y0) = "<<Norm(y0)<<endl;
1708             abort();
1709         }
1710 #endif
1711     }
1712 
1713 #define InstFile "TMV_MultBV.inst"
1714 #include "TMV_Inst.h"
1715 #undef InstFile
1716 
1717 } // namespace tmv
1718 
1719 
1720