1 #include "relapack.h"
2 static void RELAPACK_sgetrf_rec(const blasint *, const blasint *, float *, const blasint *,
3 blasint *, blasint *);
4
5
6 /** SGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
7 *
8 * This routine is functionally equivalent to LAPACK's sgetrf.
9 * For details on its interface, see
10 * http://www.netlib.org/lapack/explore-html/de/de2/sgetrf_8f.html
11 * */
RELAPACK_sgetrf(const blasint * m,const blasint * n,float * A,const blasint * ldA,blasint * ipiv,blasint * info)12 void RELAPACK_sgetrf(
13 const blasint *m, const blasint *n,
14 float *A, const blasint *ldA, blasint *ipiv,
15 blasint *info
16 ) {
17 // Check arguments
18 *info = 0;
19 if (*m < 0)
20 *info = -1;
21 else if (*n < 0)
22 *info = -2;
23 else if (*ldA < MAX(1, *m))
24 *info = -4;
25 if (*info) {
26 const blasint minfo = -*info;
27 LAPACK(xerbla)("SGETRF", &minfo, strlen("SGETRF"));
28 return;
29 }
30
31 if (*m == 0 || *n == 0) return;
32
33 const blasint sn = MIN(*m, *n);
34 RELAPACK_sgetrf_rec(m, &sn, A, ldA, ipiv, info);
35
36 // Right remainder
37 if (*m < *n) {
38 // Constants
39 const float ONE[] = { 1. };
40 const blasint iONE[] = { 1 };
41
42 // Splitting
43 const blasint rn = *n - *m;
44
45 // A_L A_R
46 const float *const A_L = A;
47 float *const A_R = A + *ldA * *m;
48
49 // A_R = apply(ipiv, A_R)
50 LAPACK(slaswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
51 // A_R = A_L \ A_R
52 BLAS(strsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
53 }
54 }
55
56
57 /** sgetrf's recursive compute kernel */
RELAPACK_sgetrf_rec(const blasint * m,const blasint * n,float * A,const blasint * ldA,blasint * ipiv,blasint * info)58 static void RELAPACK_sgetrf_rec(
59 const blasint *m, const blasint *n,
60 float *A, const blasint *ldA, blasint *ipiv,
61 blasint *info
62 ) {
63
64 if (*m == 0 || *n == 0) return;
65
66 if ( *n <= MAX(CROSSOVER_SGETRF, 1)) {
67 // Unblocked
68 LAPACK(sgetrf2)(m, n, A, ldA, ipiv, info);
69 return;
70 }
71
72 // Constants
73 const float ONE[] = { 1. };
74 const float MONE[] = { -1. };
75 const blasint iONE[] = { 1 };
76
77 // Splitting
78 const blasint n1 = SREC_SPLIT(*n);
79 const blasint n2 = *n - n1;
80 const blasint m2 = *m - n1;
81 // A_L A_R
82 float *const A_L = A;
83 float *const A_R = A + *ldA * n1;
84
85 // A_TL A_TR
86 // A_BL A_BR
87 float *const A_TL = A;
88 float *const A_TR = A + *ldA * n1;
89 float *const A_BL = A + n1;
90 float *const A_BR = A + *ldA * n1 + n1;
91
92 // ipiv_T
93 // ipiv_B
94 blasint *const ipiv_T = ipiv;
95 blasint *const ipiv_B = ipiv + n1;
96
97 // recursion(A_L, ipiv_T)
98 RELAPACK_sgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
99 if (*info)
100 return;
101 // apply pivots to A_R
102 LAPACK(slaswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
103
104 // A_TR = A_TL \ A_TR
105 BLAS(strsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
106 // A_BR = A_BR - A_BL * A_TR
107 BLAS(sgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
108
109 // recursion(A_BR, ipiv_B)
110 RELAPACK_sgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
111 if (*info)
112 *info += n1;
113 // apply pivots to A_BL
114 LAPACK(slaswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
115 // shift pivots
116 blasint i;
117 for (i = 0; i < n2; i++)
118 ipiv_B[i] += n1;
119 }
120