1 #include "relapack.h"
2 #include <stdlib.h>
3 #include <stdio.h>
4 static void RELAPACK_dgbtrf_rec(const blasint *, const blasint *, const blasint *,
5 const blasint *, double *, const blasint *, blasint *, double *, const blasint *, double *,
6 const blasint *, blasint *);
7
8
9 /** DGBTRF computes an LU factorization of a real m-by-n band matrix A using partial pivoting with row interchanges.
10 *
11 * This routine is functionally equivalent to LAPACK's dgbtrf.
12 * For details on its interface, see
13 * http://www.netlib.org/lapack/explore-html/da/d87/dgbtrf_8f.html
14 * */
RELAPACK_dgbtrf(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,blasint * info)15 void RELAPACK_dgbtrf(
16 const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
17 double *Ab, const blasint *ldAb, blasint *ipiv,
18 blasint *info
19 ) {
20
21 // Check arguments
22 *info = 0;
23 if (*m < 0)
24 *info = -1;
25 else if (*n < 0)
26 *info = -2;
27 else if (*kl < 0)
28 *info = -3;
29 else if (*ku < 0)
30 *info = -4;
31 else if (*ldAb < 2 * *kl + *ku + 1)
32 *info = -6;
33 if (*info) {
34 const blasint minfo = -*info;
35 LAPACK(xerbla)("DGBTRF", &minfo, strlen("DGBTRF"));
36 return;
37 }
38
39 if (*m == 0 || *n == 0) return;
40
41 // Constant
42 const double ZERO[] = { 0. };
43
44 // Result upper band width
45 const blasint kv = *ku + *kl;
46
47 // Unskew A
48 const blasint ldA[] = { *ldAb - 1 };
49 double *const A = Ab + kv;
50
51 // Zero upper diagonal fill-in elements
52 blasint i, j;
53 for (j = 0; j < *n; j++) {
54 double *const A_j = A + *ldA * j;
55 for (i = MAX(0, j - kv); i < j - *ku; i++)
56 A_j[i] = 0.;
57 }
58
59 // Allocate work space
60 const blasint n1 = DREC_SPLIT(*n);
61 const blasint mWorkl = abs( (kv > n1) ? MAX(1, *m - *kl) : kv);
62 const blasint nWorkl = abs( (kv > n1) ? n1 : kv);
63 const blasint mWorku = abs( (*kl > n1) ? n1 : *kl);
64 // const blasint nWorku = abs( (*kl > n1) ? MAX(0, *n - *kl) : *kl);
65 const blasint nWorku = abs( (*kl > n1) ? MAX(1, *n - *kl) : *kl);
66 double *Workl = malloc(mWorkl * nWorkl * sizeof(double));
67 double *Worku = malloc(mWorku * nWorku * sizeof(double));
68 LAPACK(dlaset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
69 LAPACK(dlaset)("U", &mWorku, &nWorku, ZERO, ZERO, Worku, &mWorku);
70
71 // Recursive kernel
72 RELAPACK_dgbtrf_rec(m, n, kl, ku, Ab, ldAb, ipiv, Workl, &mWorkl, Worku, &mWorku, info);
73
74 // Free work space
75 free(Workl);
76 free(Worku);
77 }
78
79
80 /** dgbtrf's recursive compute kernel */
RELAPACK_dgbtrf_rec(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,double * Workl,const blasint * ldWorkl,double * Worku,const blasint * ldWorku,blasint * info)81 static void RELAPACK_dgbtrf_rec(
82 const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
83 double *Ab, const blasint *ldAb, blasint *ipiv,
84 double *Workl, const blasint *ldWorkl, double *Worku, const blasint *ldWorku,
85 blasint *info
86 ) {
87
88 if (*n <= MAX(CROSSOVER_DGBTRF, 1) || *n > *kl || *ldAb == 1) {
89 // Unblocked
90 LAPACK(dgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
91 return;
92 }
93
94 // Constants
95 const double ONE[] = { 1. };
96 const double MONE[] = { -1. };
97 const blasint iONE[] = { 1 };
98
99 // Loop iterators
100 blasint i, j;
101
102 // Output upper band width
103 const blasint kv = *ku + *kl;
104
105 // Unskew A
106 const blasint ldA[] = { *ldAb - 1 };
107 double *const A = Ab + kv;
108
109 // Splitting
110 const blasint n1 = MIN(DREC_SPLIT(*n), *kl);
111 const blasint n2 = *n - n1;
112 const blasint m1 = MIN(n1, *m);
113 const blasint m2 = *m - m1;
114 const blasint mn1 = MIN(m1, n1);
115 const blasint mn2 = MIN(m2, n2);
116
117 // Ab_L *
118 // Ab_BR
119 double *const Ab_L = Ab;
120 double *const Ab_BR = Ab + *ldAb * n1;
121
122 // A_L A_R
123 double *const A_L = A;
124 double *const A_R = A + *ldA * n1;
125
126 // A_TL A_TR
127 // A_BL A_BR
128 double *const A_TL = A;
129 double *const A_TR = A + *ldA * n1;
130 double *const A_BL = A + m1;
131 double *const A_BR = A + *ldA * n1 + m1;
132
133 // ipiv_T
134 // ipiv_B
135 blasint *const ipiv_T = ipiv;
136 blasint *const ipiv_B = ipiv + n1;
137
138 // Banded splitting
139 const blasint n21 = MIN(n2, kv - n1);
140 const blasint n22 = MIN(n2 - n21, n1);
141 const blasint m21 = MIN(m2, *kl - m1);
142 const blasint m22 = MIN(m2 - m21, m1);
143
144 // n1 n21 n22
145 // m * A_Rl ARr
146 double *const A_Rl = A_R;
147 double *const A_Rr = A_R + *ldA * n21;
148
149 // n1 n21 n22
150 // m1 * A_TRl A_TRr
151 // m21 A_BLt A_BRtl A_BRtr
152 // m22 A_BLb A_BRbl A_BRbr
153 double *const A_TRl = A_TR;
154 double *const A_TRr = A_TR + *ldA * n21;
155 double *const A_BLt = A_BL;
156 double *const A_BLb = A_BL + m21;
157 double *const A_BRtl = A_BR;
158 double *const A_BRtr = A_BR + *ldA * n21;
159 double *const A_BRbl = A_BR + m21;
160 double *const A_BRbr = A_BR + *ldA * n21 + m21;
161
162 // recursion(Ab_L, ipiv_T)
163 RELAPACK_dgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
164
165 // Workl = A_BLb
166 LAPACK(dlacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
167
168 // partially redo swaps in A_L
169 for (i = 0; i < mn1; i++) {
170 const blasint ip = ipiv_T[i] - 1;
171 if (ip != i) {
172 if (ip < *kl)
173 BLAS(dswap)(&i, A_L + i, ldA, A_L + ip, ldA);
174 else
175 BLAS(dswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
176 }
177 }
178
179 // apply pivots to A_Rl
180 LAPACK(dlaswp)(&n21, A_Rl, ldA, iONE, &mn1, ipiv_T, iONE);
181
182 // apply pivots to A_Rr columnwise
183 for (j = 0; j < n22; j++) {
184 double *const A_Rrj = A_Rr + *ldA * j;
185 for (i = j; i < mn1; i++) {
186 const blasint ip = ipiv_T[i] - 1;
187 if (ip != i) {
188 const double tmp = A_Rrj[i];
189 A_Rrj[i] = A_Rr[ip];
190 A_Rrj[ip] = tmp;
191 }
192 }
193 }
194
195 // A_TRl = A_TL \ A_TRl
196 BLAS(dtrsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
197 // Worku = A_TRr
198 LAPACK(dlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
199 // Worku = A_TL \ Worku
200 if (ldWorku <= 0) return;
201 BLAS(dtrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
202 // A_TRr = Worku
203 LAPACK(dlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
204 // A_BRtl = A_BRtl - A_BLt * A_TRl
205 BLAS(dgemm)("N", "N", &m21, &n21, &n1, MONE, A_BLt, ldA, A_TRl, ldA, ONE, A_BRtl, ldA);
206 // A_BRbl = A_BRbl - Workl * A_TRl
207 BLAS(dgemm)("N", "N", &m22, &n21, &n1, MONE, Workl, ldWorkl, A_TRl, ldA, ONE, A_BRbl, ldA);
208 // A_BRtr = A_BRtr - A_BLt * Worku
209 BLAS(dgemm)("N", "N", &m21, &n22, &n1, MONE, A_BLt, ldA, Worku, ldWorku, ONE, A_BRtr, ldA);
210 // A_BRbr = A_BRbr - Workl * Worku
211 BLAS(dgemm)("N", "N", &m22, &n22, &n1, MONE, Workl, ldWorkl, Worku, ldWorku, ONE, A_BRbr, ldA);
212
213 // partially undo swaps in A_L
214 for (i = mn1 - 1; i >= 0; i--) {
215 const blasint ip = ipiv_T[i] - 1;
216 if (ip != i) {
217 if (ip < *kl)
218 BLAS(dswap)(&i, A_L + i, ldA, A_L + ip, ldA);
219 else
220 BLAS(dswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
221 }
222 }
223
224 // recursion(Ab_BR, ipiv_B)
225 // RELAPACK_dgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
226 LAPACK(dgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
227 if (*info)
228 *info += n1;
229 // shift pivots
230 for (i = 0; i < mn2; i++)
231 ipiv_B[i] += n1;
232 }
233