1 ///////////////////////////////////////////////////////////////////////////////
2 //                                                                           //
3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis     //
4 // Copyright (C) 1998 - 2016                                                 //
5 // All rights reserved                                                       //
6 //                                                                           //
7 // The project is hosted at https://code.google.com/p/tmv-cpp/               //
8 // where you can find the current version and current documention.           //
9 //                                                                           //
10 // For concerns or problems with the software, Mike may be contacted at      //
11 // mike_jarvis17 [at] gmail.                                                 //
12 //                                                                           //
13 // This software is licensed under a FreeBSD license.  The file              //
14 // TMV_LICENSE should have bee included with this distribution.              //
15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/.       //
16 //                                                                           //
17 // Essentially, you can use this software however you want provided that     //
18 // you include the TMV_LICENSE file in any distribution that uses it.        //
19 //                                                                           //
20 ///////////////////////////////////////////////////////////////////////////////
21 
22 
23 //#define XDEBUG
24 
25 
26 #include "TMV_Blas.h"
27 #include "tmv/TMV_TriMatrixArithFunc.h"
28 #include "tmv/TMV_TriMatrix.h"
29 #include "tmv/TMV_Vector.h"
30 #include "tmv/TMV_VectorArith.h"
31 #include "TMV_MultMV.h"
32 
33 // CBLAS trick of using RowMajor with ConjTrans when we have a
34 // case of A.conjugate() * x doesn't seem to be working with MKL 10.2.2.
35 // I haven't been able to figure out why.  (e.g. Is it a bug in the MKL
36 // code, or am I doing something wrong?)  So for now, just disable it.
37 #ifdef CBLAS
38 #undef CBLAS
39 #endif
40 
41 #ifdef XDEBUG
42 #include "tmv/TMV_MatrixArith.h"
43 #include <iostream>
44 using std::cout;
45 using std::cerr;
46 using std::endl;
47 #endif
48 
49 namespace tmv {
50 
51     template <class T>
cptr() const52     const T* UpperTriMatrixComposite<T>::cptr() const
53     {
54         if (!itsm.get()) {
55             itsm.resize(this->size()*this->size());
56             UpperTriMatrixView<T>(
57                 itsm.get(),this->size(),stepi(),stepj(),this->dt(),NonConj) =
58                 *this;
59         }
60         return itsm.get();
61     }
62 
63     template <class T>
stepi() const64     ptrdiff_t UpperTriMatrixComposite<T>::stepi() const
65     { return 1; }
66 
67     template <class T>
stepj() const68     ptrdiff_t UpperTriMatrixComposite<T>::stepj() const
69     { return this->size(); }
70 
71     template <class T>
cptr() const72     const T* LowerTriMatrixComposite<T>::cptr() const
73     {
74         if (!itsm.get()) {
75             itsm.resize(this->size()*this->size());
76             LowerTriMatrixView<T>(
77                 itsm.get(),this->size(),stepi(),stepj(),this->dt(),NonConj) =
78                 *this;
79         }
80         return itsm.get();
81     }
82 
83     template <class T>
stepi() const84     ptrdiff_t LowerTriMatrixComposite<T>::stepi() const
85     { return 1; }
86 
87     template <class T>
stepj() const88     ptrdiff_t LowerTriMatrixComposite<T>::stepj() const
89     { return this->size(); }
90 
91     //
92     // MultEqMV
93     //
94 
95     template <bool rm, bool ca, bool ua, class T, class Ta>
DoRowMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)96     static void DoRowMultEqMV(
97         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
98     {
99         //cout<<"RowMultEqMV Upper\n";
100         TMVAssert(x.step()==1);
101         TMVAssert(A.size() == x.size());
102         TMVAssert(x.size() > 0);
103         TMVAssert(x.ct() == NonConj);
104         TMVAssert(rm == A.isrm());
105         TMVAssert(ca == A.isconj());
106         TMVAssert(ua == A.isunit());
107 
108         const ptrdiff_t N = x.size();
109 
110         const ptrdiff_t sj = rm ? 1 : A.stepj();
111         const ptrdiff_t ds = A.stepi()+sj;
112         T* xi = x.ptr();
113         const Ta* Aii = A.cptr();
114         ptrdiff_t len = N-1;
115 
116         for(; len>0; --len,++xi,Aii+=ds) {
117             // i = 0..N-2
118             // x(i) = A.row(i,i,N) * x.subVector(i,N);
119             if (!ua) *xi *= (ca ? TMV_CONJ(*Aii) : *Aii);
120             const T* xj = xi+1;
121             const Ta* Aij = Aii+sj;
122             for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj)) {
123 #ifdef TMVFLDEBUG
124                 TMVAssert(xi >= x._first);
125                 TMVAssert(xi < x._last);
126 #endif
127                 *xi += (*xj) * (ca ? TMV_CONJ(*Aij) : *Aij);
128             }
129         }
130 #ifdef TMVFLDEBUG
131         TMVAssert(xi >= x._first);
132         TMVAssert(xi < x._last);
133 #endif
134         if (!ua) *xi *= (ca ? TMV_CONJ(*Aii) : *Aii);
135     }
136 
137     template <bool rm, class T, class Ta>
RowMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)138     static void RowMultEqMV(
139         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
140     {
141         if (A.isconj())
142             if (A.isunit())
143                 DoRowMultEqMV<rm,true,true>(A,x);
144             else
145                 DoRowMultEqMV<rm,true,false>(A,x);
146         else
147             if (A.isunit())
148                 DoRowMultEqMV<rm,false,true>(A,x);
149             else
150                 DoRowMultEqMV<rm,false,false>(A,x);
151     }
152 
153     template <bool cm, bool ca, bool ua, class T, class Ta>
DoColMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)154     static void DoColMultEqMV(
155         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
156     {
157         //cout<<"ColMultEqMV Upper\n";
158         TMVAssert(x.step()==1);
159         TMVAssert(A.size() == x.size());
160         TMVAssert(x.size() > 0);
161         TMVAssert(x.ct() == NonConj);
162         TMVAssert(cm == A.iscm());
163         TMVAssert(ca == A.isconj());
164         TMVAssert(ua == A.isunit());
165 
166         const ptrdiff_t N = x.size();
167 
168         T* x0 = x.ptr();
169         const T* xj = x0+1;
170         const ptrdiff_t si = cm ? 1 : A.stepi();
171         const Ta* A0j = A.cptr();
172 
173 #ifdef TMVFLDEBUG
174         TMVAssert(x0 >= x._first);
175         TMVAssert(x0 < x._last);
176 #endif
177         if (!ua) *x0 *= (ca ? TMV_CONJ(*A0j) : *A0j);
178         A0j += A.stepj();
179 
180         for(ptrdiff_t len=1; len<N; ++len,++xj,A0j+=A.stepj()) if (*xj != T(0)) {
181             // j = 1..N-1
182             // x.subVector(0,j) += *xj * A.col(j,0,j);
183             const Ta* Aij = A0j;
184             T* xi = x0;
185             for(ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) {
186 #ifdef TMVFLDEBUG
187                 TMVAssert(xi >= x._first);
188                 TMVAssert(xi < x._last);
189 #endif
190                 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij);
191             }
192             // Now Aij == Ajj, xi == xj
193             // so this next statement is really *xj *= *Ajj
194 #ifdef TMVFLDEBUG
195             TMVAssert(xi >= x._first);
196             TMVAssert(xi < x._last);
197 #endif
198             if (!ua) *xi *= (ca ? TMV_CONJ(*Aij) : *Aij);
199         }
200     }
201 
202     template <bool cm, class T, class Ta>
ColMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)203     static void ColMultEqMV(
204         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
205     {
206         if (A.isconj())
207             if (A.isunit())
208                 DoColMultEqMV<cm,true,true>(A,x);
209             else
210                 DoColMultEqMV<cm,true,false>(A,x);
211         else
212             if (A.isunit())
213                 DoColMultEqMV<cm,false,true>(A,x);
214             else
215                 DoColMultEqMV<cm,false,false>(A,x);
216     }
217 
218     template <bool rm, bool ca, bool ua, class T, class Ta>
DoRowMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)219     static void DoRowMultEqMV(
220         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
221     {
222         //cout<<"RowMultEqMV Lower\n";
223         TMVAssert(x.step()==1);
224         TMVAssert(A.size() == x.size());
225         TMVAssert(x.size() > 0);
226         TMVAssert(x.ct() == NonConj);
227         TMVAssert(rm == A.isrm());
228         TMVAssert(ca == A.isconj());
229         TMVAssert(ua == A.isunit());
230 
231         const ptrdiff_t N = x.size();
232         const ptrdiff_t si = A.stepi();
233         const ptrdiff_t sj = rm ? 1 : A.stepj();
234         const ptrdiff_t ds = si+sj;
235 
236         const T* x0 = x.cptr();
237         T* xi = x.ptr() + N-1;
238         const Ta* Ai0 = A.cptr()+(N-1)*si;
239         const Ta* Aii = Ai0 + (N-1)*sj;
240 
241         for(ptrdiff_t len=N-1; len>0; --len,--xi,Ai0-=si,Aii-=ds) {
242             // i = N-1..1
243             // x(i) = A.row(i,0,i+1) * x.subVector(0,i+1);
244             T xx = *xi;
245             if (!ua) xx *= (ca ? TMV_CONJ(*Aii) : *Aii);
246             const Ta* Aij = Ai0;
247             const T* xj = x0;
248             for(ptrdiff_t j=len;j>0;--j,++xj,(rm?++Aij:Aij+=sj))
249                 xx += *xj * (ca ? TMV_CONJ(*Aij) : *Aij);
250 #ifdef TMVFLDEBUG
251             TMVAssert(xi >= x._first);
252             TMVAssert(xi < x._last);
253 #endif
254             *xi = xx;
255         }
256 #ifdef TMVFLDEBUG
257         TMVAssert(xi >= x._first);
258         TMVAssert(xi < x._last);
259 #endif
260         if (!ua) *xi *= (ca ? TMV_CONJ(*Aii) : *Aii);
261     }
262 
263     template <bool rm, class T, class Ta>
RowMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)264     static void RowMultEqMV(
265         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
266     {
267         if (A.isconj())
268             if (A.isunit())
269                 DoRowMultEqMV<rm,true,true>(A,x);
270             else
271                 DoRowMultEqMV<rm,true,false>(A,x);
272         else
273             if (A.isunit())
274                 DoRowMultEqMV<rm,false,true>(A,x);
275             else
276                 DoRowMultEqMV<rm,false,false>(A,x);
277     }
278 
279     template <bool cm, bool ca, bool ua, class T, class Ta>
DoColMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)280     static void DoColMultEqMV(
281         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
282     {
283         //cout<<"ColMultEqMV Lower\n";
284         TMVAssert(A.size() == x.size());
285         TMVAssert(x.size() > 0);
286         TMVAssert(x.ct() == NonConj);
287         TMVAssert(x.step() == 1);
288         TMVAssert(cm == A.iscm());
289         TMVAssert(ca == A.isconj());
290         TMVAssert(ua == A.isunit());
291 
292         const ptrdiff_t N = x.size();
293 
294         T* xj = x.ptr() + N-2;
295         const ptrdiff_t si = cm ? 1 : A.stepi();
296         const ptrdiff_t ds = A.stepj()+si;
297         const Ta* Ajj = A.cptr()+(N-2)*ds;
298 
299 #ifdef TMVFLDEBUG
300         TMVAssert(xj+1 >= x._first);
301         TMVAssert(xj+1 < x._last);
302 #endif
303         if (!ua) *(xj+1) *= (ca ? TMV_CONJ(*(Ajj+ds)) : *(Ajj+ds));
304         for(ptrdiff_t jj=N-1,len=1;jj>0;--jj,++len,--xj,Ajj-=ds) if (*xj!=T(0)) {
305             // j = N-2..0, jj = j+1
306             // x.subVector(j+1,N) += *xj * A.col(j,j+1,N);
307             T* xi = xj+1;
308             const Ta* Aij = Ajj+si;
309             for (ptrdiff_t i=len;i>0;--i,++xi,(cm?++Aij:Aij+=si)) {
310 #ifdef TMVFLDEBUG
311                 TMVAssert(xi >= x._first);
312                 TMVAssert(xi < x._last);
313 #endif
314                 *xi += *xj * (ca ? TMV_CONJ(*Aij) : *Aij);
315             }
316 #ifdef TMVFLDEBUG
317             TMVAssert(xj >= x._first);
318             TMVAssert(xj < x._last);
319 #endif
320             if (!ua) *xj *= (ca ? TMV_CONJ(*Ajj) : *Ajj);
321         }
322     }
323 
324     template <bool cm, class T, class Ta>
ColMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)325     static void ColMultEqMV(
326         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
327     {
328         if (A.isconj())
329             if (A.isunit())
330                 DoColMultEqMV<cm,true,true>(A,x);
331             else
332                 DoColMultEqMV<cm,true,false>(A,x);
333         else
334             if (A.isunit())
335                 DoColMultEqMV<cm,false,true>(A,x);
336             else
337                 DoColMultEqMV<cm,false,false>(A,x);
338     }
339 
340     template <class T, class Ta>
DoMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)341     static inline void DoMultEqMV(
342         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
343     // x = A * x
344     {
345         if (A.isrm()) RowMultEqMV<true>(A,x);
346         else if (A.iscm()) ColMultEqMV<true>(A,x);
347         else RowMultEqMV<false>(A,x);
348     }
349 
350     template <class T, class Ta>
NonBlasMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)351     static void NonBlasMultEqMV(
352         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
353     {
354         TMVAssert(A.size() == x.size());
355         TMVAssert(x.size() > 0);
356         TMVAssert(x.step() == 1);
357         TMVAssert(x.ct() == NonConj);
358 
359         //     [ A11 A12 A13 ] [ 0  ]   [ A12 x2 ]
360         // x = [  0  A22 A23 ] [ x2 ] = [ A22 x2 ]
361         //     [  0   0  A33 ] [ 0  ]   [   0    ]
362 
363         const ptrdiff_t N = x.size(); // = A.size()
364         ptrdiff_t j2 = N;
365         for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2);
366         if (j2 == 0) return;
367         ptrdiff_t j1 = 0;
368         for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1);
369         if (j1 == 0 && j2 == N) {
370             DoMultEqMV(A,x);
371         } else {
372             TMVAssert(j1 < j2);
373             ConstUpperTriMatrixView<Ta> A22 = A.subTriMatrix(j1,j2);
374             VectorView<T> x2 = x.subVector(j1,j2);
375 
376             if (j1 != 0) {
377                 ConstMatrixView<Ta> A12 = A.subMatrix(0,j1,j1,j2);
378                 VectorView<T> x1 = x.subVector(0,j1);
379                 UnitAMultMV1<true,false>(A12,x2,x1);
380             }
381             DoMultEqMV(A22,x2);
382         }
383     }
384 
385     template <class T, class Ta>
DoMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)386     static inline void DoMultEqMV(
387         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
388     {
389         if (A.isrm()) RowMultEqMV<true>(A,x);
390         else if (A.iscm() && !SameStorage(A,x))
391             ColMultEqMV<true>(A,x);
392         else RowMultEqMV<false>(A,x);
393     }
394 
395     template <class T, class Ta>
NonBlasMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)396     static void NonBlasMultEqMV(
397         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
398     // x = A * x
399     {
400         TMVAssert(A.size() == x.size());
401         TMVAssert(x.size() > 0);
402         TMVAssert(x.step() == 1);
403         TMVAssert(x.ct() == NonConj);
404 
405         //     [ A11  0   0  ] [ 0  ]   [   0    ]
406         // x = [ A21 A22  0  ] [ x2 ] = [ A22 x2 ]
407         //     [ A31 A32 A33 ] [ 0  ]   [ A32 x2 ]
408 
409         const ptrdiff_t N = x.size(); // = A.size()
410         ptrdiff_t j2 = N;
411         for(const T* x2=x.cptr()+N-1; j2>0 && *x2==T(0); --j2,--x2);
412         if (j2 == 0) return;
413         ptrdiff_t j1 = 0;
414         for(const T* x1=x.cptr(); *x1==T(0); ++j1,++x1);
415         if (j1 == 0 && j2 == N) {
416             DoMultEqMV(A,x);
417         } else {
418             TMVAssert(j1 < j2);
419             ConstLowerTriMatrixView<Ta> A22 = A.subTriMatrix(j1,j2);
420             VectorView<T> x2 = x.subVector(j1,j2);
421 
422             if (j2 != N) {
423                 ConstMatrixView<Ta> A32 = A.subMatrix(j2,N,j1,j2);
424                 VectorView<T> x3 = x.subVector(j2,N);
425                 UnitAMultMV1<true,false>(A32,x2,x3);
426             }
427             DoMultEqMV(A22,x2);
428         }
429     }
430 
431 #ifdef BLAS
432     template <class T, class Ta>
BlasMultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)433     static inline void BlasMultEqMV(
434         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
435     { NonBlasMultEqMV(A,x); }
436     template <class T, class Ta>
BlasMultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)437     static inline void BlasMultEqMV(
438         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
439     { NonBlasMultEqMV(A,x); }
440 #ifdef INST_DOUBLE
441     template <>
BlasMultEqMV(const GenUpperTriMatrix<double> & A,VectorView<double> x)442     void BlasMultEqMV(
443         const GenUpperTriMatrix<double>& A, VectorView<double> x)
444     {
445         int n=A.size();
446         int lda=A.isrm()?A.stepi():A.stepj();
447         int xs=x.step();
448         double* xp = x.ptr();
449         BLASNAME(dtrmv) (
450             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
451             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
452             BLASV(n),BLASP(A.cptr()),BLASV(lda),
453             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
454     }
455     template <>
BlasMultEqMV(const GenLowerTriMatrix<double> & A,VectorView<double> x)456     void BlasMultEqMV(
457         const GenLowerTriMatrix<double>& A, VectorView<double> x)
458     {
459         int n=A.size();
460         int lda=A.isrm()?A.stepi():A.stepj();
461         int xs=x.step();
462         double* xp = x.ptr();
463         BLASNAME(dtrmv) (
464             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
465             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
466             BLASV(n),BLASP(A.cptr()),BLASV(lda),
467             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
468     }
469     template <>
BlasMultEqMV(const GenUpperTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> x)470     void BlasMultEqMV(
471         const GenUpperTriMatrix<std::complex<double> >& A,
472         VectorView<std::complex<double> > x)
473     {
474         int n=A.size();
475         int lda=A.isrm()?A.stepi():A.stepj();
476         int xs=x.step();
477         std::complex<double>* xp = x.ptr();
478         if (A.iscm() && A.isconj()) {
479 #ifdef CBLAS
480             BLASNAME(ztrmv) (
481                 BLASRM BLASCH_LO, BLASCH_CT,
482                 A.isunit()?BLASCH_U:BLASCH_NU,
483                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
484                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
485 #else
486             x.conjugateSelf();
487             BLASNAME(ztrmv) (
488                 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
489                 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
490                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
491                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
492             x.conjugateSelf();
493 #endif
494         } else {
495             BLASNAME(ztrmv) (
496                 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
497                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
498                 A.isunit()?BLASCH_U:BLASCH_NU,
499                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
500                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
501         }
502     }
503     template <>
BlasMultEqMV(const GenLowerTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> x)504     void BlasMultEqMV(
505         const GenLowerTriMatrix<std::complex<double> >& A,
506         VectorView<std::complex<double> > x)
507     {
508         int n=A.size();
509         int lda=A.isrm()?A.stepi():A.stepj();
510         int xs=x.step();
511         std::complex<double>* xp = x.ptr();
512         if (A.iscm() && A.isconj()) {
513 #ifdef CBLAS
514             BLASNAME(ztrmv) (
515                 BLASRM BLASCH_UP, BLASCH_CT,
516                 A.isunit()?BLASCH_U:BLASCH_NU,
517                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
518                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
519 #else
520             x.conjugateSelf();
521             BLASNAME(ztrmv) (
522                 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
523                 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
524                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
525                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
526             x.conjugateSelf();
527 #endif
528         } else {
529             BLASNAME(ztrmv) (
530                 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
531                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
532                 A.isunit()?BLASCH_U:BLASCH_NU,
533                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
534                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
535         }
536     }
537     template <>
BlasMultEqMV(const GenUpperTriMatrix<double> & A,VectorView<std::complex<double>> x)538     void BlasMultEqMV(
539         const GenUpperTriMatrix<double>& A,
540         VectorView<std::complex<double> > x)
541     {
542         int n=A.size();
543         int lda=A.isrm()?A.stepi():A.stepj();
544         int xs=2*x.step();
545         double* xp = (double*) x.ptr();
546         BLASNAME(dtrmv) (
547             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
548             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
549             BLASV(n),BLASP(A.cptr()),BLASV(lda),
550             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
551         BLASNAME(dtrmv) (
552             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
553             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
554             BLASV(n),BLASP(A.cptr()),BLASV(lda),
555             BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1);
556     }
557     template <>
BlasMultEqMV(const GenLowerTriMatrix<double> & A,VectorView<std::complex<double>> x)558     void BlasMultEqMV(
559         const GenLowerTriMatrix<double>& A,
560         VectorView<std::complex<double> > x)
561     {
562         int n=A.size();
563         int lda=A.isrm()?A.stepi():A.stepj();
564         int xs=2*x.step();
565         double* xp = (double*) x.ptr();
566         BLASNAME(dtrmv) (
567             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
568             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
569             BLASV(n),BLASP(A.cptr()),BLASV(lda),
570             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
571         BLASNAME(dtrmv) (
572             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
573             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
574             BLASV(n),BLASP(A.cptr()),BLASV(lda),
575             BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1);
576     }
577 #endif
578 #ifdef INST_FLOAT
579     template <>
BlasMultEqMV(const GenUpperTriMatrix<float> & A,VectorView<float> x)580     void BlasMultEqMV(
581         const GenUpperTriMatrix<float>& A, VectorView<float> x)
582     {
583         int n=A.size();
584         int lda=A.isrm()?A.stepi():A.stepj();
585         int xs=x.step();
586         float* xp = x.ptr();
587         BLASNAME(strmv) (
588             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
589             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
590             BLASV(n),BLASP(A.cptr()),BLASV(lda),
591             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
592     }
593     template <>
BlasMultEqMV(const GenLowerTriMatrix<float> & A,VectorView<float> x)594     void BlasMultEqMV(
595         const GenLowerTriMatrix<float>& A, VectorView<float> x)
596     {
597         int n=A.size();
598         int lda=A.isrm()?A.stepi():A.stepj();
599         int xs=x.step();
600         float* xp = x.ptr();
601         BLASNAME(strmv) (
602             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
603             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
604             BLASV(n),BLASP(A.cptr()),BLASV(lda),
605             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
606     }
607     template <>
BlasMultEqMV(const GenUpperTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> x)608     void BlasMultEqMV(
609         const GenUpperTriMatrix<std::complex<float> >& A,
610         VectorView<std::complex<float> > x)
611     {
612         int n=A.size();
613         int lda=A.isrm()?A.stepi():A.stepj();
614         int xs=x.step();
615         std::complex<float>* xp = x.ptr();
616         if (A.iscm() && A.isconj()) {
617 #ifdef CBLAS
618             BLASNAME(ctrmv) (
619                 BLASRM BLASCH_LO, BLASCH_CT,
620                 A.isunit()?BLASCH_U:BLASCH_NU,
621                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
622                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
623 #else
624             x.conjugateSelf();
625             BLASNAME(ctrmv) (
626                 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
627                 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
628                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
629                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
630             x.conjugateSelf();
631 #endif
632         } else {
633             BLASNAME(ctrmv) (
634                 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
635                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
636                 A.isunit()?BLASCH_U:BLASCH_NU,
637                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
638                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
639         }
640     }
641     template <>
BlasMultEqMV(const GenLowerTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> x)642     void BlasMultEqMV(
643         const GenLowerTriMatrix<std::complex<float> >& A,
644         VectorView<std::complex<float> > x)
645     {
646         int n=A.size();
647         int lda=A.isrm()?A.stepi():A.stepj();
648         int xs=x.step();
649         std::complex<float>* xp = x.ptr();
650         if (A.iscm() && A.isconj()) {
651 #ifdef CBLAS
652             BLASNAME(ctrmv) (
653                 BLASRM BLASCH_UP, BLASCH_CT,
654                 A.isunit()?BLASCH_U:BLASCH_NU,
655                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
656                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
657 #else
658             x.conjugateSelf();
659             BLASNAME(ctrmv) (
660                 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
661                 A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
662                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
663                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
664             x.conjugateSelf();
665 #endif
666         } else {
667             BLASNAME(ctrmv) (
668                 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
669                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
670                 A.isunit()?BLASCH_U:BLASCH_NU,
671                 BLASV(n),BLASP(A.cptr()),BLASV(lda),
672                 BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
673         }
674     }
675     template <>
BlasMultEqMV(const GenUpperTriMatrix<float> & A,VectorView<std::complex<float>> x)676     void BlasMultEqMV(
677         const GenUpperTriMatrix<float>& A,
678         VectorView<std::complex<float> > x)
679     {
680         int n=A.size();
681         int lda=A.isrm()?A.stepi():A.stepj();
682         int xs=2*x.step();
683         float* xp = (float*) x.ptr();
684         BLASNAME(strmv) (
685             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
686             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
687             BLASV(n),BLASP(A.cptr()),BLASV(lda),
688             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
689         BLASNAME(strmv) (
690             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
691             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
692             BLASV(n),BLASP(A.cptr()),BLASV(lda),
693             BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1);
694     }
695     template <>
BlasMultEqMV(const GenLowerTriMatrix<float> & A,VectorView<std::complex<float>> x)696     void BlasMultEqMV(
697         const GenLowerTriMatrix<float>& A,
698         VectorView<std::complex<float> > x)
699     {
700         int n=A.size();
701         int lda=A.isrm()?A.stepi():A.stepj();
702         int xs=2*x.step();
703         float* xp = (float*) x.ptr();
704         BLASNAME(strmv) (
705             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
706             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
707             BLASV(n),BLASP(A.cptr()),BLASV(lda),
708             BLASP(xp),BLASV(xs) BLAS1 BLAS1 BLAS1);
709         BLASNAME(strmv) (
710             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
711             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
712             BLASV(n),BLASP(A.cptr()),BLASV(lda),
713             BLASP(xp+1),BLASV(xs) BLAS1 BLAS1 BLAS1);
714     }
715 #endif // FLOAT
716 #endif // BLAS
717 
718     template <class T, class Ta>
MultEqMV(const GenUpperTriMatrix<Ta> & A,VectorView<T> x)719     static void MultEqMV(
720         const GenUpperTriMatrix<Ta>& A, VectorView<T> x)
721     {
722 #ifdef XDEBUG
723         Vector<T> x0 = x;
724         Matrix<Ta> A0 = A;
725         Vector<T> x2 = A0 * x0;
726         //cout<<"MultEqMV: A = "<<A<<"x = "<<x<<endl;
727 #endif
728         TMVAssert(A.size() == x.size());
729         TMVAssert(x.size() > 0);
730         TMVAssert(x.step() == 1);
731 
732         if (x.isconj()) MultEqMV(A.conjugate(),x.conjugate());
733         else {
734 #ifdef BLAS
735             if ((A.isrm() && A.stepi()>0) || (A.iscm() && A.stepj()>0))
736                 BlasMultEqMV(A,x);
737             else {
738                 if (A.isunit()) {
739                     UpperTriMatrix<Ta,UnitDiag|RowMajor> A2(A);
740                     BlasMultEqMV(A2,x);
741                 } else {
742                     UpperTriMatrix<Ta,NonUnitDiag|RowMajor> A2(A);
743                     BlasMultEqMV(A2,x);
744                 }
745             }
746 #else
747             NonBlasMultEqMV(A,x);
748 #endif
749         }
750 #ifdef XDEBUG
751         //cout<<"-> x = "<<x<<endl<<"x2 = "<<x2<<endl;
752         if (!(Norm(x-x2) <= 0.001*(Norm(A0)*Norm(x0)))) {
753             cerr<<"MultEqMV: \n";
754             cerr<<"A = "<<A.cptr()<<"  "<<TMV_Text(A)<<"  "<<A0<<endl;
755             cerr<<"x = "<<x.cptr()<<"  "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x0<<endl;
756             cerr<<"-> x = "<<x<<endl;
757             cerr<<"x2 = "<<x2<<endl;
758             abort();
759         }
760 #endif
761     }
762 
763     template <class T, class Ta>
MultEqMV(const GenLowerTriMatrix<Ta> & A,VectorView<T> x)764     static void MultEqMV(
765         const GenLowerTriMatrix<Ta>& A, VectorView<T> x)
766     {
767 #ifdef XDEBUG
768         Vector<T> x0 = x;
769         Matrix<Ta> A0 = A;
770         Vector<T> x2 = A0 * x0;
771 #endif
772         TMVAssert(A.size() == x.size());
773         TMVAssert(x.size() > 0);
774         TMVAssert(x.step() == 1);
775 
776         if (x.isconj()) MultEqMV(A.conjugate(),x.conjugate());
777         else {
778 #ifdef BLAS
779             if ( (A.isrm() && A.stepi()>0) || (A.iscm() && A.stepj()>0) )
780                 BlasMultEqMV(A,x);
781             else {
782                 if (A.isunit()) {
783                     LowerTriMatrix<Ta,UnitDiag|RowMajor> A2(A);
784                     BlasMultEqMV(A2,x);
785                 } else {
786                     LowerTriMatrix<Ta,NonUnitDiag|RowMajor> A2(A);
787                     BlasMultEqMV(A2,x);
788                 }
789             }
790 #else
791             NonBlasMultEqMV(A,x);
792 #endif
793         }
794 
795 #ifdef XDEBUG
796         if (!(Norm(x-x2) <= 0.001*(Norm(A0)*Norm(x0)))) {
797             cerr<<"MultEqMV: \n";
798             cerr<<"A = "<<A.cptr()<<"  "<<TMV_Text(A)<<"  "<<A0<<endl;
799             cerr<<"x = "<<x.cptr()<<"  "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x0<<endl;
800             cerr<<"-> x = "<<x<<endl;
801             cerr<<"x2 = "<<x2<<endl;
802             abort();
803         }
804 #endif
805     }
806 
807     template <bool add, class T, class Ta, class Tx>
MultMV(const T alpha,const GenUpperTriMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)808     void MultMV(
809         const T alpha, const GenUpperTriMatrix<Ta>& A,
810         const GenVector<Tx>& x, VectorView<T> y)
811     // y (+)= alpha * A * x
812     {
813 #ifdef XDEBUG
814         Vector<Tx> x0 = x;
815         Vector<T> y0 = y;
816         Matrix<Ta> A0 = A;
817         Vector<T> y2 = alpha*A0*x0;
818         if (add) y2 += y0;
819 #endif
820         TMVAssert(A.size() == x.size());
821         TMVAssert(A.size() == y.size());
822 
823         if (y.size() > 0) {
824             if (alpha==T(0)) {
825                 if (!add) y.setZero();
826             } else if (!add && y.step() == 1) {
827                 y = x;
828                 MultEqMV(A,y);
829                 y *= alpha;
830             } else {
831                 Vector<T> xx = alpha*x;
832                 MultEqMV(A,xx.view());
833                 if (add) y += xx;
834                 else y = xx;
835             }
836         }
837 #ifdef XDEBUG
838         if (!(Norm(y-y2) <=
839               0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
840                      (add?Norm(y0):TMV_RealType(T)(0))))) {
841             cerr<<"MultMV: alpha = "<<alpha<<endl;
842             cerr<<"add = "<<add<<endl;
843             cerr<<"A = "<<TMV_Text(A)<<"  "<<A0<<endl;
844             cerr<<"x = "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x0<<endl;
845             cerr<<"y = "<<TMV_Text(y)<<" step "<<y.step()<<"  "<<y0<<endl;
846             cerr<<"-> y = "<<y<<endl;
847             cerr<<"y2 = "<<y2<<endl;
848             abort();
849         }
850 #endif
851     }
852 
853     template <bool add, class T, class Ta, class Tx>
MultMV(const T alpha,const GenLowerTriMatrix<Ta> & A,const GenVector<Tx> & x,VectorView<T> y)854     void MultMV(
855         const T alpha, const GenLowerTriMatrix<Ta>& A,
856         const GenVector<Tx>& x, VectorView<T> y)
857     // y (+)= alpha * A * x
858     {
859 #ifdef XDEBUG
860         Vector<T> y0 = y;
861         Vector<Tx> x0 = x;
862         Matrix<Ta> A0 = A;
863         Vector<T> y2 = alpha*A0*x0;
864         if (add) y2 += y0;
865 #endif
866 
867         TMVAssert(A.size() == x.size());
868         TMVAssert(A.size() == y.size());
869 
870         if (y.size() > 0) {
871             if (alpha==T(0)) {
872                 if (!add) y.setZero();
873             } else if (!add && y.step() == 1) {
874                 y = x;
875                 MultEqMV(A,y);
876                 if (alpha != T(1)) y *= alpha;
877             } else {
878                 Vector<T> xx = alpha*x;
879                 MultEqMV(A,xx.view());
880                 if (add) y += xx;
881                 else y = xx;
882             }
883         }
884 #ifdef XDEBUG
885         if (!(Norm(y-y2) <=
886               0.001*(TMV_ABS(alpha)*Norm(A0)*Norm(x0)+
887                      (add?Norm(y0):TMV_RealType(T)(0))))) {
888             cerr<<"MultMV: alpha = "<<alpha<<endl;
889             cerr<<"add = "<<add<<endl;
890             cerr<<"A = "<<A.cptr()<<"  "<<TMV_Text(A)<<"  "<<A0<<endl;
891             cerr<<"x = "<<x.cptr()<<"  "<<TMV_Text(x)<<" step "<<x.step()<<"  "<<x0<<endl;
892             cerr<<"y = "<<y.cptr()<<"  "<<TMV_Text(y)<<" step "<<y.step()<<"  "<<y0<<endl;
893             cerr<<"-> y = "<<y<<endl;
894             cerr<<"y2 = "<<y2<<endl;
895             abort();
896         }
897 #endif
898     }
899 
900 #define InstFile "TMV_MultUV.inst"
901 #include "TMV_Inst.h"
902 #undef InstFile
903 
904 } // namespace tmv
905 
906 
907