1 #include "relapack.h"
2 #if XSYTRF_ALLOW_MALLOC
3 #include <stdlib.h>
4 #endif
5 static void RELAPACK_ssytrf_rec(const char *, const blasint *, const blasint *, blasint *,
6 float *, const blasint *, blasint *, float *, const blasint *, blasint *);
7
8
9 /** SSYTRF computes the factorization of a complex symmetric matrix A using the Bunch-Kaufman diagonal pivoting method.
10 *
11 * This routine is functionally equivalent to LAPACK's ssytrf.
12 * For details on its interface, see
13 * http://www.netlib.org/lapack/explore-html/da/de9/ssytrf_8f.html
14 * */
RELAPACK_ssytrf(const char * uplo,const blasint * n,float * A,const blasint * ldA,blasint * ipiv,float * Work,const blasint * lWork,blasint * info)15 void RELAPACK_ssytrf(
16 const char *uplo, const blasint *n,
17 float *A, const blasint *ldA, blasint *ipiv,
18 float *Work, const blasint *lWork, blasint *info
19 ) {
20
21 // Required work size
22 const blasint cleanlWork = *n * (*n / 2);
23 blasint minlWork = cleanlWork;
24 #if XSYTRF_ALLOW_MALLOC
25 minlWork = 1;
26 #endif
27
28 // Check arguments
29 const blasint lower = LAPACK(lsame)(uplo, "L");
30 const blasint upper = LAPACK(lsame)(uplo, "U");
31 *info = 0;
32 if (!lower && !upper)
33 *info = -1;
34 else if (*n < 0)
35 *info = -2;
36 else if (*ldA < MAX(1, *n))
37 *info = -4;
38 else if ((*lWork <1 || *lWork < minlWork) && *lWork != -1)
39 *info = -7;
40 else if (*lWork == -1) {
41 // Work size query
42 *Work = cleanlWork;
43 return;
44 }
45
46 // Ensure Work size
47 float *cleanWork = Work;
48 #if XSYTRF_ALLOW_MALLOC
49 if (!*info && *lWork < cleanlWork) {
50 cleanWork = malloc(cleanlWork * sizeof(float));
51 if (!cleanWork)
52 *info = -7;
53 }
54 #endif
55
56 if (*info) {
57 const blasint minfo = -*info;
58 LAPACK(xerbla)("SSYTRF", &minfo, strlen("SSYTRF"));
59 return;
60 }
61
62 // Clean char * arguments
63 const char cleanuplo = lower ? 'L' : 'U';
64
65 // Dummy arguments
66 blasint nout;
67
68 // Recursive kernel
69 if (*n != 0)
70 RELAPACK_ssytrf_rec(&cleanuplo, n, n, &nout, A, ldA, ipiv, cleanWork, n, info);
71
72 #if XSYTRF_ALLOW_MALLOC
73 if (cleanWork != Work)
74 free(cleanWork);
75 #endif
76 }
77
78
79 /** ssytrf's recursive compute kernel */
RELAPACK_ssytrf_rec(const char * uplo,const blasint * n_full,const blasint * n,blasint * n_out,float * A,const blasint * ldA,blasint * ipiv,float * Work,const blasint * ldWork,blasint * info)80 static void RELAPACK_ssytrf_rec(
81 const char *uplo, const blasint *n_full, const blasint *n, blasint *n_out,
82 float *A, const blasint *ldA, blasint *ipiv,
83 float *Work, const blasint *ldWork, blasint *info
84 ) {
85
86 // top recursion level?
87 const blasint top = *n_full == *n;
88
89 if (*n <= MAX(CROSSOVER_SSYTRF, 3)) {
90 // Unblocked
91 if (top) {
92 LAPACK(ssytf2)(uplo, n, A, ldA, ipiv, info);
93 *n_out = *n;
94 } else
95 RELAPACK_ssytrf_rec2(uplo, n_full, n, n_out, A, ldA, ipiv, Work, ldWork, info);
96 return;
97 }
98
99 blasint info1, info2;
100
101 // Constants
102 const float ONE[] = { 1. };
103 const float MONE[] = { -1. };
104 const blasint iONE[] = { 1 };
105
106 // Loop iterator
107 blasint i;
108
109 const blasint n_rest = *n_full - *n;
110
111 if (*uplo == 'L') {
112 // Splitting (setup)
113 blasint n1 = SREC_SPLIT(*n);
114 blasint n2 = *n - n1;
115
116 // Work_L *
117 float *const Work_L = Work;
118
119 // recursion(A_L)
120 blasint n1_out;
121 RELAPACK_ssytrf_rec(uplo, n_full, &n1, &n1_out, A, ldA, ipiv, Work_L, ldWork, &info1);
122 n1 = n1_out;
123
124 // Splitting (continued)
125 n2 = *n - n1;
126 const blasint n_full2 = *n_full - n1;
127
128 // * *
129 // A_BL A_BR
130 // A_BL_B A_BR_B
131 float *const A_BL = A + n1;
132 float *const A_BR = A + *ldA * n1 + n1;
133 float *const A_BL_B = A + *n;
134 float *const A_BR_B = A + *ldA * n1 + *n;
135
136 // * *
137 // Work_BL Work_BR
138 // * *
139 // (top recursion level: use Work as Work_BR)
140 float *const Work_BL = Work + n1;
141 float *const Work_BR = top ? Work : Work + *ldWork * n1 + n1;
142 const blasint ldWork_BR = top ? n2 : *ldWork;
143
144 // ipiv_T
145 // ipiv_B
146 blasint *const ipiv_B = ipiv + n1;
147
148 // A_BR = A_BR - A_BL Work_BL'
149 RELAPACK_sgemmt(uplo, "N", "T", &n2, &n1, MONE, A_BL, ldA, Work_BL, ldWork, ONE, A_BR, ldA);
150 BLAS(sgemm)("N", "T", &n_rest, &n2, &n1, MONE, A_BL_B, ldA, Work_BL, ldWork, ONE, A_BR_B, ldA);
151
152 // recursion(A_BR)
153 blasint n2_out;
154 RELAPACK_ssytrf_rec(uplo, &n_full2, &n2, &n2_out, A_BR, ldA, ipiv_B, Work_BR, &ldWork_BR, &info2);
155
156 if (n2_out != n2) {
157 // undo 1 column of updates
158 const blasint n_restp1 = n_rest + 1;
159
160 // last column of A_BR
161 float *const A_BR_r = A_BR + *ldA * n2_out + n2_out;
162
163 // last row of A_BL
164 float *const A_BL_b = A_BL + n2_out;
165
166 // last row of Work_BL
167 float *const Work_BL_b = Work_BL + n2_out;
168
169 // A_BR_r = A_BR_r + A_BL_b Work_BL_b'
170 BLAS(sgemv)("N", &n_restp1, &n1, ONE, A_BL_b, ldA, Work_BL_b, ldWork, ONE, A_BR_r, iONE);
171 }
172 n2 = n2_out;
173
174 // shift pivots
175 for (i = 0; i < n2; i++)
176 if (ipiv_B[i] > 0)
177 ipiv_B[i] += n1;
178 else
179 ipiv_B[i] -= n1;
180
181 *info = info1 || info2;
182 *n_out = n1 + n2;
183 } else {
184 // Splitting (setup)
185 blasint n2 = SREC_SPLIT(*n);
186 blasint n1 = *n - n2;
187
188 // * Work_R
189 // (top recursion level: use Work as Work_R)
190 float *const Work_R = top ? Work : Work + *ldWork * n1;
191
192 // recursion(A_R)
193 blasint n2_out;
194 RELAPACK_ssytrf_rec(uplo, n_full, &n2, &n2_out, A, ldA, ipiv, Work_R, ldWork, &info2);
195 const blasint n2_diff = n2 - n2_out;
196 n2 = n2_out;
197
198 // Splitting (continued)
199 n1 = *n - n2;
200 const blasint n_full1 = *n_full - n2;
201
202 // * A_TL_T A_TR_T
203 // * A_TL A_TR
204 // * * *
205 float *const A_TL_T = A + *ldA * n_rest;
206 float *const A_TR_T = A + *ldA * (n_rest + n1);
207 float *const A_TL = A + *ldA * n_rest + n_rest;
208 float *const A_TR = A + *ldA * (n_rest + n1) + n_rest;
209
210 // Work_L *
211 // * Work_TR
212 // * *
213 // (top recursion level: Work_R was Work)
214 float *const Work_L = Work;
215 float *const Work_TR = Work + *ldWork * (top ? n2_diff : n1) + n_rest;
216 const blasint ldWork_L = top ? n1 : *ldWork;
217
218 // A_TL = A_TL - A_TR Work_TR'
219 RELAPACK_sgemmt(uplo, "N", "T", &n1, &n2, MONE, A_TR, ldA, Work_TR, ldWork, ONE, A_TL, ldA);
220 BLAS(sgemm)("N", "T", &n_rest, &n1, &n2, MONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, ldA);
221
222 // recursion(A_TL)
223 blasint n1_out;
224 RELAPACK_ssytrf_rec(uplo, &n_full1, &n1, &n1_out, A, ldA, ipiv, Work_L, &ldWork_L, &info1);
225
226 if (n1_out != n1) {
227 // undo 1 column of updates
228 const blasint n_restp1 = n_rest + 1;
229
230 // A_TL_T_l = A_TL_T_l + A_TR_T Work_TR_t'
231 BLAS(sgemv)("N", &n_restp1, &n2, ONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, iONE);
232 }
233 n1 = n1_out;
234
235 *info = info2 || info1;
236 *n_out = n1 + n2;
237 }
238 }
239