1 #include "relapack.h"
2 
3 static void RELAPACK_dgetrf_rec(const blasint *, const blasint *, double *,
4     const blasint *, blasint *, blasint *);
5 
6 
7 /** DGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
8  *
9  * This routine is functionally equivalent to LAPACK's dgetrf.
10  * For details on its interface, see
11  * http://www.netlib.org/lapack/explore-html/d3/d6a/dgetrf_8f.html
12  * */
RELAPACK_dgetrf(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)13 void RELAPACK_dgetrf(
14     const blasint *m, const blasint *n,
15     double *A, const blasint *ldA, blasint *ipiv,
16     blasint *info
17 ) {
18     // Check arguments
19     *info = 0;
20     if (*m < 0)
21         *info = -1;
22     else if (*n < 0)
23         *info = -2;
24     else if (*ldA < MAX(1, *m))
25         *info = -4;
26     if (*info!=0) {
27         const blasint minfo = -*info;
28         LAPACK(xerbla)("DGETRF", &minfo, strlen("DGETRF"));
29         return;
30     }
31 
32     if (*m == 0 || *n == 0) return;
33 
34     const blasint sn = MIN(*m, *n);
35     RELAPACK_dgetrf_rec(m, &sn, A, ldA, ipiv, info);
36 
37     // Right remainder
38     if (*m < *n) {
39         // Constants
40         const double ONE[] = { 1. };
41         const blasint   iONE[] = { 1 };
42 
43         // Splitting
44         const blasint rn = *n - *m;
45 
46         // A_L A_R
47         const double *const A_L = A;
48         double *const       A_R = A + *ldA * *m;
49 
50         // A_R = apply(ipiv, A_R)
51         LAPACK(dlaswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
52         // A_R = A_S \ A_R
53         BLAS(dtrsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
54     }
55 }
56 
57 
58 /** dgetrf's recursive compute kernel */
RELAPACK_dgetrf_rec(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)59 static void RELAPACK_dgetrf_rec(
60     const blasint *m, const blasint *n,
61     double *A, const blasint *ldA, blasint *ipiv,
62     blasint *info
63 ) {
64     if ( *n <= MAX(CROSSOVER_DGETRF, 1)) {
65         // Unblocked
66         LAPACK(dgetrf2)(m, n, A, ldA, ipiv, info);
67         return;
68     }
69     // Constants
70     const double ONE[]  = { 1. };
71     const double MONE[] = { -1. };
72     const blasint    iONE[] = { 1 };
73 
74     // Splitting
75     const blasint n1 = DREC_SPLIT(*n);
76     const blasint n2 = *n - n1;
77     const blasint m2 = *m - n1;
78 
79     // A_L A_R
80     double *const A_L = A;
81     double *const A_R = A + *ldA * n1;
82 
83     // A_TL A_TR
84     // A_BL A_BR
85     double *const A_TL = A;
86     double *const A_TR = A + *ldA * n1;
87     double *const A_BL = A             + n1;
88     double *const A_BR = A + *ldA * n1 + n1;
89 
90     // ipiv_T
91     // ipiv_B
92     blasint *const ipiv_T = ipiv;
93     blasint *const ipiv_B = ipiv + n1;
94 
95     // recursion(A_L, ipiv_T)
96     RELAPACK_dgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
97     if (*info) return;
98     // apply pivots to A_R
99     LAPACK(dlaswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
100 
101     // A_TR = A_TL \ A_TR
102     BLAS(dtrsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
103     // A_BR = A_BR - A_BL * A_TR
104     BLAS(dgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
105 
106     // recursion(A_BR, ipiv_B)
107     RELAPACK_dgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
108     if (*info)
109         *info += n1;
110     // apply pivots to A_BL
111     LAPACK(dlaswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
112     // shift pivots
113     blasint i;
114     for (i = 0; i < n2; i++)
115         ipiv_B[i] += n1;
116 }
117