1 #include "relapack.h"
2
3 static void RELAPACK_dgetrf_rec(const blasint *, const blasint *, double *,
4 const blasint *, blasint *, blasint *);
5
6
7 /** DGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
8 *
9 * This routine is functionally equivalent to LAPACK's dgetrf.
10 * For details on its interface, see
11 * http://www.netlib.org/lapack/explore-html/d3/d6a/dgetrf_8f.html
12 * */
RELAPACK_dgetrf(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)13 void RELAPACK_dgetrf(
14 const blasint *m, const blasint *n,
15 double *A, const blasint *ldA, blasint *ipiv,
16 blasint *info
17 ) {
18 // Check arguments
19 *info = 0;
20 if (*m < 0)
21 *info = -1;
22 else if (*n < 0)
23 *info = -2;
24 else if (*ldA < MAX(1, *m))
25 *info = -4;
26 if (*info!=0) {
27 const blasint minfo = -*info;
28 LAPACK(xerbla)("DGETRF", &minfo, strlen("DGETRF"));
29 return;
30 }
31
32 if (*m == 0 || *n == 0) return;
33
34 const blasint sn = MIN(*m, *n);
35 RELAPACK_dgetrf_rec(m, &sn, A, ldA, ipiv, info);
36
37 // Right remainder
38 if (*m < *n) {
39 // Constants
40 const double ONE[] = { 1. };
41 const blasint iONE[] = { 1 };
42
43 // Splitting
44 const blasint rn = *n - *m;
45
46 // A_L A_R
47 const double *const A_L = A;
48 double *const A_R = A + *ldA * *m;
49
50 // A_R = apply(ipiv, A_R)
51 LAPACK(dlaswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
52 // A_R = A_S \ A_R
53 BLAS(dtrsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
54 }
55 }
56
57
58 /** dgetrf's recursive compute kernel */
RELAPACK_dgetrf_rec(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)59 static void RELAPACK_dgetrf_rec(
60 const blasint *m, const blasint *n,
61 double *A, const blasint *ldA, blasint *ipiv,
62 blasint *info
63 ) {
64 if ( *n <= MAX(CROSSOVER_DGETRF, 1)) {
65 // Unblocked
66 LAPACK(dgetrf2)(m, n, A, ldA, ipiv, info);
67 return;
68 }
69 // Constants
70 const double ONE[] = { 1. };
71 const double MONE[] = { -1. };
72 const blasint iONE[] = { 1 };
73
74 // Splitting
75 const blasint n1 = DREC_SPLIT(*n);
76 const blasint n2 = *n - n1;
77 const blasint m2 = *m - n1;
78
79 // A_L A_R
80 double *const A_L = A;
81 double *const A_R = A + *ldA * n1;
82
83 // A_TL A_TR
84 // A_BL A_BR
85 double *const A_TL = A;
86 double *const A_TR = A + *ldA * n1;
87 double *const A_BL = A + n1;
88 double *const A_BR = A + *ldA * n1 + n1;
89
90 // ipiv_T
91 // ipiv_B
92 blasint *const ipiv_T = ipiv;
93 blasint *const ipiv_B = ipiv + n1;
94
95 // recursion(A_L, ipiv_T)
96 RELAPACK_dgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
97 if (*info) return;
98 // apply pivots to A_R
99 LAPACK(dlaswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
100
101 // A_TR = A_TL \ A_TR
102 BLAS(dtrsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
103 // A_BR = A_BR - A_BL * A_TR
104 BLAS(dgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
105
106 // recursion(A_BR, ipiv_B)
107 RELAPACK_dgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
108 if (*info)
109 *info += n1;
110 // apply pivots to A_BL
111 LAPACK(dlaswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
112 // shift pivots
113 blasint i;
114 for (i = 0; i < n2; i++)
115 ipiv_B[i] += n1;
116 }
117