1 ///////////////////////////////////////////////////////////////////////////////
2 //                                                                           //
3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis     //
4 // Copyright (C) 1998 - 2016                                                 //
5 // All rights reserved                                                       //
6 //                                                                           //
7 // The project is hosted at https://code.google.com/p/tmv-cpp/               //
8 // where you can find the current version and current documention.           //
9 //                                                                           //
10 // For concerns or problems with the software, Mike may be contacted at      //
11 // mike_jarvis17 [at] gmail.                                                 //
12 //                                                                           //
13 // This software is licensed under a FreeBSD license.  The file              //
14 // TMV_LICENSE should have bee included with this distribution.              //
15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/.       //
16 //                                                                           //
17 // Essentially, you can use this software however you want provided that     //
18 // you include the TMV_LICENSE file in any distribution that uses it.        //
19 //                                                                           //
20 ///////////////////////////////////////////////////////////////////////////////
21 
22 
23 //#define XDEBUG
24 
25 
26 #include "TMV_Blas.h"
27 #include "tmv/TMV_TriDiv.h"
28 #include "tmv/TMV_TriMatrix.h"
29 #include "tmv/TMV_Vector.h"
30 
31 // Using CblasRowMajor in all the other files isn't working correcly with
32 // MKL 10.2.2.  The usage in here hasn't given me any problems, but
33 // just to be safe, I'm disabling it here as well.  My guess is that I
34 // haven't tested the V/U code as thoroughly as the others, so it might
35 // turn up as a problem in more complicated code.
36 #ifdef CBLAS
37 #undef CBLAS
38 #endif
39 
40 #ifdef NOTHROW
41 #include <iostream>
42 #endif
43 
44 #ifdef XDEBUG
45 #include "tmv/TMV_MatrixArith.h"
46 #include "tmv/TMV_VectorArith.h"
47 #include <iostream>
48 using std::cout;
49 using std::cerr;
50 using std::endl;
51 #endif
52 
53 namespace tmv {
54 
55     //
56     // TriLDivEq V
57     //
58 
59     template <bool rm, bool ca, bool ua, class T, class Ta>
DoRowTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)60     static void DoRowTriLDivEq(
61         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
62     {
63         // Solve A x = y  where A is an upper triangle matrix
64         //cout<<"Row Upper\n";
65         TMVAssert(b.step()==1);
66         TMVAssert(A.size() == b.size());
67         TMVAssert(b.size() > 0);
68         TMVAssert(b.ct() == NonConj);
69         TMVAssert(rm == A.isrm());
70         TMVAssert(ca == A.isconj());
71         TMVAssert(ua == A.isunit());
72 
73         const ptrdiff_t N = A.size();
74 
75         const ptrdiff_t sj = (rm?1:A.stepj());
76         const ptrdiff_t ds = A.stepi()+sj;
77         const Ta* Aii = A.cptr() + (ua ? N-2 : N-1)*ds;
78         T* bi = b.ptr() + (ua ? N-2 : N-1);
79 
80         if (!ua) {
81             if (*Aii==Ta(0)) {
82 #ifdef NOTHROW
83                 std::cerr<<"Singular UpperTriMatrix found\n";
84                 exit(1);
85 #else
86                 throw SingularUpperTriMatrix<Ta>(A);
87 #endif
88             }
89 #ifdef TMVFLDEBUG
90             TMVAssert(bi >= b._first);
91             TMVAssert(bi < b._last);
92 #endif
93             *bi = TMV_Divide(*bi,(ca ? TMV_CONJ(*Aii) : *Aii));
94             Aii -= ds;
95             --bi;
96         }
97         if (N==1) return;
98 
99         for(ptrdiff_t i=N-1,len=1; i>0; --i,++len,Aii-=ds,--bi) {
100             // Actual row being done is i-1, not i
101 
102             // *bi -= A.row(i,i+1,N) * b.subVector(i+1,N);
103             const T* bj = bi+1;
104             const Ta* Aij = Aii + sj;
105             for(ptrdiff_t j=len;j>0;--j,++bj,(rm?++Aij:Aij+=sj)) {
106 #ifdef TMVFLDEBUG
107                 TMVAssert(bi >= b._first);
108                 TMVAssert(bi < b._last);
109 #endif
110                 *bi -= (*bj) * (ca ? TMV_CONJ(*Aij) : *Aij);
111             }
112 
113             if (!ua) {
114                 if (*Aii==Ta(0)) {
115 #ifdef NOTHROW
116                     std::cerr<<"Singular UpperTriMatrix found\n";
117                     exit(1);
118 #else
119                     throw SingularUpperTriMatrix<Ta>(A);
120 #endif
121                 }
122 #ifdef TMVFLDEBUG
123                 TMVAssert(bi >= b._first);
124                 TMVAssert(bi < b._last);
125 #endif
126                 *bi = TMV_Divide(*bi,(ca ? TMV_CONJ(*Aii) : *Aii));
127             }
128         }
129     }
130 
131     template <bool rm, class T, class Ta>
RowTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)132     static inline void RowTriLDivEq(
133         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
134     {
135         if (A.isconj())
136             if (A.isunit())
137                 DoRowTriLDivEq<rm,true,true>(A,b);
138             else
139                 DoRowTriLDivEq<rm,true,false>(A,b);
140         else
141             if (A.isunit())
142                 DoRowTriLDivEq<rm,false,true>(A,b);
143             else
144                 DoRowTriLDivEq<rm,false,false>(A,b);
145     }
146 
147     template <bool cm, bool ca, bool ua, class T, class Ta>
DoColTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)148     static void DoColTriLDivEq(
149         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
150     {
151         //cout<<"colmajor upper\n";
152         // Solve A x = y  where A is an upper triangle matrix
153         TMVAssert(b.step()==1);
154         TMVAssert(A.size() == b.size());
155         TMVAssert(b.size() > 0);
156         TMVAssert(b.ct() == NonConj);
157         TMVAssert(cm == A.iscm());
158         TMVAssert(ca == A.isconj());
159         TMVAssert(ua == A.isunit());
160 
161         const ptrdiff_t N = A.size();
162 
163         const ptrdiff_t si = (cm ? 1 : A.stepi());
164         const ptrdiff_t sj = A.stepj();
165         const ptrdiff_t ds = si+sj;
166         const Ta* A0j = A.cptr()+(N-1)*sj;
167         const Ta* Ajj = (ua ? 0 : A0j+(N-1)*si); // if unit, this isn't used.
168         T*const b0 = b.ptr();
169         T* bj = b0 + N-1;
170 
171         for(ptrdiff_t j=N-1; j>0; --j,--bj,A0j-=sj) {
172             if (*bj != T(0)) {
173                 if (!ua) {
174                     if (*Ajj==Ta(0)) {
175 #ifdef NOTHROW
176                         std::cerr<<"Singular UpperTriMatrix found\n";
177                         exit(1);
178 #else
179                         throw SingularUpperTriMatrix<Ta>(A);
180 #endif
181                     }
182 #ifdef TMVFLDEBUG
183                     TMVAssert(bj >= b._first);
184                     TMVAssert(bj < b._last);
185 #endif
186                     *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj));
187                     Ajj-=ds;
188                 }
189 
190                 // b.subVector(0,j) -= *bj * A.col(j,0,j);
191                 T* bi = b0;
192                 const Ta* Aij = A0j;
193                 for(ptrdiff_t i=j;i>0;--i,++bi,(cm?++Aij:Aij+=si)) {
194 #ifdef TMVFLDEBUG
195                     TMVAssert(bi >= b._first);
196                     TMVAssert(bi < b._last);
197 #endif
198                     *bi -= *bj * (ca ? TMV_CONJ(*Aij) : *Aij);
199                 }
200             }
201             else if (!ua) Ajj -= ds;
202         }
203         if (!ua && *bj != T(0)) {
204             if (*Ajj==Ta(0)) {
205 #ifdef NOTHROW
206                 std::cerr<<"Singular UpperTriMatrix found\n";
207                 exit(1);
208 #else
209                 throw SingularUpperTriMatrix<Ta>(A);
210 #endif
211             }
212 #ifdef TMVFLDEBUG
213             TMVAssert(bj >= b._first);
214             TMVAssert(bj < b._last);
215 #endif
216             *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj));
217         }
218     }
219 
220     template <bool cm, class T, class Ta>
ColTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)221     static inline void ColTriLDivEq(
222         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
223     {
224         if (A.isconj())
225             if (A.isunit())
226                 DoColTriLDivEq<cm,true,true>(A,b);
227             else
228                 DoColTriLDivEq<cm,true,false>(A,b);
229         else
230             if (A.isunit())
231                 DoColTriLDivEq<cm,false,true>(A,b);
232             else
233                 DoColTriLDivEq<cm,false,false>(A,b);
234     }
235 
236     template <bool rm, bool ca, bool ua, class T, class Ta>
DoRowTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)237     static void DoRowTriLDivEq(
238         const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
239     {
240         // Solve A x = y  where A is a lower triangle matrix
241         TMVAssert(b.step()==1);
242         TMVAssert(A.size() == b.size());
243         TMVAssert(b.size() > 0);
244         TMVAssert(b.ct() == NonConj);
245         TMVAssert(rm == A.isrm());
246         TMVAssert(ca == A.isconj());
247         TMVAssert(ua == A.isunit());
248 
249         const ptrdiff_t N = A.size();
250 
251         const ptrdiff_t sj = (rm ? 1 : A.stepj());
252         const ptrdiff_t si = A.stepi();
253 
254         const Ta* Ai0 = A.cptr();
255         T* b0 = b.ptr();
256 
257         if (!ua) {
258             if (*Ai0==Ta(0)) {
259 #ifdef NOTHROW
260                 std::cerr<<"Singular LowerTriMatrix found\n";
261                 exit(1);
262 #else
263                 throw SingularLowerTriMatrix<Ta>(A);
264 #endif
265             }
266 #ifdef TMVFLDEBUG
267             TMVAssert(b0 >= b._first);
268             TMVAssert(b0 < b._last);
269 #endif
270             *b0 = TMV_Divide(*b0,(ca ? TMV_CONJ(*Ai0) : *Ai0));
271         }
272 
273         T* bi = b0+1;
274         Ai0 += si;
275         for(ptrdiff_t i=1,len=1;i<N;++i,++len,++bi,Ai0+=si) {
276             // *bi -= A.row(i,0,i) * b.subVector(0,i);
277             const Ta* Aij = Ai0;
278             const T* bj = b0;
279             for(ptrdiff_t j=len;j>0;--j,++bj,(rm?++Aij:Aij+=sj)){
280 #ifdef TMVFLDEBUG
281                 TMVAssert(bi >= b._first);
282                 TMVAssert(bi < b._last);
283 #endif
284                 *bi -= (*bj) * (ca ? TMV_CONJ(*Aij) : *Aij);
285             }
286             if (!ua) {
287                 // Aij is Aii after the above for loop
288                 if (*Aij==Ta(0)) {
289 #ifdef NOTHROW
290                     std::cerr<<"Singular LowerTriMatrix found\n";
291                     exit(1);
292 #else
293                     throw SingularLowerTriMatrix<Ta>(A);
294 #endif
295                 }
296 #ifdef TMVFLDEBUG
297                 TMVAssert(bi >= b._first);
298                 TMVAssert(bi < b._last);
299 #endif
300                 *bi = TMV_Divide(*bi,(ca ? TMV_CONJ(*Aij) : *Aij));
301             }
302         }
303     }
304 
305     template <bool rm, class T, class Ta>
RowTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)306     static inline void RowTriLDivEq(
307         const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
308     {
309         if (A.isconj())
310             if (A.isunit())
311                 DoRowTriLDivEq<rm,true,true>(A,b);
312             else
313                 DoRowTriLDivEq<rm,true,false>(A,b);
314         else
315             if (A.isunit())
316                 DoRowTriLDivEq<rm,false,true>(A,b);
317             else
318                 DoRowTriLDivEq<rm,false,false>(A,b);
319     }
320 
321     template <bool cm, bool ca, bool ua, class T, class Ta>
DoColTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)322     static void DoColTriLDivEq(
323         const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
324     {
325         // Solve A x = y  where A is a lower triangle matrix
326         TMVAssert(b.step()==1);
327         TMVAssert(A.size() == b.size());
328         TMVAssert(b.size() > 0);
329         TMVAssert(b.ct() == NonConj);
330         TMVAssert(cm == A.iscm());
331         TMVAssert(ca == A.isconj());
332         TMVAssert(ua == A.isunit());
333 
334         const ptrdiff_t N = A.size();
335 
336         const ptrdiff_t si = (cm ? 1 : A.stepi());
337         const ptrdiff_t ds = A.stepj()+si;
338         const Ta* Ajj = A.cptr();
339         T* bj = b.ptr();
340 
341         for(ptrdiff_t j=0,len=N-1;len>0;++j,--len,++bj,Ajj+=ds) if (*bj != T(0)) {
342             if (!ua) {
343                 if (*Ajj==Ta(0)) {
344 #ifdef NOTHROW
345                     std::cerr<<"Singular LowerTriMatrix found\n";
346                     exit(1);
347 #else
348                     throw SingularLowerTriMatrix<Ta>(A);
349 #endif
350                 }
351 #ifdef TMVFLDEBUG
352                 TMVAssert(bj >= b._first);
353                 TMVAssert(bj < b._last);
354 #endif
355                 *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj));
356             }
357             // b.subVector(j+1,N) -= *bj * A.col(j,j+1,N);
358             T* bi = bj+1;
359             const Ta* Aij = Ajj+si;
360             for(ptrdiff_t i=len;i>0;--i,++bi,(cm?++Aij:Aij+=si)){
361 #ifdef TMVFLDEBUG
362                 TMVAssert(bi >= b._first);
363                 TMVAssert(bi < b._last);
364 #endif
365                 *bi -= *bj * (ca ? TMV_CONJ(*Aij) : *Aij);
366             }
367         }
368         if (!ua && *bj != T(0)) {
369             if (*Ajj==Ta(0)) {
370 #ifdef NOTHROW
371                 std::cerr<<"Singular LowerTriMatrix found\n";
372                 exit(1);
373 #else
374                 throw SingularLowerTriMatrix<Ta>(A);
375 #endif
376             }
377 #ifdef TMVFLDEBUG
378             TMVAssert(bj >= b._first);
379             TMVAssert(bj < b._last);
380 #endif
381             *bj = TMV_Divide(*bj,(ca ? TMV_CONJ(*Ajj) : *Ajj));
382         }
383     }
384 
385     template <bool cm, class T, class Ta>
ColTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)386     static inline void ColTriLDivEq(
387         const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
388     {
389         if (A.isconj())
390             if (A.isunit())
391                 DoColTriLDivEq<cm,true,true>(A,b);
392             else
393                 DoColTriLDivEq<cm,true,false>(A,b);
394         else
395             if (A.isunit())
396                 DoColTriLDivEq<cm,false,true>(A,b);
397             else
398                 DoColTriLDivEq<cm,false,false>(A,b);
399     }
400 
401     template <class T, class Ta>
DoTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)402     static inline void DoTriLDivEq(
403         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
404     {
405         if (A.isrm()) RowTriLDivEq<true>(A,b);
406         else if (A.iscm()) ColTriLDivEq<true>(A,b);
407         else RowTriLDivEq<false>(A,b);
408     }
409 
410     template <class T, class Ta>
DoTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)411     static inline void DoTriLDivEq(
412         const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
413     {
414         if (A.isrm()) RowTriLDivEq<true>(A,b);
415         else if (A.iscm()) ColTriLDivEq<true>(A,b);
416         else RowTriLDivEq<false>(A,b);
417     }
418 
419     template <class T, class Ta>
NonBlasTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)420     static void NonBlasTriLDivEq(
421         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
422     {
423         //cout<<"Upper LDivEq vect\n";
424         TMVAssert(A.size() == b.size());
425         TMVAssert(b.size()>0);
426         TMVAssert(b.ct() == NonConj);
427 
428         if (b.step() == 1) {
429             ptrdiff_t i2 = b.size();
430             for(const T* b2 = b.cptr()+i2-1; i2>0 && *b2==T(0); --i2,--b2);
431             if (i2==0) return;
432             else DoTriLDivEq(A.subTriMatrix(0,i2),b.subVector(0,i2));
433         } else {
434             Vector<T> bb = b;
435             NonBlasTriLDivEq(A,bb.view());
436             b = bb;
437         }
438     }
439 
440     template <class T, class Ta>
NonBlasTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)441     static void NonBlasTriLDivEq(
442         const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
443     {
444         //cout<<"Lower LDivEq vect\n";
445         TMVAssert(A.size() == b.size());
446         TMVAssert(b.size()>0);
447         TMVAssert(b.ct() == NonConj);
448 
449         if (b.step() == 1) {
450             const ptrdiff_t N = b.size();
451             ptrdiff_t i1 = 0;
452             for(const T* b1 = b.cptr(); i1<N && *b1==T(0); ++i1,++b1);
453             if (i1==N) return;
454             else
455                 DoTriLDivEq(A.subTriMatrix(i1,N),b.subVector(i1,N));
456         } else {
457             Vector<T> bb = b;
458             NonBlasTriLDivEq(A,bb.view());
459             b = bb;
460         }
461     }
462 
463 #ifdef BLAS
464     template <class T, class Ta>
BlasTriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)465     static inline void BlasTriLDivEq(
466         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
467     { NonBlasTriLDivEq(A,b); }
468     template <class T, class Ta>
BlasTriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)469     static inline void BlasTriLDivEq(
470         const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
471     { NonBlasTriLDivEq(A,b); }
472 #ifdef INST_DOUBLE
473     template <>
BlasTriLDivEq(const GenUpperTriMatrix<double> & A,VectorView<double> b)474     void BlasTriLDivEq(
475         const GenUpperTriMatrix<double>& A, VectorView<double> b)
476     {
477         int n=A.size();
478         int lda = A.isrm()?A.stepi():A.stepj();
479         int bs = b.step();
480         double* bp = b.ptr();
481         if (bs < 0) bp += (n-1)*bs;
482         BLASNAME(dtrsv) (
483             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
484             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
485             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
486             BLAS1 BLAS1 BLAS1);
487     }
488     template <>
BlasTriLDivEq(const GenLowerTriMatrix<double> & A,VectorView<double> b)489     void BlasTriLDivEq(
490         const GenLowerTriMatrix<double>& A, VectorView<double> b)
491     {
492         int n=A.size();
493         int lda = A.isrm()?A.stepi():A.stepj();
494         int bs = b.step();
495         double* bp = b.ptr();
496         if (bs < 0) bp += (n-1)*bs;
497         BLASNAME(dtrsv) (
498             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
499             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
500             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
501             BLAS1 BLAS1 BLAS1);
502     }
503     template <>
BlasTriLDivEq(const GenUpperTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> b)504     void BlasTriLDivEq(
505         const GenUpperTriMatrix<std::complex<double> >& A,
506         VectorView<std::complex<double> > b)
507     {
508         int n=A.size();
509         int lda = A.isrm()?A.stepi():A.stepj();
510         int bs = b.step();
511         std::complex<double>* bp = b.ptr();
512         if (bs < 0) bp += (n-1)*bs;
513         if (A.iscm() && A.isconj()) {
514 #ifdef CBLAS
515             BLASNAME(ztrsv) (
516                 BLASRM BLASCH_LO,BLASCH_CT,
517                 A.isunit()?BLASCH_U:BLASCH_NU,
518                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
519                 BLAS1 BLAS1 BLAS1);
520 #else
521             b.conjugateSelf();
522             BLASNAME(ztrsv) (
523                 BLASCM BLASCH_UP,BLASCH_NT,
524                 A.isunit()?BLASCH_U:BLASCH_NU,
525                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
526                 BLAS1 BLAS1 BLAS1);
527             b.conjugateSelf();
528 #endif
529         } else {
530             BLASNAME(ztrsv) (
531                 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
532                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
533                 A.isunit()?BLASCH_U:BLASCH_NU,
534                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
535                 BLAS1 BLAS1 BLAS1);
536         }
537     }
538     template <>
BlasTriLDivEq(const GenLowerTriMatrix<std::complex<double>> & A,VectorView<std::complex<double>> b)539     void BlasTriLDivEq(
540         const GenLowerTriMatrix<std::complex<double> >& A,
541         VectorView<std::complex<double> > b)
542     {
543         int n=A.size();
544         int lda = A.isrm()?A.stepi():A.stepj();
545         int bs = b.step();
546         std::complex<double>* bp = b.ptr();
547         if (bs < 0) bp += (n-1)*bs;
548         if (A.iscm() && A.isconj()) {
549 #ifdef CBLAS
550             BLASNAME(ztrsv) (
551                 BLASRM BLASCH_UP,BLASCH_CT,
552                 A.isunit()?BLASCH_U:BLASCH_NU,
553                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
554                 BLAS1 BLAS1 BLAS1);
555 #else
556             b.conjugateSelf();
557             BLASNAME(ztrsv) (
558                 BLASCM BLASCH_LO,BLASCH_NT,
559                 A.isunit()?BLASCH_U:BLASCH_NU,
560                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
561                 BLAS1 BLAS1 BLAS1);
562             b.conjugateSelf();
563 #endif
564         } else {
565             BLASNAME(ztrsv) (
566                 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
567                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
568                 A.isunit()?BLASCH_U:BLASCH_NU,
569                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
570                 BLAS1 BLAS1 BLAS1);
571         }
572     }
573     template <>
BlasTriLDivEq(const GenUpperTriMatrix<double> & A,VectorView<std::complex<double>> b)574     void BlasTriLDivEq(
575         const GenUpperTriMatrix<double>& A,
576         VectorView<std::complex<double> > b)
577     {
578         int n=A.size();
579         int lda = A.isrm()?A.stepi():A.stepj();
580         int bs = 2*b.step();
581         double* bp = (double*)(b.ptr());
582         if (bs < 0) bp += (n-1)*bs;
583         BLASNAME(dtrsv) (
584             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
585             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
586             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
587             BLAS1 BLAS1 BLAS1);
588         BLASNAME(dtrsv) (
589             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
590             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
591             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs)
592             BLAS1 BLAS1 BLAS1);
593     }
594     template <>
BlasTriLDivEq(const GenLowerTriMatrix<double> & A,VectorView<std::complex<double>> b)595     void BlasTriLDivEq(
596         const GenLowerTriMatrix<double>& A,
597         VectorView<std::complex<double> > b)
598     {
599         int n=A.size();
600         int lda = A.isrm()?A.stepi():A.stepj();
601         int bs = 2*b.step();
602         double* bp = (double*)(b.ptr());
603         if (bs < 0) bp += (n-1)*bs;
604         BLASNAME(dtrsv) (
605             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
606             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
607             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
608             BLAS1 BLAS1 BLAS1);
609         BLASNAME(dtrsv) (
610             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
611             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
612             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs)
613             BLAS1 BLAS1 BLAS1);
614     }
615 #endif
616 #ifdef INST_FLOAT
617     template <>
BlasTriLDivEq(const GenUpperTriMatrix<float> & A,VectorView<float> b)618     void BlasTriLDivEq(
619         const GenUpperTriMatrix<float>& A, VectorView<float> b)
620     {
621         int n=A.size();
622         int lda = A.isrm()?A.stepi():A.stepj();
623         int bs = b.step();
624         float* bp = (b.ptr());
625         if (bs < 0) bp += (n-1)*bs;
626         BLASNAME(strsv) (
627             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
628             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
629             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
630             BLAS1 BLAS1 BLAS1);
631     }
632     template <>
BlasTriLDivEq(const GenLowerTriMatrix<float> & A,VectorView<float> b)633     void BlasTriLDivEq(
634         const GenLowerTriMatrix<float>& A, VectorView<float> b)
635     {
636         int n=A.size();
637         int lda = A.isrm()?A.stepi():A.stepj();
638         int bs = b.step();
639         float* bp = (b.ptr());
640         if (bs < 0) bp += (n-1)*bs;
641         BLASNAME(strsv) (
642             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
643             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
644             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
645             BLAS1 BLAS1 BLAS1);
646     }
647     template <>
BlasTriLDivEq(const GenUpperTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> b)648     void BlasTriLDivEq(
649         const GenUpperTriMatrix<std::complex<float> >& A,
650         VectorView<std::complex<float> > b)
651     {
652         int n=A.size();
653         int lda = A.isrm()?A.stepi():A.stepj();
654         int bs = b.step();
655         std::complex<float>* bp = (b.ptr());
656         if (bs < 0) bp += (n-1)*bs;
657         if (A.iscm() && A.isconj()) {
658 #ifdef CBLAS
659             BLASNAME(ctrsv) (
660                 BLASRM BLASCH_LO,BLASCH_CT,
661                 A.isunit()?BLASCH_U:BLASCH_NU,
662                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
663                 BLAS1 BLAS1 BLAS1);
664 #else
665             b.conjugateSelf();
666             BLASNAME(ctrsv) (
667                 BLASCM BLASCH_UP,BLASCH_NT,
668                 A.isunit()?BLASCH_U:BLASCH_NU,
669                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
670                 BLAS1 BLAS1 BLAS1);
671             b.conjugateSelf();
672 #endif
673         } else {
674             BLASNAME(ctrsv) (
675                 BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
676                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
677                 A.isunit()?BLASCH_U:BLASCH_NU,
678                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
679                 BLAS1 BLAS1 BLAS1);
680         }
681     }
682     template <>
BlasTriLDivEq(const GenLowerTriMatrix<std::complex<float>> & A,VectorView<std::complex<float>> b)683     void BlasTriLDivEq(
684         const GenLowerTriMatrix<std::complex<float> >& A,
685         VectorView<std::complex<float> > b)
686     {
687         int n=A.size();
688         int lda = A.isrm()?A.stepi():A.stepj();
689         int bs = b.step();
690         std::complex<float>* bp = (b.ptr());
691         if (bs < 0) bp += (n-1)*bs;
692         if (A.iscm() && A.isconj()) {
693 #ifdef CBLAS
694             BLASNAME(ctrsv) (
695                 BLASRM BLASCH_UP,BLASCH_CT,
696                 A.isunit()?BLASCH_U:BLASCH_NU,
697                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
698                 BLAS1 BLAS1 BLAS1);
699 #else
700             b.conjugateSelf();
701             BLASNAME(ctrsv) (
702                 BLASCM BLASCH_LO,BLASCH_NT,
703                 A.isunit()?BLASCH_U:BLASCH_NU,
704                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
705                 BLAS1 BLAS1 BLAS1);
706             b.conjugateSelf();
707 #endif
708         } else {
709             BLASNAME(ctrsv) (
710                 BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
711                 A.iscm()?BLASCH_NT:A.isconj()?BLASCH_CT:BLASCH_T,
712                 A.isunit()?BLASCH_U:BLASCH_NU,
713                 BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
714                 BLAS1 BLAS1 BLAS1);
715         }
716     }
717     template <>
BlasTriLDivEq(const GenUpperTriMatrix<float> & A,VectorView<std::complex<float>> b)718     void BlasTriLDivEq(
719         const GenUpperTriMatrix<float>& A,
720         VectorView<std::complex<float> > b)
721     {
722         int n=A.size();
723         int lda = A.isrm()?A.stepi():A.stepj();
724         int bs = 2*b.step();
725         float* bp = (float*)(b.ptr());
726         if (bs < 0) bp += (n-1)*bs;
727         BLASNAME(strsv) (
728             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
729             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
730             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
731             BLAS1 BLAS1 BLAS1);
732         BLASNAME(strsv) (
733             BLASCM A.iscm()?BLASCH_UP:BLASCH_LO,
734             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
735             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs)
736             BLAS1 BLAS1 BLAS1);
737     }
738     template <>
BlasTriLDivEq(const GenLowerTriMatrix<float> & A,VectorView<std::complex<float>> b)739     void BlasTriLDivEq(
740         const GenLowerTriMatrix<float>& A,
741         VectorView<std::complex<float> > b)
742     {
743         int n=A.size();
744         int lda = A.isrm()?A.stepi():A.stepj();
745         int bs = 2*b.step();
746         float* bp = (float*)(b.ptr());
747         if (bs < 0) bp += (n-1)*bs;
748         BLASNAME(strsv) (
749             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
750             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
751             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp),BLASV(bs)
752             BLAS1 BLAS1 BLAS1);
753         BLASNAME(strsv) (
754             BLASCM A.iscm()?BLASCH_LO:BLASCH_UP,
755             A.iscm()?BLASCH_NT:BLASCH_T, A.isunit()?BLASCH_U:BLASCH_NU,
756             BLASV(n),BLASP(A.cptr()),BLASV(lda),BLASP(bp+1),BLASV(bs)
757             BLAS1 BLAS1 BLAS1);
758     }
759 #endif // FLOAT
760 #endif // BLAS
761 
762     template <class T, class Ta>
TriLDivEq(const GenUpperTriMatrix<Ta> & A,VectorView<T> b)763     void TriLDivEq(
764         const GenUpperTriMatrix<Ta>& A, VectorView<T> b)
765     {
766         TMVAssert(b.size() == A.size());
767 #ifdef XDEBUG
768         Matrix<Ta> A0(A);
769         Vector<T> b0(b);
770 #endif
771         if (b.size() > 0) {
772             if (b.isconj()) TriLDivEq(A.conjugate(),b.conjugate());
773             else
774 #ifdef BLAS
775                 if (isComplex(T()) && isReal(Ta()))
776                     BlasTriLDivEq(A,b);
777                 else if ( !((A.isrm() && A.stepi()>0) ||
778                             (A.iscm() && A.stepj()>0)) ) {
779                     if (A.isunit()) {
780                         UpperTriMatrix<Ta,UnitDiag|ColMajor> AA = A;
781                         BlasTriLDivEq(AA,b);
782                     } else {
783                         UpperTriMatrix<Ta,NonUnitDiag|ColMajor> AA = A;
784                         BlasTriLDivEq(AA,b);
785                     }
786                 }
787                 else BlasTriLDivEq(A,b);
788 #else
789             NonBlasTriLDivEq(A,b);
790 #endif
791         }
792 #ifdef XDEBUG
793         Vector<T> b2 = A0*b;
794         if (Norm(b2-b0) > 0.001*Norm(A0)*Norm(b0)) {
795             cerr<<"TriLDivEq: v/Upper\n";
796             cerr<<"A = "<<TMV_Text(A)<<"  "<<A0<<endl;
797             cerr<<"b = "<<TMV_Text(b)<<"  "<<b0<<endl;
798             cerr<<"Done: b = "<<b<<endl;
799             cerr<<"A*b = "<<b2<<endl;
800             abort();
801         }
802 #endif
803     }
804 
805     template <class T, class Ta>
TriLDivEq(const GenLowerTriMatrix<Ta> & A,VectorView<T> b)806     void TriLDivEq(const GenLowerTriMatrix<Ta>& A, VectorView<T> b)
807     {
808         TMVAssert(b.size() == A.size());
809 #ifdef XDEBUG
810         Matrix<Ta> A0(A);
811         Vector<T> b0(b);
812 #endif
813         if (b.size() > 0) {
814             if (b.isconj()) TriLDivEq(A.conjugate(),b.conjugate());
815             else {
816 #ifdef BLAS
817                 if (isComplex(T()) && isReal(Ta()))
818                     BlasTriLDivEq(A,b);
819                 else if ( !((A.isrm() && A.stepi()>0) ||
820                             (A.iscm() && A.stepj()>0)) ) {
821                     if (A.isunit()) {
822                         LowerTriMatrix<Ta,UnitDiag|ColMajor> AA = A;
823                         BlasTriLDivEq(AA,b);
824                     } else {
825                         LowerTriMatrix<Ta,NonUnitDiag|ColMajor> AA = A;
826                         BlasTriLDivEq(AA,b);
827                     }
828                 } else {
829                     BlasTriLDivEq(A,b);
830                 }
831 #else
832                 NonBlasTriLDivEq(A,b);
833 #endif
834             }
835         }
836 #ifdef XDEBUG
837         Vector<T> b2 = A0*b;
838         if (Norm(b2-b0) > 0.001*Norm(A0)*Norm(b0)) {
839             cerr<<"TriLDivEq: v/Lower\n";
840             cerr<<"A = "<<TMV_Text(A)<<"  "<<A<<endl;
841             cerr<<"b = "<<TMV_Text(b)<<"  "<<b0<<endl;
842             cerr<<"Done: b = "<<b<<endl;
843             cerr<<"A*b = "<<b2<<endl;
844             abort();
845         }
846 #endif
847     }
848 
849 #ifdef INST_INT
850 #undef INST_INT
851 #endif
852 
853 #define InstFile "TMV_TriDiv_V.inst"
854 #include "TMV_Inst.h"
855 #undef InstFile
856 
857 } // namespace tmv
858 
859 
860