1 #include "relapack.h"
2 #include <math.h>
3 
4 static void RELAPACK_ctgsyl_rec(const char *, const blasint *, const blasint *,
5     const blasint *, const float *, const blasint *, const float *, const blasint *,
6     float *, const blasint *, const float *, const blasint *, const float *,
7     const blasint *, float *, const blasint *, float *, float *, float *, blasint *);
8 
9 
10 /** CTGSYL solves the generalized Sylvester equation.
11  *
12  * This routine is functionally equivalent to LAPACK's ctgsyl.
13  * For details on its interface, see
14  * http://www.netlib.org/lapack/explore-html/d7/de7/ctgsyl_8f.html
15  * */
RELAPACK_ctgsyl(const char * trans,const blasint * ijob,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dif,float * Work,const blasint * lWork,blasint * iWork,blasint * info)16 void RELAPACK_ctgsyl(
17     const char *trans, const blasint *ijob, const blasint *m, const blasint *n,
18     const float *A, const blasint *ldA, const float *B, const blasint *ldB,
19     float *C, const blasint *ldC,
20     const float *D, const blasint *ldD, const float *E, const blasint *ldE,
21     float *F, const blasint *ldF,
22     float *scale, float *dif,
23     float *Work, const blasint *lWork, blasint *iWork, blasint *info
24 ) {
25 
26     // Parse arguments
27     const blasint notran = LAPACK(lsame)(trans, "N");
28     const blasint tran = LAPACK(lsame)(trans, "C");
29 
30     // Compute work buffer size
31     blasint lwmin = 1;
32     if (notran && (*ijob == 1 || *ijob == 2))
33         lwmin = MAX(1, 2 * *m * *n);
34     *info = 0;
35 
36     // Check arguments
37     if (!tran && !notran)
38         *info = -1;
39     else if (notran && (*ijob < 0 || *ijob > 4))
40         *info = -2;
41     else if (*m <= 0)
42         *info = -3;
43     else if (*n <= 0)
44         *info = -4;
45     else if (*ldA < MAX(1, *m))
46         *info = -6;
47     else if (*ldB < MAX(1, *n))
48         *info = -8;
49     else if (*ldC < MAX(1, *m))
50         *info = -10;
51     else if (*ldD < MAX(1, *m))
52         *info = -12;
53     else if (*ldE < MAX(1, *n))
54         *info = -14;
55     else if (*ldF < MAX(1, *m))
56         *info = -16;
57     else if (*lWork < lwmin && *lWork != -1)
58         *info = -20;
59     if (*info) {
60         const blasint minfo = -*info;
61         LAPACK(xerbla)("CTGSYL", &minfo, strlen("CTGSYL"));
62         return;
63     }
64 
65     if (*lWork == -1) {
66         // Work size query
67         *Work = lwmin;
68         return;
69     }
70 
71     if ( *m == 0 || *n == 0) {
72       *scale = 1.;
73       if (notran && (*ijob != 0))
74         *dif = 0.;
75       return;
76     }
77 
78     // Clean char * arguments
79     const char cleantrans = notran ? 'N' : 'C';
80 
81     // Constant
82     const float ZERO[] = { 0., 0. };
83 
84     blasint isolve = 1;
85     blasint ifunc  = 0;
86     if (notran) {
87         if (*ijob >= 3) {
88             ifunc = *ijob - 2;
89             LAPACK(claset)("F", m, n, ZERO, ZERO, C, ldC);
90             LAPACK(claset)("F", m, n, ZERO, ZERO, F, ldF);
91         } else if (*ijob >= 1)
92             isolve = 2;
93     }
94 
95     float scale2;
96     blasint iround;
97     for (iround = 1; iround <= isolve; iround++) {
98         *scale = 1;
99         float dscale = 0;
100         float dsum   = 1;
101         RELAPACK_ctgsyl_rec(&cleantrans, &ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, &dsum, &dscale, info);
102         if (dscale != 0) {
103             if (*ijob == 1 || *ijob == 3)
104                 *dif = sqrt(2 * *m * *n) / (dscale * sqrt(dsum));
105             else
106                 *dif = sqrt(*m * *n) / (dscale * sqrt(dsum));
107         }
108         if (isolve == 2) {
109             if (iround == 1) {
110                 if (notran)
111                     ifunc = *ijob;
112                 scale2 = *scale;
113                 LAPACK(clacpy)("F", m, n, C, ldC, Work, m);
114                 LAPACK(clacpy)("F", m, n, F, ldF, Work + 2 * *m * *n, m);
115                 LAPACK(claset)("F", m, n, ZERO, ZERO, C, ldC);
116                 LAPACK(claset)("F", m, n, ZERO, ZERO, F, ldF);
117             } else {
118                 LAPACK(clacpy)("F", m, n, Work, m, C, ldC);
119                 LAPACK(clacpy)("F", m, n, Work + 2 * *m * *n, m, F, ldF);
120                 *scale = scale2;
121             }
122         }
123     }
124 }
125 
126 
127 /** ctgsyl's recursive vompute kernel */
RELAPACK_ctgsyl_rec(const char * trans,const blasint * ifunc,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dsum,float * dscale,blasint * info)128 static void RELAPACK_ctgsyl_rec(
129     const char *trans, const blasint *ifunc, const blasint *m, const blasint *n,
130     const float *A, const blasint *ldA, const float *B, const blasint *ldB,
131     float *C, const blasint *ldC,
132     const float *D, const blasint *ldD, const float *E, const blasint *ldE,
133     float *F, const blasint *ldF,
134     float *scale, float *dsum, float *dscale,
135     blasint *info
136 ) {
137 
138     if (*m <= MAX(CROSSOVER_CTGSYL, 1) && *n <= MAX(CROSSOVER_CTGSYL, 1)) {
139         // Unblocked
140         LAPACK(ctgsy2)(trans, ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, dsum, dscale, info);
141         return;
142     }
143 
144     // Constants
145     const float ONE[]  = { 1., 0. };
146     const float MONE[] = { -1., 0. };
147     const blasint   iONE[] = { 1 };
148 
149     // Outputs
150     float scale1[] = { 1., 0. };
151     float scale2[] = { 1., 0. };
152     blasint   info1[]  = { 0 };
153     blasint   info2[]  = { 0 };
154 
155     if (*m > *n) {
156         // Splitting
157         const blasint m1 = CREC_SPLIT(*m);
158         const blasint m2 = *m - m1;
159 
160         // A_TL A_TR
161         // 0    A_BR
162         const float *const A_TL = A;
163         const float *const A_TR = A + 2 * *ldA * m1;
164         const float *const A_BR = A + 2 * *ldA * m1 + 2 * m1;
165 
166         // C_T
167         // C_B
168         float *const C_T = C;
169         float *const C_B = C + 2 * m1;
170 
171         // D_TL D_TR
172         // 0    D_BR
173         const float *const D_TL = D;
174         const float *const D_TR = D + 2 * *ldD * m1;
175         const float *const D_BR = D + 2 * *ldD * m1 + 2 * m1;
176 
177         // F_T
178         // F_B
179         float *const F_T = F;
180         float *const F_B = F + 2 * m1;
181 
182         if (*trans == 'N') {
183             // recursion(A_BR, B, C_B, D_BR, E, F_B)
184             RELAPACK_ctgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale1, dsum, dscale, info1);
185             // C_T = C_T - A_TR * C_B
186             BLAS(cgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
187             // F_T = F_T - D_TR * C_B
188             BLAS(cgemm)("N", "N", &m1, n, &m2, MONE, D_TR, ldD, C_B, ldC, scale1, F_T, ldF);
189             // recursion(A_TL, B, C_T, D_TL, E, F_T)
190             RELAPACK_ctgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale2, dsum, dscale, info2);
191             // apply scale
192             if (scale2[0] != 1) {
193                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
194                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m2, n, F_B, ldF, info);
195             }
196         } else {
197             // recursion(A_TL, B, C_T, D_TL, E, F_T)
198             RELAPACK_ctgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale1, dsum, dscale, info1);
199             // apply scale
200             if (scale1[0] != 1)
201                 LAPACK(clascl)("G", iONE, iONE, ONE, scale1, &m2, n, F_B, ldF, info);
202             // C_B = C_B - A_TR^H * C_T
203             BLAS(cgemm)("C", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
204             // C_B = C_B - D_TR^H * F_T
205             BLAS(cgemm)("C", "N", &m2, n, &m1, MONE, D_TR, ldD, F_T, ldC, ONE, C_B, ldC);
206             // recursion(A_BR, B, C_B, D_BR, E, F_B)
207             RELAPACK_ctgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale2, dsum, dscale, info2);
208             // apply scale
209             if (scale2[0] != 1) {
210                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_T, ldC, info);
211                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m1, n, F_T, ldF, info);
212             }
213         }
214     } else {
215         // Splitting
216         const blasint n1 = CREC_SPLIT(*n);
217         const blasint n2 = *n - n1;
218 
219         // B_TL B_TR
220         // 0    B_BR
221         const float *const B_TL = B;
222         const float *const B_TR = B + 2 * *ldB * n1;
223         const float *const B_BR = B + 2 * *ldB * n1 + 2 * n1;
224 
225         // C_L C_R
226         float *const C_L = C;
227         float *const C_R = C + 2 * *ldC * n1;
228 
229         // E_TL E_TR
230         // 0    E_BR
231         const float *const E_TL = E;
232         const float *const E_TR = E + 2 * *ldE * n1;
233         const float *const E_BR = E + 2 * *ldE * n1 + 2 * n1;
234 
235         // F_L F_R
236         float *const F_L = F;
237         float *const F_R = F + 2 * *ldF * n1;
238 
239         if (*trans == 'N') {
240             // recursion(A, B_TL, C_L, D, E_TL, F_L)
241             RELAPACK_ctgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale1, dsum, dscale, info1);
242             // C_R = C_R + F_L * B_TR
243             BLAS(cgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, B_TR, ldB, scale1, C_R, ldC);
244             // F_R = F_R + F_L * E_TR
245             BLAS(cgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, E_TR, ldE, scale1, F_R, ldF);
246             // recursion(A, B_BR, C_R, D, E_BR, F_R)
247             RELAPACK_ctgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale2, dsum, dscale, info2);
248             // apply scale
249             if (scale2[0] != 1) {
250                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
251                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n1, F_L, ldF, info);
252             }
253         } else {
254             // recursion(A, B_BR, C_R, D, E_BR, F_R)
255             RELAPACK_ctgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale1, dsum, dscale, info1);
256             // apply scale
257             if (scale1[0] != 1)
258                 LAPACK(clascl)("G", iONE, iONE, ONE, scale1, m, &n1, C_L, ldC, info);
259             // F_L = F_L + C_R * B_TR
260             BLAS(cgemm)("N", "C", m, &n1, &n2, ONE, C_R, ldC, B_TR, ldB, scale1, F_L, ldF);
261             // F_L = F_L + F_R * E_TR
262             BLAS(cgemm)("N", "C", m, &n1, &n2, ONE, F_R, ldF, E_TR, ldB, ONE, F_L, ldF);
263             // recursion(A, B_TL, C_L, D, E_TL, F_L)
264             RELAPACK_ctgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale2, dsum, dscale, info2);
265             // apply scale
266             if (scale2[0] != 1) {
267                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
268                 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n2, F_R, ldF, info);
269             }
270         }
271     }
272 
273     *scale = scale1[0] * scale2[0];
274     *info  = info1[0] || info2[0];
275 }
276