1 #include "relapack.h"
2 #include "stdlib.h"
3
4 static void RELAPACK_sgbtrf_rec(const blasint *, const blasint *, const blasint *,
5 const blasint *, float *, const blasint *, blasint *, float *, const blasint *, float *,
6 const blasint *, blasint *);
7
8
9 /** SGBTRF computes an LU factorization of a real m-by-n band matrix A using partial pivoting with row interchanges.
10 *
11 * This routine is functionally equivalent to LAPACK's sgbtrf.
12 * For details on its interface, see
13 * http://www.netlib.org/lapack/explore-html/d5/d72/sgbtrf_8f.html
14 * */
RELAPACK_sgbtrf(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,float * Ab,const blasint * ldAb,blasint * ipiv,blasint * info)15 void RELAPACK_sgbtrf(
16 const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
17 float *Ab, const blasint *ldAb, blasint *ipiv,
18 blasint *info
19 ) {
20 // Check arguments
21 *info = 0;
22 if (*m < 0)
23 *info = -1;
24 else if (*n < 0)
25 *info = -2;
26 else if (*kl < 0)
27 *info = -3;
28 else if (*ku < 0)
29 *info = -4;
30 else if (*ldAb < 2 * *kl + *ku + 1)
31 *info = -6;
32 if (*info) {
33 const blasint minfo = -*info;
34 LAPACK(xerbla)("SGBTRF", &minfo, strlen("SGBTRF"));
35 return;
36 }
37
38 if (*m == 0 || *n == 0) return;
39
40 if (*ldAb == 1) {
41 LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
42 return;
43 }
44
45 // Constant
46 const float ZERO[] = { 0. };
47
48 // Result upper band width
49 const blasint kv = *ku + *kl;
50
51 // Unskewg A
52 const blasint ldA[] = { *ldAb - 1 };
53 float *const A = Ab + kv;
54
55 // Zero upper diagonal fill-in elements
56 blasint i, j;
57 for (j = 0; j < *n; j++) {
58 float *const A_j = A + *ldA * j;
59 for (i = MAX(0, j - kv); i < j - *ku; i++)
60 A_j[i] = 0.;
61 }
62
63 // Allocate work space
64 const blasint n1 = SREC_SPLIT(*n);
65 const blasint mWorkl = abs( (kv > n1) ? MAX(1, *m - *kl) : kv );
66 const blasint nWorkl = abs( (kv > n1) ? n1 : kv );
67 const blasint mWorku = abs( (*kl > n1) ? n1 : *kl );
68 const blasint nWorku = abs( (*kl > n1) ? MAX(0, *n - *kl) : *kl );
69 float *Workl = malloc(mWorkl * nWorkl * sizeof(float));
70 float *Worku = malloc(mWorku * nWorku * sizeof(float));
71 LAPACK(slaset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
72 LAPACK(slaset)("U", &mWorku, &nWorku, ZERO, ZERO, Worku, &mWorku);
73
74
75 // Recursive kernel
76 RELAPACK_sgbtrf_rec(m, n, kl, ku, Ab, ldAb, ipiv, Workl, &mWorkl, Worku, &mWorku, info);
77
78 // Free work space
79 free(Workl);
80 free(Worku);
81 }
82
83
84 /** sgbtrf's recursive compute kernel */
RELAPACK_sgbtrf_rec(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,float * Ab,const blasint * ldAb,blasint * ipiv,float * Workl,const blasint * ldWorkl,float * Worku,const blasint * ldWorku,blasint * info)85 static void RELAPACK_sgbtrf_rec(
86 const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
87 float *Ab, const blasint *ldAb, blasint *ipiv,
88 float *Workl, const blasint *ldWorkl, float *Worku, const blasint *ldWorku,
89 blasint *info
90 ) {
91
92 if (*m == 0 || *n == 0) return;
93
94 if ( *n <= MAX(CROSSOVER_SGBTRF, 1) || *n > *kl || *ldAb == 1) {
95 // Unblocked
96 LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
97 return;
98 }
99
100 // Constants
101 const float ONE[] = { 1. };
102 const float MONE[] = { -1. };
103 const blasint iONE[] = { 1 };
104
105 // Loop iterators
106 blasint i, j;
107
108 // Output upper band width
109 const blasint kv = *ku + *kl;
110
111 // Unskew A
112 const blasint ldA[] = { *ldAb - 1 };
113 float *const A = Ab + kv;
114
115 // Splitting
116 const blasint n1 = MIN(SREC_SPLIT(*n), *kl);
117 const blasint n2 = *n - n1;
118 const blasint m1 = MIN(n1, *m);
119 const blasint m2 = *m - m1;
120 const blasint mn1 = MIN(m1, n1);
121 const blasint mn2 = MIN(m2, n2);
122
123 // Ab_L *
124 // Ab_BR
125 float *const Ab_L = Ab;
126 float *const Ab_BR = Ab + *ldAb * n1;
127
128 // A_L A_R
129 float *const A_L = A;
130 float *const A_R = A + *ldA * n1;
131
132 // A_TL A_TR
133 // A_BL A_BR
134 float *const A_TL = A;
135 float *const A_TR = A + *ldA * n1;
136 float *const A_BL = A + m1;
137 float *const A_BR = A + *ldA * n1 + m1;
138
139 // ipiv_T
140 // ipiv_B
141 blasint *const ipiv_T = ipiv;
142 blasint *const ipiv_B = ipiv + n1;
143
144 // Banded splitting
145 const blasint n21 = MIN(n2, kv - n1);
146 const blasint n22 = MIN(n2 - n21, n1);
147 const blasint m21 = MIN(m2, *kl - m1);
148 const blasint m22 = MIN(m2 - m21, m1);
149
150 // n1 n21 n22
151 // m * A_Rl ARr
152 float *const A_Rl = A_R;
153 float *const A_Rr = A_R + *ldA * n21;
154
155 // n1 n21 n22
156 // m1 * A_TRl A_TRr
157 // m21 A_BLt A_BRtl A_BRtr
158 // m22 A_BLb A_BRbl A_BRbr
159 float *const A_TRl = A_TR;
160 float *const A_TRr = A_TR + *ldA * n21;
161 float *const A_BLt = A_BL;
162 float *const A_BLb = A_BL + m21;
163 float *const A_BRtl = A_BR;
164 float *const A_BRtr = A_BR + *ldA * n21;
165 float *const A_BRbl = A_BR + m21;
166 float *const A_BRbr = A_BR + *ldA * n21 + m21;
167
168
169 // recursion(Ab_L, ipiv_T)
170 RELAPACK_sgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
171 if (*info) return;
172 // Workl = A_BLb
173 LAPACK(slacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
174
175 // partially redo swaps in A_L
176 for (i = 0; i < mn1; i++) {
177 const blasint ip = ipiv_T[i] - 1;
178 if (ip != i) {
179 if (ip < *kl)
180 BLAS(sswap)(&i, A_L + i, ldA, A_L + ip, ldA);
181 else
182 BLAS(sswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
183 }
184 }
185
186 // apply pivots to A_Rl
187 LAPACK(slaswp)(&n21, A_Rl, ldA, iONE, &mn1, ipiv_T, iONE);
188
189 // apply pivots to A_Rr columnwise
190 for (j = 0; j < n22; j++) {
191 float *const A_Rrj = A_Rr + *ldA * j;
192 for (i = j; i < mn1; i++) {
193 const blasint ip = ipiv_T[i] - 1;
194 if (ip != i) {
195 const float tmp = A_Rrj[i];
196 A_Rrj[i] = A_Rr[ip];
197 A_Rrj[ip] = tmp;
198 }
199 }
200 }
201
202 // A_TRl = A_TL \ A_TRl
203 BLAS(strsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
204 // Worku = A_TRr
205 LAPACK(slacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
206 // Worku = A_TL \ Worku
207 BLAS(strsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
208 // A_TRr = Worku
209 LAPACK(slacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
210 // A_BRtl = A_BRtl - A_BLt * A_TRl
211 BLAS(sgemm)("N", "N", &m21, &n21, &n1, MONE, A_BLt, ldA, A_TRl, ldA, ONE, A_BRtl, ldA);
212 // A_BRbl = A_BRbl - Workl * A_TRl
213 BLAS(sgemm)("N", "N", &m22, &n21, &n1, MONE, Workl, ldWorkl, A_TRl, ldA, ONE, A_BRbl, ldA);
214 // A_BRtr = A_BRtr - A_BLt * Worku
215 BLAS(sgemm)("N", "N", &m21, &n22, &n1, MONE, A_BLt, ldA, Worku, ldWorku, ONE, A_BRtr, ldA);
216 // A_BRbr = A_BRbr - Workl * Worku
217 BLAS(sgemm)("N", "N", &m22, &n22, &n1, MONE, Workl, ldWorkl, Worku, ldWorku, ONE, A_BRbr, ldA);
218
219 // partially undo swaps in A_L
220 for (i = mn1 - 1; i >= 0; i--) {
221 const blasint ip = ipiv_T[i] - 1;
222 if (ip != i) {
223 if (ip < *kl)
224 BLAS(sswap)(&i, A_L + i, ldA, A_L + ip, ldA);
225 else
226 BLAS(sswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
227 }
228 }
229
230
231 // recursion(Ab_BR, ipiv_B)
232 //cause of infinite recursion here ?
233 RELAPACK_sgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
234 // LAPACK(sgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
235 if (*info)
236 *info += n1;
237 // shift pivots
238 for (i = 0; i < mn2; i++)
239 ipiv_B[i] += n1;
240 }
241