1 #include "relapack.h"
2 #include "stdlib.h"
3
4 static void RELAPACK_zgbtrf_rec(const blasint *, const blasint *, const blasint *,
5 const blasint *, double *, const blasint *, blasint *, double *, const blasint *, double *,
6 const blasint *, blasint *);
7
8
9 /** ZGBTRF computes an LU factorization of a complex m-by-n band matrix A using partial pivoting with row interchanges.
10 *
11 * This routine is functionally equivalent to LAPACK's zgbtrf.
12 * For details on its interface, see
13 * http://www.netlib.org/lapack/explore-html/dc/dcb/zgbtrf_8f.html
14 * */
RELAPACK_zgbtrf(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,blasint * info)15 void RELAPACK_zgbtrf(
16 const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
17 double *Ab, const blasint *ldAb, blasint *ipiv,
18 blasint *info
19 ) {
20
21 // Check arguments
22 *info = 0;
23 if (*m < 0)
24 *info = -1;
25 else if (*n < 0)
26 *info = -2;
27 else if (*kl < 0)
28 *info = -3;
29 else if (*ku < 0)
30 *info = -4;
31 else if (*ldAb < 2 * *kl + *ku + 1)
32 *info = -6;
33 if (*info) {
34 const blasint minfo = -*info;
35 LAPACK(xerbla)("ZGBTRF", &minfo, strlen("ZGBTRF"));
36 return;
37 }
38
39 if (*m == 0 || *n == 0) return;
40
41 // Constant
42 const double ZERO[] = { 0., 0. };
43
44 // Result upper band width
45 const blasint kv = *ku + *kl;
46
47 // Unskew A
48 const blasint ldA[] = { *ldAb - 1 };
49 double *const A = Ab + 2 * kv;
50
51 // Zero upper diagonal fill-in elements
52 blasint i, j;
53 for (j = 0; j < *n; j++) {
54 double *const A_j = A + 2 * *ldA * j;
55 for (i = MAX(0, j - kv); i < j - *ku; i++)
56 A_j[2 * i] = A_j[2 * i + 1] = 0.;
57 }
58
59 // Allocate work space
60 const blasint n1 = ZREC_SPLIT(*n);
61 const blasint mWorkl = abs ( (kv > n1) ? MAX(1, *m - *kl) : kv);
62 const blasint nWorkl = abs ( (kv > n1) ? n1 : kv);
63 const blasint mWorku = abs ( (*kl > n1) ? n1 : *kl);
64 const blasint nWorku = abs ( (*kl > n1) ? MAX(0, *n - *kl) : *kl);
65 double *Workl = malloc(mWorkl * nWorkl * 2 * sizeof(double));
66 double *Worku = malloc(mWorku * nWorku * 2 * sizeof(double));
67 LAPACK(zlaset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
68 LAPACK(zlaset)("U", &mWorku, &nWorku, ZERO, ZERO, Worku, &mWorku);
69
70 // Recursive kernel
71 RELAPACK_zgbtrf_rec(m, n, kl, ku, Ab, ldAb, ipiv, Workl, &mWorkl, Worku, &mWorku, info);
72
73 // Free work space
74 free(Workl);
75 free(Worku);
76 }
77
78
79 /** zgbtrf's recursive compute kernel */
RELAPACK_zgbtrf_rec(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,double * Workl,const blasint * ldWorkl,double * Worku,const blasint * ldWorku,blasint * info)80 static void RELAPACK_zgbtrf_rec(
81 const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
82 double *Ab, const blasint *ldAb, blasint *ipiv,
83 double *Workl, const blasint *ldWorkl, double *Worku, const blasint *ldWorku,
84 blasint *info
85 ) {
86
87 if (*n <= MAX(CROSSOVER_ZGBTRF, 1) || *n > *kl || *ldAb == 1) {
88 // Unblocked
89 LAPACK(zgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
90 return;
91 }
92
93 // Constants
94 const double ONE[] = { 1., 0. };
95 const double MONE[] = { -1., 0. };
96 const blasint iONE[] = { 1 };
97 const blasint min11 = -11;
98
99 // Loop iterators
100 blasint i, j;
101
102 // Output upper band width
103 const blasint kv = *ku + *kl;
104
105 // Unskew A
106 const blasint ldA[] = { *ldAb - 1 };
107 double *const A = Ab + 2 * kv;
108
109 // Splitting
110 const blasint n1 = MIN(ZREC_SPLIT(*n), *kl);
111 const blasint n2 = *n - n1;
112 const blasint m1 = MIN(n1, *m);
113 const blasint m2 = *m - m1;
114 const blasint mn1 = MIN(m1, n1);
115 const blasint mn2 = MIN(m2, n2);
116
117 // Ab_L *
118 // Ab_BR
119 double *const Ab_L = Ab;
120 double *const Ab_BR = Ab + 2 * *ldAb * n1;
121
122 // A_L A_R
123 double *const A_L = A;
124 double *const A_R = A + 2 * *ldA * n1;
125
126 // A_TL A_TR
127 // A_BL A_BR
128 double *const A_TL = A;
129 double *const A_TR = A + 2 * *ldA * n1;
130 double *const A_BL = A + 2 * m1;
131 double *const A_BR = A + 2 * *ldA * n1 + 2 * m1;
132
133 // ipiv_T
134 // ipiv_B
135 blasint *const ipiv_T = ipiv;
136 blasint *const ipiv_B = ipiv + n1;
137
138 // Banded splitting
139 const blasint n21 = MIN(n2, kv - n1);
140 const blasint n22 = MIN(n2 - n21, n1);
141 const blasint m21 = MIN(m2, *kl - m1);
142 const blasint m22 = MIN(m2 - m21, m1);
143
144 // n1 n21 n22
145 // m * A_Rl ARr
146 double *const A_Rl = A_R;
147 double *const A_Rr = A_R + 2 * *ldA * n21;
148
149 // n1 n21 n22
150 // m1 * A_TRl A_TRr
151 // m21 A_BLt A_BRtl A_BRtr
152 // m22 A_BLb A_BRbl A_BRbr
153 double *const A_TRl = A_TR;
154 double *const A_TRr = A_TR + 2 * *ldA * n21;
155 double *const A_BLt = A_BL;
156 double *const A_BLb = A_BL + 2 * m21;
157 double *const A_BRtl = A_BR;
158 double *const A_BRtr = A_BR + 2 * *ldA * n21;
159 double *const A_BRbl = A_BR + 2 * m21;
160 double *const A_BRbr = A_BR + 2 * *ldA * n21 + 2 * m21;
161
162 // recursion(Ab_L, ipiv_T)
163 RELAPACK_zgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
164 if (*info) return;
165
166 // Workl = A_BLb
167 LAPACK(zlacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
168
169 // partially redo swaps in A_L
170 for (i = 0; i < mn1; i++) {
171 const blasint ip = ipiv_T[i] - 1;
172 if (ip != i) {
173 if (ip < *kl)
174 BLAS(zswap)(&i, A_L + 2 * i, ldA, A_L + 2 * ip, ldA);
175 else
176 BLAS(zswap)(&i, A_L + 2 * i, ldA, Workl + 2 * (ip - *kl), ldWorkl);
177 }
178 }
179
180 // apply pivots to A_Rl
181 LAPACK(zlaswp)(&n21, A_Rl, ldA, iONE, &mn1, ipiv_T, iONE);
182
183 // apply pivots to A_Rr columnwise
184 for (j = 0; j < n22; j++) {
185 double *const A_Rrj = A_Rr + 2 * *ldA * j;
186 for (i = j; i < mn1; i++) {
187 const blasint ip = ipiv_T[i] - 1;
188 if (ip != i) {
189 const double tmpr = A_Rrj[2 * i];
190 const double tmpc = A_Rrj[2 * i + 1];
191 A_Rrj[2 * i] = A_Rrj[2 * ip];
192 A_Rrj[2 * i + 1] = A_Rrj[2 * ip + 1];
193 A_Rrj[2 * ip] = tmpr;
194 A_Rrj[2 * ip + 1] = tmpc;
195 }
196 }
197 }
198
199 // A_TRl = A_TL \ A_TRl
200 if (*ldA < MAX(1,m1)) {
201 LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
202 return;
203 } else {
204 BLAS(ztrsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
205 }
206 // Worku = A_TRr
207 LAPACK(zlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
208 // Worku = A_TL \ Worku
209 if (*ldWorku < MAX(1,m1)) {
210 LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
211 return;
212 } else {
213 BLAS(ztrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
214 }
215 // A_TRr = Worku
216 LAPACK(zlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
217 // A_BRtl = A_BRtl - A_BLt * A_TRl
218 BLAS(zgemm)("N", "N", &m21, &n21, &n1, MONE, A_BLt, ldA, A_TRl, ldA, ONE, A_BRtl, ldA);
219 // A_BRbl = A_BRbl - Workl * A_TRl
220 BLAS(zgemm)("N", "N", &m22, &n21, &n1, MONE, Workl, ldWorkl, A_TRl, ldA, ONE, A_BRbl, ldA);
221 // A_BRtr = A_BRtr - A_BLt * Worku
222 BLAS(zgemm)("N", "N", &m21, &n22, &n1, MONE, A_BLt, ldA, Worku, ldWorku, ONE, A_BRtr, ldA);
223 // A_BRbr = A_BRbr - Workl * Worku
224 BLAS(zgemm)("N", "N", &m22, &n22, &n1, MONE, Workl, ldWorkl, Worku, ldWorku, ONE, A_BRbr, ldA);
225
226 // partially undo swaps in A_L
227 for (i = mn1 - 1; i >= 0; i--) {
228 const blasint ip = ipiv_T[i] - 1;
229 if (ip != i) {
230 if (ip < *kl)
231 BLAS(zswap)(&i, A_L + 2 * i, ldA, A_L + 2 * ip, ldA);
232 else
233 BLAS(zswap)(&i, A_L + 2 * i, ldA, Workl + 2 * (ip - *kl), ldWorkl);
234 }
235 }
236
237 // recursion(Ab_BR, ipiv_B)
238 // RELAPACK_zgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
239 LAPACK(zgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
240
241 if (*info)
242 *info += n1;
243 // shift pivots
244 for (i = 0; i < mn2; i++)
245 ipiv_B[i] += n1;
246 }
247