1 #include "relapack.h"
2 static void RELAPACK_sgetrf_rec(const blasint *, const blasint *, float *, const blasint *,
3     blasint *, blasint *);
4 
5 
6 /** SGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
7  *
8  * This routine is functionally equivalent to LAPACK's sgetrf.
9  * For details on its interface, see
10  * http://www.netlib.org/lapack/explore-html/de/de2/sgetrf_8f.html
11  * */
RELAPACK_sgetrf(const blasint * m,const blasint * n,float * A,const blasint * ldA,blasint * ipiv,blasint * info)12 void RELAPACK_sgetrf(
13     const blasint *m, const blasint *n,
14     float *A, const blasint *ldA, blasint *ipiv,
15     blasint *info
16 ) {
17     // Check arguments
18     *info = 0;
19     if (*m < 0)
20         *info = -1;
21     else if (*n < 0)
22         *info = -2;
23     else if (*ldA < MAX(1, *m))
24         *info = -4;
25     if (*info) {
26         const blasint minfo = -*info;
27         LAPACK(xerbla)("SGETRF", &minfo, strlen("SGETRF"));
28         return;
29     }
30 
31     if (*m == 0 || *n == 0) return;
32 
33     const blasint sn = MIN(*m, *n);
34     RELAPACK_sgetrf_rec(m, &sn, A, ldA, ipiv, info);
35 
36     // Right remainder
37     if (*m < *n) {
38         // Constants
39         const float ONE[] = { 1. };
40         const blasint  iONE[] = { 1 };
41 
42         // Splitting
43         const blasint rn = *n - *m;
44 
45         // A_L A_R
46         const float *const A_L = A;
47         float *const       A_R = A + *ldA * *m;
48 
49         // A_R = apply(ipiv, A_R)
50         LAPACK(slaswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
51         // A_R = A_L \ A_R
52         BLAS(strsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
53     }
54 }
55 
56 
57 /** sgetrf's recursive compute kernel */
RELAPACK_sgetrf_rec(const blasint * m,const blasint * n,float * A,const blasint * ldA,blasint * ipiv,blasint * info)58 static void RELAPACK_sgetrf_rec(
59     const blasint *m, const blasint *n,
60     float *A, const blasint *ldA, blasint *ipiv,
61     blasint *info
62 ) {
63 
64     if (*m == 0 || *n == 0) return;
65 
66     if ( *n <= MAX(CROSSOVER_SGETRF, 1)) {
67         // Unblocked
68         LAPACK(sgetrf2)(m, n, A, ldA, ipiv, info);
69         return;
70     }
71 
72     // Constants
73     const float ONE[]  = { 1. };
74     const float MONE[] = { -1. };
75     const blasint   iONE[] = { 1 };
76 
77     // Splitting
78     const blasint n1 = SREC_SPLIT(*n);
79     const blasint n2 = *n - n1;
80     const blasint m2 = *m - n1;
81     // A_L A_R
82     float *const A_L = A;
83     float *const A_R = A + *ldA * n1;
84 
85     // A_TL A_TR
86     // A_BL A_BR
87     float *const A_TL = A;
88     float *const A_TR = A + *ldA * n1;
89     float *const A_BL = A             + n1;
90     float *const A_BR = A + *ldA * n1 + n1;
91 
92     // ipiv_T
93     // ipiv_B
94     blasint *const ipiv_T = ipiv;
95     blasint *const ipiv_B = ipiv + n1;
96 
97     // recursion(A_L, ipiv_T)
98     RELAPACK_sgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
99     if (*info)
100 	return;
101     // apply pivots to A_R
102     LAPACK(slaswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
103 
104     // A_TR = A_TL \ A_TR
105     BLAS(strsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
106     // A_BR = A_BR - A_BL * A_TR
107     BLAS(sgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
108 
109     // recursion(A_BR, ipiv_B)
110     RELAPACK_sgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
111     if (*info)
112         *info += n1;
113     // apply pivots to A_BL
114     LAPACK(slaswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
115     // shift pivots
116     blasint i;
117     for (i = 0; i < n2; i++)
118         ipiv_B[i] += n1;
119 }
120