1 #include "relapack.h"
2 #include <math.h>
3 
4 static void RELAPACK_ztgsyl_rec(const char *, const blasint *, const blasint *,
5     const blasint *, const double *, const blasint *, const double *, const blasint *,
6     double *, const blasint *, const double *, const blasint *, const double *,
7     const blasint *, double *, const blasint *, double *, double *, double *, blasint *);
8 
9 
10 /** ZTGSYL solves the generalized Sylvester equation.
11  *
12  * This routine is functionally equivalent to LAPACK's ztgsyl.
13  * For details on its interface, see
14  * http://www.netlib.org/lapack/explore-html/db/d68/ztgsyl_8f.html
15  * */
RELAPACK_ztgsyl(const char * trans,const blasint * ijob,const blasint * m,const blasint * n,const double * A,const blasint * ldA,const double * B,const blasint * ldB,double * C,const blasint * ldC,const double * D,const blasint * ldD,const double * E,const blasint * ldE,double * F,const blasint * ldF,double * scale,double * dif,double * Work,const blasint * lWork,blasint * iWork,blasint * info)16 void RELAPACK_ztgsyl(
17     const char *trans, const blasint *ijob, const blasint *m, const blasint *n,
18     const double *A, const blasint *ldA, const double *B, const blasint *ldB,
19     double *C, const blasint *ldC,
20     const double *D, const blasint *ldD, const double *E, const blasint *ldE,
21     double *F, const blasint *ldF,
22     double *scale, double *dif,
23     double *Work, const blasint *lWork, blasint *iWork, blasint *info
24 ) {
25 
26     // Parse arguments
27     const blasint notran = LAPACK(lsame)(trans, "N");
28     const blasint tran = LAPACK(lsame)(trans, "C");
29 
30     // Compute work buffer size
31     blasint lwmin = 1;
32     if (notran && (*ijob == 1 || *ijob == 2))
33         lwmin = MAX(1, 2 * *m * *n);
34     *info = 0;
35 
36     // Check arguments
37     if (!tran && !notran)
38         *info = -1;
39     else if (notran && (*ijob < 0 || *ijob > 4))
40         *info = -2;
41     else if (*m <= 0)
42         *info = -3;
43     else if (*n <= 0)
44         *info = -4;
45     else if (*ldA < MAX(1, *m))
46         *info = -6;
47     else if (*ldB < MAX(1, *n))
48         *info = -8;
49     else if (*ldC < MAX(1, *m))
50         *info = -10;
51     else if (*ldD < MAX(1, *m))
52         *info = -12;
53     else if (*ldE < MAX(1, *n))
54         *info = -14;
55     else if (*ldF < MAX(1, *m))
56         *info = -16;
57     else if (*lWork < lwmin && *lWork != -1)
58         *info = -20;
59     if (*info) {
60         const blasint minfo = -*info;
61         LAPACK(xerbla)("ZTGSYL", &minfo, strlen("ZTGSYL"));
62         return;
63     }
64 
65     if (*lWork == -1) {
66         // Work size query
67         *Work = lwmin;
68         return;
69     }
70 
71     // Clean char * arguments
72     const char cleantrans = notran ? 'N' : 'C';
73 
74     // Constant
75     const double ZERO[] = { 0., 0. };
76 
77     blasint isolve = 1;
78     blasint ifunc  = 0;
79     if (notran) {
80         if (*ijob >= 3) {
81             ifunc = *ijob - 2;
82             LAPACK(zlaset)("F", m, n, ZERO, ZERO, C, ldC);
83             LAPACK(zlaset)("F", m, n, ZERO, ZERO, F, ldF);
84         } else if (*ijob >= 1)
85             isolve = 2;
86     }
87 
88     double scale2;
89     blasint iround;
90     for (iround = 1; iround <= isolve; iround++) {
91         *scale = 1;
92         double dscale = 0;
93         double dsum   = 1;
94         RELAPACK_ztgsyl_rec(&cleantrans, &ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, &dsum, &dscale, info);
95         if (dscale != 0) {
96             if (*ijob == 1 || *ijob == 3)
97                 *dif = sqrt(2 * *m * *n) / (dscale * sqrt(dsum));
98             else
99                 *dif = sqrt(*m * *n) / (dscale * sqrt(dsum));
100         }
101         if (isolve == 2) {
102             if (iround == 1) {
103                 if (notran)
104                     ifunc = *ijob;
105                 scale2 = *scale;
106                 LAPACK(zlacpy)("F", m, n, C, ldC, Work, m);
107                 LAPACK(zlacpy)("F", m, n, F, ldF, Work + 2 * *m * *n, m);
108                 LAPACK(zlaset)("F", m, n, ZERO, ZERO, C, ldC);
109                 LAPACK(zlaset)("F", m, n, ZERO, ZERO, F, ldF);
110             } else {
111                 LAPACK(zlacpy)("F", m, n, Work, m, C, ldC);
112                 LAPACK(zlacpy)("F", m, n, Work + 2 * *m * *n, m, F, ldF);
113                 *scale = scale2;
114             }
115         }
116     }
117 }
118 
119 
120 /** ztgsyl's recursive vompute kernel */
RELAPACK_ztgsyl_rec(const char * trans,const blasint * ifunc,const blasint * m,const blasint * n,const double * A,const blasint * ldA,const double * B,const blasint * ldB,double * C,const blasint * ldC,const double * D,const blasint * ldD,const double * E,const blasint * ldE,double * F,const blasint * ldF,double * scale,double * dsum,double * dscale,blasint * info)121 static void RELAPACK_ztgsyl_rec(
122     const char *trans, const blasint *ifunc, const blasint *m, const blasint *n,
123     const double *A, const blasint *ldA, const double *B, const blasint *ldB,
124     double *C, const blasint *ldC,
125     const double *D, const blasint *ldD, const double *E, const blasint *ldE,
126     double *F, const blasint *ldF,
127     double *scale, double *dsum, double *dscale,
128     blasint *info
129 ) {
130 
131     if (*m <= MAX(CROSSOVER_ZTGSYL, 1) && *n <= MAX(CROSSOVER_ZTGSYL, 1)) {
132         // Unblocked
133         LAPACK(ztgsy2)(trans, ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, dsum, dscale, info);
134         return;
135     }
136 
137     // Constants
138     const double ONE[]  = { 1., 0. };
139     const double MONE[] = { -1., 0. };
140     const blasint    iONE[] = { 1 };
141 
142     // Outputs
143     double scale1[] = { 1., 0. };
144     double scale2[] = { 1., 0. };
145     blasint    info1[]  = { 0 };
146     blasint    info2[]  = { 0 };
147 
148     if (*m > *n) {
149         // Splitting
150         const blasint m1 = ZREC_SPLIT(*m);
151         const blasint m2 = *m - m1;
152 
153         // A_TL A_TR
154         // 0    A_BR
155         const double *const A_TL = A;
156         const double *const A_TR = A + 2 * *ldA * m1;
157         const double *const A_BR = A + 2 * *ldA * m1 + 2 * m1;
158 
159         // C_T
160         // C_B
161         double *const C_T = C;
162         double *const C_B = C + 2 * m1;
163 
164         // D_TL D_TR
165         // 0    D_BR
166         const double *const D_TL = D;
167         const double *const D_TR = D + 2 * *ldD * m1;
168         const double *const D_BR = D + 2 * *ldD * m1 + 2 * m1;
169 
170         // F_T
171         // F_B
172         double *const F_T = F;
173         double *const F_B = F + 2 * m1;
174 
175         if (*trans == 'N') {
176             // recursion(A_BR, B, C_B, D_BR, E, F_B)
177             RELAPACK_ztgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale1, dsum, dscale, info1);
178             // C_T = C_T - A_TR * C_B
179             BLAS(zgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
180             // F_T = F_T - D_TR * C_B
181             BLAS(zgemm)("N", "N", &m1, n, &m2, MONE, D_TR, ldD, C_B, ldC, scale1, F_T, ldF);
182             // recursion(A_TL, B, C_T, D_TL, E, F_T)
183             RELAPACK_ztgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale2, dsum, dscale, info2);
184             // apply scale
185             if (scale2[0] != 1) {
186                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
187                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m2, n, F_B, ldF, info);
188             }
189         } else {
190             // recursion(A_TL, B, C_T, D_TL, E, F_T)
191             RELAPACK_ztgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale1, dsum, dscale, info1);
192             // apply scale
193             if (scale1[0] != 1)
194                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale1, &m2, n, F_B, ldF, info);
195             // C_B = C_B - A_TR^H * C_T
196             BLAS(zgemm)("C", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
197             // C_B = C_B - D_TR^H * F_T
198             BLAS(zgemm)("C", "N", &m2, n, &m1, MONE, D_TR, ldD, F_T, ldC, ONE, C_B, ldC);
199             // recursion(A_BR, B, C_B, D_BR, E, F_B)
200             RELAPACK_ztgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale2, dsum, dscale, info2);
201             // apply scale
202             if (scale2[0] != 1) {
203                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_T, ldC, info);
204                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m1, n, F_T, ldF, info);
205             }
206         }
207     } else {
208         // Splitting
209         const blasint n1 = ZREC_SPLIT(*n);
210         const blasint n2 = *n - n1;
211 
212         // B_TL B_TR
213         // 0    B_BR
214         const double *const B_TL = B;
215         const double *const B_TR = B + 2 * *ldB * n1;
216         const double *const B_BR = B + 2 * *ldB * n1 + 2 * n1;
217 
218         // C_L C_R
219         double *const C_L = C;
220         double *const C_R = C + 2 * *ldC * n1;
221 
222         // E_TL E_TR
223         // 0    E_BR
224         const double *const E_TL = E;
225         const double *const E_TR = E + 2 * *ldE * n1;
226         const double *const E_BR = E + 2 * *ldE * n1 + 2 * n1;
227 
228         // F_L F_R
229         double *const F_L = F;
230         double *const F_R = F + 2 * *ldF * n1;
231 
232         if (*trans == 'N') {
233             // recursion(A, B_TL, C_L, D, E_TL, F_L)
234             RELAPACK_ztgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale1, dsum, dscale, info1);
235             // C_R = C_R + F_L * B_TR
236             BLAS(zgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, B_TR, ldB, scale1, C_R, ldC);
237             // F_R = F_R + F_L * E_TR
238             BLAS(zgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, E_TR, ldE, scale1, F_R, ldF);
239             // recursion(A, B_BR, C_R, D, E_BR, F_R)
240             RELAPACK_ztgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale2, dsum, dscale, info2);
241             // apply scale
242             if (scale2[0] != 1) {
243                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
244                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n1, F_L, ldF, info);
245             }
246         } else {
247             // recursion(A, B_BR, C_R, D, E_BR, F_R)
248             RELAPACK_ztgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale1, dsum, dscale, info1);
249             // apply scale
250             if (scale1[0] != 1)
251                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale1, m, &n1, C_L, ldC, info);
252             // F_L = F_L + C_R * B_TR
253             BLAS(zgemm)("N", "C", m, &n1, &n2, ONE, C_R, ldC, B_TR, ldB, scale1, F_L, ldF);
254             // F_L = F_L + F_R * E_TR
255             BLAS(zgemm)("N", "C", m, &n1, &n2, ONE, F_R, ldF, E_TR, ldB, ONE, F_L, ldF);
256             // recursion(A, B_TL, C_L, D, E_TL, F_L)
257             RELAPACK_ztgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale2, dsum, dscale, info2);
258             // apply scale
259             if (scale2[0] != 1) {
260                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
261                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n2, F_R, ldF, info);
262             }
263         }
264     }
265 
266     *scale = scale1[0] * scale2[0];
267     *info  = info1[0] || info2[0];
268 }
269