1 #include "relapack.h"
2 #include "stdlib.h"
3 
4 static void RELAPACK_zgbtrf_rec(const blasint *, const blasint *, const blasint *,
5     const blasint *, double *, const blasint *, blasint *, double *, const blasint *, double *,
6     const blasint *, blasint *);
7 
8 
9 /** ZGBTRF computes an LU factorization of a complex m-by-n band matrix A using partial pivoting with row interchanges.
10  *
11  * This routine is functionally equivalent to LAPACK's zgbtrf.
12  * For details on its interface, see
13  * http://www.netlib.org/lapack/explore-html/dc/dcb/zgbtrf_8f.html
14  * */
RELAPACK_zgbtrf(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,blasint * info)15 void RELAPACK_zgbtrf(
16     const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
17     double *Ab, const blasint *ldAb, blasint *ipiv,
18     blasint *info
19 ) {
20 
21     // Check arguments
22     *info = 0;
23     if (*m < 0)
24         *info = -1;
25     else if (*n < 0)
26         *info = -2;
27     else if (*kl < 0)
28         *info = -3;
29     else if (*ku < 0)
30         *info = -4;
31     else if (*ldAb < 2 * *kl + *ku + 1)
32         *info = -6;
33     if (*info) {
34         const blasint minfo = -*info;
35         LAPACK(xerbla)("ZGBTRF", &minfo, strlen("ZGBTRF"));
36         return;
37     }
38 
39     if (*m == 0 || *n == 0) return;
40 
41     // Constant
42     const double ZERO[] = { 0., 0. };
43 
44     // Result upper band width
45     const blasint kv = *ku + *kl;
46 
47     // Unskew A
48     const blasint ldA[] = { *ldAb - 1 };
49     double *const A = Ab + 2 * kv;
50 
51     // Zero upper diagonal fill-in elements
52     blasint i, j;
53     for (j = 0; j < *n; j++) {
54         double *const A_j = A + 2 * *ldA * j;
55         for (i = MAX(0, j - kv); i < j - *ku; i++)
56             A_j[2 * i] = A_j[2 * i + 1] = 0.;
57     }
58 
59     // Allocate work space
60     const blasint n1 = ZREC_SPLIT(*n);
61     const blasint mWorkl = abs ( (kv > n1) ? MAX(1, *m - *kl) : kv);
62     const blasint nWorkl = abs ( (kv > n1) ? n1 : kv);
63     const blasint mWorku = abs ( (*kl > n1) ? n1 : *kl);
64     const blasint nWorku = abs ( (*kl > n1) ? MAX(0, *n - *kl) : *kl);
65     double *Workl = malloc(mWorkl * nWorkl * 2 * sizeof(double));
66     double *Worku = malloc(mWorku * nWorku * 2 * sizeof(double));
67     LAPACK(zlaset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
68     LAPACK(zlaset)("U", &mWorku, &nWorku, ZERO, ZERO, Worku, &mWorku);
69 
70     // Recursive kernel
71     RELAPACK_zgbtrf_rec(m, n, kl, ku, Ab, ldAb, ipiv, Workl, &mWorkl, Worku, &mWorku, info);
72 
73     // Free work space
74     free(Workl);
75     free(Worku);
76 }
77 
78 
79 /** zgbtrf's recursive compute kernel */
RELAPACK_zgbtrf_rec(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,double * Workl,const blasint * ldWorkl,double * Worku,const blasint * ldWorku,blasint * info)80 static void RELAPACK_zgbtrf_rec(
81     const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
82     double *Ab, const blasint *ldAb, blasint *ipiv,
83     double *Workl, const blasint *ldWorkl, double *Worku, const blasint *ldWorku,
84     blasint *info
85 ) {
86 
87     if (*n <= MAX(CROSSOVER_ZGBTRF, 1) || *n > *kl || *ldAb == 1) {
88         // Unblocked
89         LAPACK(zgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
90         return;
91     }
92 
93     // Constants
94     const double ONE[]  = { 1., 0. };
95     const double MONE[] = { -1., 0. };
96     const blasint    iONE[] = { 1 };
97     const blasint min11 = -11;
98 
99     // Loop iterators
100     blasint i, j;
101 
102     // Output upper band width
103     const blasint kv = *ku + *kl;
104 
105     // Unskew A
106     const blasint ldA[] = { *ldAb - 1 };
107     double *const A = Ab + 2 * kv;
108 
109     // Splitting
110     const blasint n1  = MIN(ZREC_SPLIT(*n), *kl);
111     const blasint n2  = *n - n1;
112     const blasint m1  = MIN(n1, *m);
113     const blasint m2  = *m - m1;
114     const blasint mn1 = MIN(m1, n1);
115     const blasint mn2 = MIN(m2, n2);
116 
117     // Ab_L *
118     //      Ab_BR
119     double *const Ab_L  = Ab;
120     double *const Ab_BR = Ab + 2 * *ldAb * n1;
121 
122     // A_L A_R
123     double *const A_L = A;
124     double *const A_R = A + 2 * *ldA * n1;
125 
126     // A_TL A_TR
127     // A_BL A_BR
128     double *const A_TL = A;
129     double *const A_TR = A + 2 * *ldA * n1;
130     double *const A_BL = A                 + 2 * m1;
131     double *const A_BR = A + 2 * *ldA * n1 + 2 * m1;
132 
133     // ipiv_T
134     // ipiv_B
135     blasint *const ipiv_T = ipiv;
136     blasint *const ipiv_B = ipiv + n1;
137 
138     // Banded splitting
139     const blasint n21 = MIN(n2, kv - n1);
140     const blasint n22 = MIN(n2 - n21, n1);
141     const blasint m21 = MIN(m2, *kl - m1);
142     const blasint m22 = MIN(m2 - m21, m1);
143 
144     //   n1 n21  n22
145     // m *  A_Rl ARr
146     double *const A_Rl = A_R;
147     double *const A_Rr = A_R + 2 * *ldA * n21;
148 
149     //     n1    n21    n22
150     // m1  *     A_TRl  A_TRr
151     // m21 A_BLt A_BRtl A_BRtr
152     // m22 A_BLb A_BRbl A_BRbr
153     double *const A_TRl  = A_TR;
154     double *const A_TRr  = A_TR + 2 * *ldA * n21;
155     double *const A_BLt  = A_BL;
156     double *const A_BLb  = A_BL                  + 2 * m21;
157     double *const A_BRtl = A_BR;
158     double *const A_BRtr = A_BR + 2 * *ldA * n21;
159     double *const A_BRbl = A_BR                  + 2 * m21;
160     double *const A_BRbr = A_BR + 2 * *ldA * n21 + 2 * m21;
161 
162     // recursion(Ab_L, ipiv_T)
163     RELAPACK_zgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
164 if (*info) return;
165 
166     // Workl = A_BLb
167     LAPACK(zlacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
168 
169     // partially redo swaps in A_L
170     for (i = 0; i < mn1; i++) {
171         const blasint ip = ipiv_T[i] - 1;
172         if (ip != i) {
173             if (ip < *kl)
174                 BLAS(zswap)(&i, A_L + 2 * i, ldA, A_L + 2 * ip, ldA);
175             else
176                 BLAS(zswap)(&i, A_L + 2 * i, ldA, Workl + 2 * (ip - *kl), ldWorkl);
177         }
178     }
179 
180     // apply pivots to A_Rl
181     LAPACK(zlaswp)(&n21, A_Rl, ldA, iONE, &mn1, ipiv_T, iONE);
182 
183     // apply pivots to A_Rr columnwise
184     for (j = 0; j < n22; j++) {
185         double *const A_Rrj = A_Rr + 2 * *ldA * j;
186         for (i = j; i < mn1; i++) {
187             const blasint ip = ipiv_T[i] - 1;
188             if (ip != i) {
189                 const double tmpr = A_Rrj[2 * i];
190                 const double tmpc = A_Rrj[2 * i + 1];
191                 A_Rrj[2 * i]     = A_Rrj[2 * ip];
192                 A_Rrj[2 * i + 1] = A_Rrj[2 * ip + 1];
193                 A_Rrj[2 * ip]     = tmpr;
194                 A_Rrj[2 * ip + 1] = tmpc;
195             }
196         }
197     }
198 
199     // A_TRl = A_TL \ A_TRl
200     if (*ldA < MAX(1,m1)) {
201         LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
202         return;
203     } else {
204     BLAS(ztrsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
205     }
206     // Worku = A_TRr
207     LAPACK(zlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
208     // Worku = A_TL \ Worku
209     if (*ldWorku < MAX(1,m1)) {
210         LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
211         return;
212     } else {
213     BLAS(ztrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
214     }
215     // A_TRr = Worku
216     LAPACK(zlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
217     // A_BRtl = A_BRtl - A_BLt * A_TRl
218     BLAS(zgemm)("N", "N", &m21, &n21, &n1, MONE, A_BLt, ldA, A_TRl, ldA, ONE, A_BRtl, ldA);
219     // A_BRbl = A_BRbl - Workl * A_TRl
220     BLAS(zgemm)("N", "N", &m22, &n21, &n1, MONE, Workl, ldWorkl, A_TRl, ldA, ONE, A_BRbl, ldA);
221     // A_BRtr = A_BRtr - A_BLt * Worku
222     BLAS(zgemm)("N", "N", &m21, &n22, &n1, MONE, A_BLt, ldA, Worku, ldWorku, ONE, A_BRtr, ldA);
223     // A_BRbr = A_BRbr - Workl * Worku
224     BLAS(zgemm)("N", "N", &m22, &n22, &n1, MONE, Workl, ldWorkl, Worku, ldWorku, ONE, A_BRbr, ldA);
225 
226     // partially undo swaps in A_L
227     for (i = mn1 - 1; i >= 0; i--) {
228         const blasint ip = ipiv_T[i] - 1;
229         if (ip != i) {
230             if (ip < *kl)
231                 BLAS(zswap)(&i, A_L + 2 * i, ldA, A_L + 2 * ip, ldA);
232             else
233                 BLAS(zswap)(&i, A_L + 2 * i, ldA, Workl + 2 * (ip - *kl), ldWorkl);
234         }
235     }
236 
237     // recursion(Ab_BR, ipiv_B)
238  //   RELAPACK_zgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
239  LAPACK(zgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
240 
241     if (*info)
242         *info += n1;
243     // shift pivots
244     for (i = 0; i < mn2; i++)
245         ipiv_B[i] += n1;
246 }
247