1 #include "relapack.h"
2 #include <stdlib.h>
3 #include <stdio.h>
4 static void RELAPACK_dgbtrf_rec(const blasint *, const blasint *, const blasint *,
5     const blasint *, double *, const blasint *, blasint *, double *, const blasint *, double *,
6     const blasint *, blasint *);
7 
8 
9 /** DGBTRF computes an LU factorization of a real m-by-n band matrix A using partial pivoting with row interchanges.
10  *
11  * This routine is functionally equivalent to LAPACK's dgbtrf.
12  * For details on its interface, see
13  * http://www.netlib.org/lapack/explore-html/da/d87/dgbtrf_8f.html
14  * */
RELAPACK_dgbtrf(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,blasint * info)15 void RELAPACK_dgbtrf(
16     const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
17     double *Ab, const blasint *ldAb, blasint *ipiv,
18     blasint *info
19 ) {
20 
21     // Check arguments
22     *info = 0;
23     if (*m < 0)
24         *info = -1;
25     else if (*n < 0)
26         *info = -2;
27     else if (*kl < 0)
28         *info = -3;
29     else if (*ku < 0)
30         *info = -4;
31     else if (*ldAb < 2 * *kl + *ku + 1)
32         *info = -6;
33     if (*info) {
34         const blasint minfo = -*info;
35         LAPACK(xerbla)("DGBTRF", &minfo, strlen("DGBTRF"));
36         return;
37     }
38 
39     if (*m == 0 || *n == 0) return;
40 
41     // Constant
42     const double ZERO[] = { 0. };
43 
44     // Result upper band width
45     const blasint kv = *ku + *kl;
46 
47     // Unskew A
48     const blasint ldA[] = { *ldAb - 1 };
49     double *const A = Ab + kv;
50 
51     // Zero upper diagonal fill-in elements
52     blasint i, j;
53     for (j = 0; j < *n; j++) {
54         double *const A_j = A + *ldA * j;
55         for (i = MAX(0, j - kv); i < j - *ku; i++)
56             A_j[i] = 0.;
57     }
58 
59     // Allocate work space
60     const blasint n1 = DREC_SPLIT(*n);
61     const blasint mWorkl = abs( (kv > n1) ? MAX(1, *m - *kl) : kv);
62     const blasint nWorkl = abs( (kv > n1) ? n1 : kv);
63     const blasint mWorku = abs( (*kl > n1) ? n1 : *kl);
64 //    const blasint nWorku = abs( (*kl > n1) ? MAX(0, *n - *kl) : *kl);
65     const blasint nWorku = abs( (*kl > n1) ? MAX(1, *n - *kl) : *kl);
66     double *Workl = malloc(mWorkl * nWorkl * sizeof(double));
67     double *Worku = malloc(mWorku * nWorku * sizeof(double));
68     LAPACK(dlaset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
69     LAPACK(dlaset)("U", &mWorku, &nWorku, ZERO, ZERO, Worku, &mWorku);
70 
71     // Recursive kernel
72     RELAPACK_dgbtrf_rec(m, n, kl, ku, Ab, ldAb, ipiv, Workl, &mWorkl, Worku, &mWorku, info);
73 
74     // Free work space
75     free(Workl);
76     free(Worku);
77 }
78 
79 
80 /** dgbtrf's recursive compute kernel */
RELAPACK_dgbtrf_rec(const blasint * m,const blasint * n,const blasint * kl,const blasint * ku,double * Ab,const blasint * ldAb,blasint * ipiv,double * Workl,const blasint * ldWorkl,double * Worku,const blasint * ldWorku,blasint * info)81 static void RELAPACK_dgbtrf_rec(
82     const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
83     double *Ab, const blasint *ldAb, blasint *ipiv,
84     double *Workl, const blasint *ldWorkl, double *Worku, const blasint *ldWorku,
85     blasint *info
86 ) {
87 
88     if (*n <= MAX(CROSSOVER_DGBTRF, 1) || *n > *kl || *ldAb == 1) {
89         // Unblocked
90         LAPACK(dgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
91         return;
92     }
93 
94     // Constants
95     const double ONE[]  = { 1. };
96     const double MONE[] = { -1. };
97     const blasint    iONE[] = { 1 };
98 
99     // Loop iterators
100     blasint i, j;
101 
102     // Output upper band width
103     const blasint kv = *ku + *kl;
104 
105     // Unskew A
106     const blasint ldA[] = { *ldAb - 1 };
107     double *const A = Ab + kv;
108 
109     // Splitting
110     const blasint n1  = MIN(DREC_SPLIT(*n), *kl);
111     const blasint n2  = *n - n1;
112     const blasint m1  = MIN(n1, *m);
113     const blasint m2  = *m - m1;
114     const blasint mn1 = MIN(m1, n1);
115     const blasint mn2 = MIN(m2, n2);
116 
117     // Ab_L *
118     //      Ab_BR
119     double *const Ab_L  = Ab;
120     double *const Ab_BR = Ab + *ldAb * n1;
121 
122     // A_L A_R
123     double *const A_L = A;
124     double *const A_R = A + *ldA * n1;
125 
126     // A_TL A_TR
127     // A_BL A_BR
128     double *const A_TL = A;
129     double *const A_TR = A + *ldA * n1;
130     double *const A_BL = A             + m1;
131     double *const A_BR = A + *ldA * n1 + m1;
132 
133     // ipiv_T
134     // ipiv_B
135     blasint *const ipiv_T = ipiv;
136     blasint *const ipiv_B = ipiv + n1;
137 
138     // Banded splitting
139     const blasint n21 = MIN(n2, kv - n1);
140     const blasint n22 = MIN(n2 - n21, n1);
141     const blasint m21 = MIN(m2, *kl - m1);
142     const blasint m22 = MIN(m2 - m21, m1);
143 
144     //   n1 n21  n22
145     // m *  A_Rl ARr
146     double *const A_Rl = A_R;
147     double *const A_Rr = A_R + *ldA * n21;
148 
149     //     n1    n21    n22
150     // m1  *     A_TRl  A_TRr
151     // m21 A_BLt A_BRtl A_BRtr
152     // m22 A_BLb A_BRbl A_BRbr
153     double *const A_TRl  = A_TR;
154     double *const A_TRr  = A_TR + *ldA * n21;
155     double *const A_BLt  = A_BL;
156     double *const A_BLb  = A_BL              + m21;
157     double *const A_BRtl = A_BR;
158     double *const A_BRtr = A_BR + *ldA * n21;
159     double *const A_BRbl = A_BR              + m21;
160     double *const A_BRbr = A_BR + *ldA * n21 + m21;
161 
162     // recursion(Ab_L, ipiv_T)
163     RELAPACK_dgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
164 
165     // Workl = A_BLb
166     LAPACK(dlacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
167 
168     // partially redo swaps in A_L
169     for (i = 0; i < mn1; i++) {
170         const blasint ip = ipiv_T[i] - 1;
171         if (ip != i) {
172             if (ip < *kl)
173                 BLAS(dswap)(&i, A_L + i, ldA, A_L + ip, ldA);
174             else
175                 BLAS(dswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
176         }
177     }
178 
179     // apply pivots to A_Rl
180     LAPACK(dlaswp)(&n21, A_Rl, ldA, iONE, &mn1, ipiv_T, iONE);
181 
182     // apply pivots to A_Rr columnwise
183     for (j = 0; j < n22; j++) {
184         double *const A_Rrj = A_Rr + *ldA * j;
185         for (i = j; i < mn1; i++) {
186             const blasint ip = ipiv_T[i] - 1;
187             if (ip != i) {
188                 const double tmp = A_Rrj[i];
189                 A_Rrj[i] = A_Rr[ip];
190                 A_Rrj[ip] = tmp;
191             }
192         }
193     }
194 
195     // A_TRl = A_TL \ A_TRl
196     BLAS(dtrsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
197     // Worku = A_TRr
198     LAPACK(dlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
199     // Worku = A_TL \ Worku
200     if (ldWorku <= 0) return;
201     BLAS(dtrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
202     // A_TRr = Worku
203     LAPACK(dlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
204     // A_BRtl = A_BRtl - A_BLt * A_TRl
205     BLAS(dgemm)("N", "N", &m21, &n21, &n1, MONE, A_BLt, ldA, A_TRl, ldA, ONE, A_BRtl, ldA);
206     // A_BRbl = A_BRbl - Workl * A_TRl
207     BLAS(dgemm)("N", "N", &m22, &n21, &n1, MONE, Workl, ldWorkl, A_TRl, ldA, ONE, A_BRbl, ldA);
208     // A_BRtr = A_BRtr - A_BLt * Worku
209     BLAS(dgemm)("N", "N", &m21, &n22, &n1, MONE, A_BLt, ldA, Worku, ldWorku, ONE, A_BRtr, ldA);
210     // A_BRbr = A_BRbr - Workl * Worku
211     BLAS(dgemm)("N", "N", &m22, &n22, &n1, MONE, Workl, ldWorkl, Worku, ldWorku, ONE, A_BRbr, ldA);
212 
213     // partially undo swaps in A_L
214     for (i = mn1 - 1; i >= 0; i--) {
215         const blasint ip = ipiv_T[i] - 1;
216         if (ip != i) {
217             if (ip < *kl)
218                 BLAS(dswap)(&i, A_L + i, ldA, A_L + ip, ldA);
219             else
220                 BLAS(dswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
221         }
222     }
223 
224     // recursion(Ab_BR, ipiv_B)
225 //    RELAPACK_dgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
226         LAPACK(dgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
227     if (*info)
228         *info += n1;
229     // shift pivots
230     for (i = 0; i < mn2; i++)
231         ipiv_B[i] += n1;
232 }
233