1 #include "relapack.h"
2 
3 static void RELAPACK_zgetrf_rec(const blasint *, const blasint *, double *,
4     const blasint *, blasint *, blasint *);
5 
6 
7 /** ZGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
8  *
9  * This routine is functionally equivalent to LAPACK's zgetrf.
10  * For details on its interface, see
11  * http://www.netlib.org/lapack/explore-html/dd/dd1/zgetrf_8f.html
12  * */
RELAPACK_zgetrf(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)13 void RELAPACK_zgetrf(
14     const blasint *m, const blasint *n,
15     double *A, const blasint *ldA, blasint *ipiv,
16     blasint *info
17 ) {
18 
19     // Check arguments
20     *info = 0;
21     if (*m < 0)
22         *info = -1;
23     else if (*n < 0)
24         *info = -2;
25     else if (*ldA < MAX(1, *m))
26         *info = -4;
27     if (*info) {
28         const blasint minfo = -*info;
29         LAPACK(xerbla)("ZGETRF", &minfo, strlen("ZGETRF"));
30         return;
31     }
32 
33     if (*m == 0 || *n == 0) return;
34     const blasint sn = MIN(*m, *n);
35 
36     RELAPACK_zgetrf_rec(m, &sn, A, ldA, ipiv, info);
37 
38     // Right remainder
39     if (*m < *n) {
40         // Constants
41         const double ONE[]  = { 1., 0. };
42         const blasint    iONE[] = { 1 };
43 
44         // Splitting
45         const blasint rn = *n - *m;
46 
47         // A_L A_R
48         const double *const A_L = A;
49         double *const       A_R = A + 2 * *ldA * *m;
50 
51         // A_R = apply(ipiv, A_R)
52         LAPACK(zlaswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
53         // A_R = A_L \ A_R
54         BLAS(ztrsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
55     }
56 }
57 
58 
59 /** zgetrf's recursive compute kernel */
RELAPACK_zgetrf_rec(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)60 static void RELAPACK_zgetrf_rec(
61     const blasint *m, const blasint *n,
62     double *A, const blasint *ldA, blasint *ipiv,
63     blasint *info
64 ) {
65 
66     if (*m == 0 || *n == 0) return;
67 
68     if ( *n <= MAX(CROSSOVER_ZGETRF, 1)) {
69         // Unblocked
70         LAPACK(zgetrf2)(m, n, A, ldA, ipiv, info);
71         return;
72     }
73 
74     // Constants
75     const double ONE[]  = { 1., 0. };
76     const double MONE[] = { -1., 0. };
77     const blasint    iONE[] = { 1. };
78 
79     // Splitting
80     const blasint n1 = ZREC_SPLIT(*n);
81     const blasint n2 = *n - n1;
82     const blasint m2 = *m - n1;
83 
84     // A_L A_R
85     double *const A_L = A;
86     double *const A_R = A + 2 * *ldA * n1;
87 
88     // A_TL A_TR
89     // A_BL A_BR
90     double *const A_TL = A;
91     double *const A_TR = A + 2 * *ldA * n1;
92     double *const A_BL = A                 + 2 * n1;
93     double *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
94 
95     // ipiv_T
96     // ipiv_B
97     blasint *const ipiv_T = ipiv;
98     blasint *const ipiv_B = ipiv + n1;
99 
100     // recursion(A_L, ipiv_T)
101     RELAPACK_zgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
102 if (*info) return;
103 
104     // apply pivots to A_R
105     LAPACK(zlaswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
106 
107     // A_TR = A_TL \ A_TR
108     BLAS(ztrsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
109     // A_BR = A_BR - A_BL * A_TR
110     BLAS(zgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
111 
112     // recursion(A_BR, ipiv_B)
113     RELAPACK_zgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
114     if (*info)
115         *info += n1;
116     // apply pivots to A_BL
117     LAPACK(zlaswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
118     // shift pivots
119     blasint i;
120     for (i = 0; i < n2; i++)
121         ipiv_B[i] += n1;
122 }
123