1 #include "relapack.h"
2 #if XSYTRF_ALLOW_MALLOC
3 #include <stdlib.h>
4 #endif
5
6 static void RELAPACK_zsytrf_rec(const char *, const blasint *, const blasint *, blasint *,
7 double *, const blasint *, blasint *, double *, const blasint *, blasint *);
8
9
10 /** ZSYTRF computes the factorization of a complex symmetric matrix A using the Bunch-Kaufman diagonal pivoting method.
11 *
12 * This routine is functionally equivalent to LAPACK's zsytrf.
13 * For details on its interface, see
14 * http://www.netlib.org/lapack/explore-html/da/d94/zsytrf_8f.html
15 * */
RELAPACK_zsytrf(const char * uplo,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,double * Work,const blasint * lWork,blasint * info)16 void RELAPACK_zsytrf(
17 const char *uplo, const blasint *n,
18 double *A, const blasint *ldA, blasint *ipiv,
19 double *Work, const blasint *lWork, blasint *info
20 ) {
21
22 // Required work size
23 const blasint cleanlWork = *n * (*n / 2);
24 blasint minlWork = cleanlWork;
25 #if XSYTRF_ALLOW_MALLOC
26 minlWork = 1;
27 #endif
28
29 // Check arguments
30 const blasint lower = LAPACK(lsame)(uplo, "L");
31 const blasint upper = LAPACK(lsame)(uplo, "U");
32 *info = 0;
33 if (!lower && !upper)
34 *info = -1;
35 else if (*n < 0)
36 *info = -2;
37 else if (*ldA < MAX(1, *n))
38 *info = -4;
39 else if ((*lWork < 1 || *lWork < minlWork) && *lWork != -1)
40 *info = -7;
41 else if (*lWork == -1) {
42 // Work size query
43 *Work = cleanlWork;
44 return;
45 }
46
47 // Ensure Work size
48 double *cleanWork = Work;
49 #if XSYTRF_ALLOW_MALLOC
50 if (!*info && *lWork < cleanlWork) {
51 cleanWork = malloc(cleanlWork * 2 * sizeof(double));
52 if (!cleanWork)
53 *info = -7;
54 }
55 #endif
56
57 if (*info) {
58 const blasint minfo = -*info;
59 LAPACK(xerbla)("ZSYTRF", &minfo, strlen("ZSYTRF"));
60 return;
61 }
62
63 // Clean char * arguments
64 const char cleanuplo = lower ? 'L' : 'U';
65
66 // Dummy arguments
67 blasint nout;
68
69 // Recursive kernel
70 if (*n != 0)
71 RELAPACK_zsytrf_rec(&cleanuplo, n, n, &nout, A, ldA, ipiv, cleanWork, n, info);
72
73 #if XSYTRF_ALLOW_MALLOC
74 if (cleanWork != Work)
75 free(cleanWork);
76 #endif
77 }
78
79
80 /** zsytrf's recursive compute kernel */
RELAPACK_zsytrf_rec(const char * uplo,const blasint * n_full,const blasint * n,blasint * n_out,double * A,const blasint * ldA,blasint * ipiv,double * Work,const blasint * ldWork,blasint * info)81 static void RELAPACK_zsytrf_rec(
82 const char *uplo, const blasint *n_full, const blasint *n, blasint *n_out,
83 double *A, const blasint *ldA, blasint *ipiv,
84 double *Work, const blasint *ldWork, blasint *info
85 ) {
86
87 // top recursion level?
88 const blasint top = *n_full == *n;
89
90 if (*n <= MAX(CROSSOVER_ZSYTRF, 3)) {
91 // Unblocked
92 if (top) {
93 LAPACK(zsytf2)(uplo, n, A, ldA, ipiv, info);
94 *n_out = *n;
95 } else
96 RELAPACK_zsytrf_rec2(uplo, n_full, n, n_out, A, ldA, ipiv, Work, ldWork, info);
97 return;
98 }
99
100 blasint info1, info2;
101
102 // Constants
103 const double ONE[] = { 1., 0. };
104 const double MONE[] = { -1., 0. };
105 const blasint iONE[] = { 1 };
106
107 // Loop iterator
108 blasint i;
109
110 const blasint n_rest = *n_full - *n;
111
112 if (*uplo == 'L') {
113 // Splitting (setup)
114 blasint n1 = ZREC_SPLIT(*n);
115 blasint n2 = *n - n1;
116
117 // Work_L *
118 double *const Work_L = Work;
119
120 // recursion(A_L)
121 blasint n1_out;
122 RELAPACK_zsytrf_rec(uplo, n_full, &n1, &n1_out, A, ldA, ipiv, Work_L, ldWork, &info1);
123 n1 = n1_out;
124
125 // Splitting (continued)
126 n2 = *n - n1;
127 const blasint n_full2 = *n_full - n1;
128
129 // * *
130 // A_BL A_BR
131 // A_BL_B A_BR_B
132 double *const A_BL = A + 2 * n1;
133 double *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
134 double *const A_BL_B = A + 2 * *n;
135 double *const A_BR_B = A + 2 * *ldA * n1 + 2 * *n;
136
137 // * *
138 // Work_BL Work_BR
139 // * *
140 // (top recursion level: use Work as Work_BR)
141 double *const Work_BL = Work + 2 * n1;
142 double *const Work_BR = top ? Work : Work + 2 * *ldWork * n1 + 2 * n1;
143 const blasint ldWork_BR = top ? n2 : *ldWork;
144
145 // ipiv_T
146 // ipiv_B
147 blasint *const ipiv_B = ipiv + n1;
148
149 // A_BR = A_BR - A_BL Work_BL'
150 RELAPACK_zgemmt(uplo, "N", "T", &n2, &n1, MONE, A_BL, ldA, Work_BL, ldWork, ONE, A_BR, ldA);
151 BLAS(zgemm)("N", "T", &n_rest, &n2, &n1, MONE, A_BL_B, ldA, Work_BL, ldWork, ONE, A_BR_B, ldA);
152
153 // recursion(A_BR)
154 blasint n2_out;
155 RELAPACK_zsytrf_rec(uplo, &n_full2, &n2, &n2_out, A_BR, ldA, ipiv_B, Work_BR, &ldWork_BR, &info2);
156
157 if (n2_out != n2) {
158 // undo 1 column of updates
159 const blasint n_restp1 = n_rest + 1;
160
161 // last column of A_BR
162 double *const A_BR_r = A_BR + 2 * *ldA * n2_out + 2 * n2_out;
163
164 // last row of A_BL
165 double *const A_BL_b = A_BL + 2 * n2_out;
166
167 // last row of Work_BL
168 double *const Work_BL_b = Work_BL + 2 * n2_out;
169
170 // A_BR_r = A_BR_r + A_BL_b Work_BL_b'
171 BLAS(zgemv)("N", &n_restp1, &n1, ONE, A_BL_b, ldA, Work_BL_b, ldWork, ONE, A_BR_r, iONE);
172 }
173 n2 = n2_out;
174
175 // shift pivots
176 for (i = 0; i < n2; i++)
177 if (ipiv_B[i] > 0)
178 ipiv_B[i] += n1;
179 else
180 ipiv_B[i] -= n1;
181
182 *info = info1 || info2;
183 *n_out = n1 + n2;
184 } else {
185 // Splitting (setup)
186 blasint n2 = ZREC_SPLIT(*n);
187 blasint n1 = *n - n2;
188
189 // * Work_R
190 // (top recursion level: use Work as Work_R)
191 double *const Work_R = top ? Work : Work + 2 * *ldWork * n1;
192
193 // recursion(A_R)
194 blasint n2_out;
195 RELAPACK_zsytrf_rec(uplo, n_full, &n2, &n2_out, A, ldA, ipiv, Work_R, ldWork, &info2);
196 const blasint n2_diff = n2 - n2_out;
197 n2 = n2_out;
198
199 // Splitting (continued)
200 n1 = *n - n2;
201 const blasint n_full1 = *n_full - n2;
202
203 // * A_TL_T A_TR_T
204 // * A_TL A_TR
205 // * * *
206 double *const A_TL_T = A + 2 * *ldA * n_rest;
207 double *const A_TR_T = A + 2 * *ldA * (n_rest + n1);
208 double *const A_TL = A + 2 * *ldA * n_rest + 2 * n_rest;
209 double *const A_TR = A + 2 * *ldA * (n_rest + n1) + 2 * n_rest;
210
211 // Work_L *
212 // * Work_TR
213 // * *
214 // (top recursion level: Work_R was Work)
215 double *const Work_L = Work;
216 double *const Work_TR = Work + 2 * *ldWork * (top ? n2_diff : n1) + 2 * n_rest;
217 const blasint ldWork_L = top ? n1 : *ldWork;
218
219 // A_TL = A_TL - A_TR Work_TR'
220 RELAPACK_zgemmt(uplo, "N", "T", &n1, &n2, MONE, A_TR, ldA, Work_TR, ldWork, ONE, A_TL, ldA);
221 BLAS(zgemm)("N", "T", &n_rest, &n1, &n2, MONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, ldA);
222
223 // recursion(A_TL)
224 blasint n1_out;
225 RELAPACK_zsytrf_rec(uplo, &n_full1, &n1, &n1_out, A, ldA, ipiv, Work_L, &ldWork_L, &info1);
226
227 if (n1_out != n1) {
228 // undo 1 column of updates
229 const blasint n_restp1 = n_rest + 1;
230
231 // A_TL_T_l = A_TL_T_l + A_TR_T Work_TR_t'
232 BLAS(zgemv)("N", &n_restp1, &n2, ONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, iONE);
233 }
234 n1 = n1_out;
235
236 *info = info2 || info1;
237 *n_out = n1 + n2;
238 }
239 }
240