1 #include "relapack.h"
2 #if XSYTRF_ALLOW_MALLOC
3 #include <stdlib.h>
4 #endif
5 
6 static void RELAPACK_zsytrf_rook_rec(const char *, const blasint *, const blasint *, blasint *,
7     double *, const blasint *, blasint *, double *, const blasint *, blasint *);
8 
9 
10 /** ZSYTRF_ROOK computes the factorization of a complex symmetric matrix A using the bounded Bunch-Kaufman ("rook") diagonal pivoting method.
11  *
12  * This routine is functionally equivalent to LAPACK's zsytrf_rook.
13  * For details on its interface, see
14  * http://www.netlib.org/lapack/explore-html/d6/d6e/zsytrf__rook_8f.html
15  * */
RELAPACK_zsytrf_rook(const char * uplo,const blasint * n,double * A,const blasint * ldA,blasint * ipiv,double * Work,const blasint * lWork,blasint * info)16 void RELAPACK_zsytrf_rook(
17     const char *uplo, const blasint *n,
18     double *A, const blasint *ldA, blasint *ipiv,
19     double *Work, const blasint *lWork, blasint *info
20 ) {
21 
22     // Required work size
23     const blasint cleanlWork = *n * (*n / 2);
24     blasint minlWork = cleanlWork;
25 #if XSYTRF_ALLOW_MALLOC
26     minlWork = 1;
27 #endif
28 
29     // Check arguments
30     const blasint lower = LAPACK(lsame)(uplo, "L");
31     const blasint upper = LAPACK(lsame)(uplo, "U");
32     *info = 0;
33     if (!lower && !upper)
34         *info = -1;
35     else if (*n < 0)
36         *info = -2;
37     else if (*ldA < MAX(1, *n))
38         *info = -4;
39     else if ((*lWork < 1 || *lWork < minlWork) && *lWork != -1)
40         *info = -7;
41     else if (*lWork == -1) {
42         // Work size query
43         *Work = cleanlWork;
44         return;
45     }
46 
47     // Ensure Work size
48     double *cleanWork = Work;
49 #if XSYTRF_ALLOW_MALLOC
50     if (!*info && *lWork < cleanlWork) {
51         cleanWork = malloc(cleanlWork * 2 * sizeof(double));
52         if (!cleanWork)
53             *info = -7;
54     }
55 #endif
56 
57     if (*info) {
58         const blasint minfo = -*info;
59         LAPACK(xerbla)("ZSYTRF_ROOK", &minfo, strlen("ZSYTRF_ROOK"));
60         return;
61     }
62 
63     // Clean char * arguments
64     const char cleanuplo = lower ? 'L' : 'U';
65 
66     // Dummy argument
67     blasint nout;
68 
69     // Recursive kernel
70     if (*n != 0)
71     RELAPACK_zsytrf_rook_rec(&cleanuplo, n, n, &nout, A, ldA, ipiv, cleanWork, n, info);
72 
73 #if XSYTRF_ALLOW_MALLOC
74     if (cleanWork != Work)
75         free(cleanWork);
76 #endif
77 }
78 
79 
80 /** zsytrf_rook's recursive compute kernel */
RELAPACK_zsytrf_rook_rec(const char * uplo,const blasint * n_full,const blasint * n,blasint * n_out,double * A,const blasint * ldA,blasint * ipiv,double * Work,const blasint * ldWork,blasint * info)81 static void RELAPACK_zsytrf_rook_rec(
82     const char *uplo, const blasint *n_full, const blasint *n, blasint *n_out,
83     double *A, const blasint *ldA, blasint *ipiv,
84     double *Work, const blasint *ldWork, blasint *info
85 ) {
86 
87     // top recursion level?
88     const blasint top = *n_full == *n;
89 
90     if (*n <= MAX(CROSSOVER_ZSYTRF_ROOK, 3)) {
91         // Unblocked
92         if (top) {
93             LAPACK(zsytf2)(uplo, n, A, ldA, ipiv, info);
94             *n_out = *n;
95         } else
96             RELAPACK_zsytrf_rook_rec2(uplo, n_full, n, n_out, A, ldA, ipiv, Work, ldWork, info);
97         return;
98     }
99 
100     blasint info1, info2;
101 
102     // Constants
103     const double ONE[]  = { 1., 0. };
104     const double MONE[] = { -1., 0. };
105     const blasint    iONE[] = { 1 };
106 
107     const blasint n_rest = *n_full - *n;
108 
109     if (*uplo == 'L') {
110         // Splitting (setup)
111         blasint n1 = ZREC_SPLIT(*n);
112         blasint n2 = *n - n1;
113 
114         // Work_L *
115         double *const Work_L = Work;
116 
117         // recursion(A_L)
118         blasint n1_out;
119         RELAPACK_zsytrf_rook_rec(uplo, n_full, &n1, &n1_out, A, ldA, ipiv, Work_L, ldWork, &info1);
120         n1 = n1_out;
121 
122         // Splitting (continued)
123         n2 = *n - n1;
124         const blasint n_full2   = *n_full - n1;
125 
126         // *      *
127         // A_BL   A_BR
128         // A_BL_B A_BR_B
129         double *const A_BL   = A                 + 2 * n1;
130         double *const A_BR   = A + 2 * *ldA * n1 + 2 * n1;
131         double *const A_BL_B = A                 + 2 * *n;
132         double *const A_BR_B = A + 2 * *ldA * n1 + 2 * *n;
133 
134         // *        *
135         // Work_BL Work_BR
136         // *       *
137         // (top recursion level: use Work as Work_BR)
138         double *const Work_BL =              Work                    + 2 * n1;
139         double *const Work_BR = top ? Work : Work + 2 * *ldWork * n1 + 2 * n1;
140         const blasint ldWork_BR = top ? n2 : *ldWork;
141 
142         // ipiv_T
143         // ipiv_B
144         blasint *const ipiv_B = ipiv + n1;
145 
146         // A_BR = A_BR - A_BL Work_BL'
147         RELAPACK_zgemmt(uplo, "N", "T", &n2, &n1, MONE, A_BL, ldA, Work_BL, ldWork, ONE, A_BR, ldA);
148         BLAS(zgemm)("N", "T", &n_rest, &n2, &n1, MONE, A_BL_B, ldA, Work_BL, ldWork, ONE, A_BR_B, ldA);
149 
150         // recursion(A_BR)
151         blasint n2_out;
152         RELAPACK_zsytrf_rook_rec(uplo, &n_full2, &n2, &n2_out, A_BR, ldA, ipiv_B, Work_BR, &ldWork_BR, &info2);
153 
154         if (n2_out != n2) {
155             // undo 1 column of updates
156             const blasint n_restp1 = n_rest + 1;
157 
158             // last column of A_BR
159             double *const A_BR_r = A_BR + 2 * *ldA * n2_out + 2 * n2_out;
160 
161             // last row of A_BL
162             double *const A_BL_b = A_BL + 2 * n2_out;
163 
164             // last row of Work_BL
165             double *const Work_BL_b = Work_BL + 2 * n2_out;
166 
167             // A_BR_r = A_BR_r + A_BL_b Work_BL_b'
168             BLAS(zgemv)("N", &n_restp1, &n1, ONE, A_BL_b, ldA, Work_BL_b, ldWork, ONE, A_BR_r, iONE);
169         }
170         n2 = n2_out;
171 
172         // shift pivots
173         blasint i;
174         for (i = 0; i < n2; i++)
175             if (ipiv_B[i] > 0)
176                 ipiv_B[i] += n1;
177             else
178                 ipiv_B[i] -= n1;
179 
180         *info  = info1 || info2;
181         *n_out = n1 + n2;
182     } else {
183         // Splitting (setup)
184         blasint n2 = ZREC_SPLIT(*n);
185         blasint n1 = *n - n2;
186 
187         // * Work_R
188         // (top recursion level: use Work as Work_R)
189         double *const Work_R = top ? Work : Work + 2 * *ldWork * n1;
190 
191         // recursion(A_R)
192         blasint n2_out;
193         RELAPACK_zsytrf_rook_rec(uplo, n_full, &n2, &n2_out, A, ldA, ipiv, Work_R, ldWork, &info2);
194         const blasint n2_diff = n2 - n2_out;
195         n2 = n2_out;
196 
197         // Splitting (continued)
198         n1 = *n - n2;
199         const blasint n_full1 = *n_full - n2;
200 
201         // * A_TL_T A_TR_T
202         // * A_TL   A_TR
203         // * *      *
204         double *const A_TL_T = A + 2 * *ldA * n_rest;
205         double *const A_TR_T = A + 2 * *ldA * (n_rest + n1);
206         double *const A_TL   = A + 2 * *ldA * n_rest        + 2 * n_rest;
207         double *const A_TR   = A + 2 * *ldA * (n_rest + n1) + 2 * n_rest;
208 
209         // Work_L *
210         // *      Work_TR
211         // *      *
212         // (top recursion level: Work_R was Work)
213         double *const Work_L  = Work;
214         double *const Work_TR = Work + 2 * *ldWork * (top ? n2_diff : n1) + 2 * n_rest;
215         const blasint ldWork_L = top ? n1 : *ldWork;
216 
217         // A_TL = A_TL - A_TR Work_TR'
218         RELAPACK_zgemmt(uplo, "N", "T", &n1, &n2, MONE, A_TR, ldA, Work_TR, ldWork, ONE, A_TL, ldA);
219         BLAS(zgemm)("N", "T", &n_rest, &n1, &n2, MONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, ldA);
220 
221         // recursion(A_TL)
222         blasint n1_out;
223         RELAPACK_zsytrf_rook_rec(uplo, &n_full1, &n1, &n1_out, A, ldA, ipiv, Work_L, &ldWork_L, &info1);
224 
225         if (n1_out != n1) {
226             // undo 1 column of updates
227             const blasint n_restp1 = n_rest + 1;
228 
229             // A_TL_T_l = A_TL_T_l + A_TR_T Work_TR_t'
230             BLAS(zgemv)("N", &n_restp1, &n2, ONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, iONE);
231         }
232         n1 = n1_out;
233 
234         *info  = info2 || info1;
235         *n_out = n1 + n2;
236     }
237 }
238