1 #include "relapack.h"
2 
3 static void RELAPACK_cgetrf_rec(const blasint *, const blasint *, float *,
4     const blasint *, blasint *, blasint *);
5 
6 
7 /** CGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
8  *
9  * This routine is functionally equivalent to LAPACK's cgetrf.
10  * For details on its interface, see
11  * http://www.netlib.org/lapack/explore-html/d9/dfb/cgetrf_8f.html
12  */
RELAPACK_cgetrf(const blasint * m,const blasint * n,float * A,const blasint * ldA,blasint * ipiv,blasint * info)13 void RELAPACK_cgetrf(
14     const blasint *m, const blasint *n,
15     float *A, const blasint *ldA, blasint *ipiv,
16     blasint *info
17 ) {
18 
19     // Check arguments
20     *info = 0;
21     if (*m < 0)
22         *info = -1;
23     else if (*n < 0)
24         *info = -2;
25     else if (*ldA < MAX(1, *m))
26         *info = -4;
27     if (*info) {
28         const blasint minfo = -*info;
29         LAPACK(xerbla)("CGETRF", &minfo, strlen("CGETRF"));
30         return;
31     }
32 
33     if (*m == 0 || *n == 0) return;
34 
35     const blasint sn = MIN(*m, *n);
36 
37     RELAPACK_cgetrf_rec(m, &sn, A, ldA, ipiv, info);
38 
39     // Right remainder
40     if (*m < *n) {
41         // Constants
42         const float ONE[]  = { 1., 0. };
43         const blasint   iONE[] = { 1 };
44 
45         // Splitting
46         const blasint rn = *n - *m;
47 
48         // A_L A_R
49         const float *const A_L = A;
50         float *const       A_R = A + 2 * *ldA * *m;
51 
52         // A_R = apply(ipiv, A_R)
53         LAPACK(claswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
54         // A_R = A_L \ A_R
55         BLAS(ctrsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
56     }
57 }
58 
59 
60 /** cgetrf's recursive compute kernel */
RELAPACK_cgetrf_rec(const blasint * m,const blasint * n,float * A,const blasint * ldA,blasint * ipiv,blasint * info)61 static void RELAPACK_cgetrf_rec(
62     const blasint *m, const blasint *n,
63     float *A, const blasint *ldA, blasint *ipiv,
64     blasint *info
65 ) {
66 
67     if (*m == 0 || *n == 0) return;
68 
69     if ( *n <= MAX(CROSSOVER_CGETRF, 1)) {
70         // Unblocked
71         LAPACK(cgetrf2)(m, n, A, ldA, ipiv, info);
72         return;
73     }
74 
75     // Constants
76     const float ONE[]  = { 1., 0. };
77     const float MONE[] = { -1., 0. };
78     const blasint   iONE[] = { 1 };
79 
80     // Splitting
81     const blasint n1 = CREC_SPLIT(*n);
82     const blasint n2 = *n - n1;
83     const blasint m2 = *m - n1;
84 
85     // A_L A_R
86     float *const A_L = A;
87     float *const A_R = A + 2 * *ldA * n1;
88 
89     // A_TL A_TR
90     // A_BL A_BR
91     float *const A_TL = A;
92     float *const A_TR = A + 2 * *ldA * n1;
93     float *const A_BL = A                 + 2 * n1;
94     float *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
95 
96     // ipiv_T
97     // ipiv_B
98     blasint *const ipiv_T = ipiv;
99     blasint *const ipiv_B = ipiv + n1;
100 
101     // recursion(A_L, ipiv_T)
102     RELAPACK_cgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
103     if (*info) return;
104     // apply pivots to A_R
105     LAPACK(claswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
106 
107     // A_TR = A_TL \ A_TR
108     BLAS(ctrsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
109     // A_BR = A_BR - A_BL * A_TR
110     BLAS(cgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
111 
112     // recursion(A_BR, ipiv_B)
113     RELAPACK_cgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
114     if (*info)
115         *info += n1;
116     // apply pivots to A_BL
117     LAPACK(claswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
118     // shift pivots
119     blasint i;
120     for (i = 0; i < n2; i++)
121         ipiv_B[i] += n1;
122 }
123