1 #include "relapack.h"
2
3 static void RELAPACK_zgetrf_rec(const blasint *, const blasint *, double *,
4 const blasint *, blasint *, blasint *);
5
6
7 /** ZGETRF computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges.
8 *
9 * This routine is functionally equivalent to LAPACK's zgetrf.
10 * For details on its interface, see
11 * http://www.netlib.org/lapack/explore-html/dd/dd1/zgetrf_8f.html
12 * */
RELAPACK_zgetrf(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)13 void RELAPACK_zgetrf(
14 const blasint *m, const blasint *n,
15 double *A, const blasint *ldA, blasint *ipiv,
16 blasint *info
17 ) {
18
19 // Check arguments
20 *info = 0;
21 if (*m < 0)
22 *info = -1;
23 else if (*n < 0)
24 *info = -2;
25 else if (*ldA < MAX(1, *m))
26 *info = -4;
27 if (*info) {
28 const blasint minfo = -*info;
29 LAPACK(xerbla)("ZGETRF", &minfo, strlen("ZGETRF"));
30 return;
31 }
32
33 if (*m == 0 || *n == 0) return;
34 const blasint sn = MIN(*m, *n);
35
36 RELAPACK_zgetrf_rec(m, &sn, A, ldA, ipiv, info);
37
38 // Right remainder
39 if (*m < *n) {
40 // Constants
41 const double ONE[] = { 1., 0. };
42 const blasint iONE[] = { 1 };
43
44 // Splitting
45 const blasint rn = *n - *m;
46
47 // A_L A_R
48 const double *const A_L = A;
49 double *const A_R = A + 2 * *ldA * *m;
50
51 // A_R = apply(ipiv, A_R)
52 LAPACK(zlaswp)(&rn, A_R, ldA, iONE, m, ipiv, iONE);
53 // A_R = A_L \ A_R
54 BLAS(ztrsm)("L", "L", "N", "U", m, &rn, ONE, A_L, ldA, A_R, ldA);
55 }
56 }
57
58
59 /** zgetrf's recursive compute kernel */
RELAPACK_zgetrf_rec(const blasint * m,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,blasint * info)60 static void RELAPACK_zgetrf_rec(
61 const blasint *m, const blasint *n,
62 double *A, const blasint *ldA, blasint *ipiv,
63 blasint *info
64 ) {
65
66 if (*m == 0 || *n == 0) return;
67
68 if ( *n <= MAX(CROSSOVER_ZGETRF, 1)) {
69 // Unblocked
70 LAPACK(zgetrf2)(m, n, A, ldA, ipiv, info);
71 return;
72 }
73
74 // Constants
75 const double ONE[] = { 1., 0. };
76 const double MONE[] = { -1., 0. };
77 const blasint iONE[] = { 1. };
78
79 // Splitting
80 const blasint n1 = ZREC_SPLIT(*n);
81 const blasint n2 = *n - n1;
82 const blasint m2 = *m - n1;
83
84 // A_L A_R
85 double *const A_L = A;
86 double *const A_R = A + 2 * *ldA * n1;
87
88 // A_TL A_TR
89 // A_BL A_BR
90 double *const A_TL = A;
91 double *const A_TR = A + 2 * *ldA * n1;
92 double *const A_BL = A + 2 * n1;
93 double *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
94
95 // ipiv_T
96 // ipiv_B
97 blasint *const ipiv_T = ipiv;
98 blasint *const ipiv_B = ipiv + n1;
99
100 // recursion(A_L, ipiv_T)
101 RELAPACK_zgetrf_rec(m, &n1, A_L, ldA, ipiv_T, info);
102 if (*info) return;
103
104 // apply pivots to A_R
105 LAPACK(zlaswp)(&n2, A_R, ldA, iONE, &n1, ipiv_T, iONE);
106
107 // A_TR = A_TL \ A_TR
108 BLAS(ztrsm)("L", "L", "N", "U", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
109 // A_BR = A_BR - A_BL * A_TR
110 BLAS(zgemm)("N", "N", &m2, &n2, &n1, MONE, A_BL, ldA, A_TR, ldA, ONE, A_BR, ldA);
111
112 // recursion(A_BR, ipiv_B)
113 RELAPACK_zgetrf_rec(&m2, &n2, A_BR, ldA, ipiv_B, info);
114 if (*info)
115 *info += n1;
116 // apply pivots to A_BL
117 LAPACK(zlaswp)(&n1, A_BL, ldA, iONE, &n2, ipiv_B, iONE);
118 // shift pivots
119 blasint i;
120 for (i = 0; i < n2; i++)
121 ipiv_B[i] += n1;
122 }
123