1 #include "relapack.h"
2 #include <math.h>
3
4 static void RELAPACK_ztgsyl_rec(const char *, const blasint *, const blasint *,
5 const blasint *, const double *, const blasint *, const double *, const blasint *,
6 double *, const blasint *, const double *, const blasint *, const double *,
7 const blasint *, double *, const blasint *, double *, double *, double *, blasint *);
8
9
10 /** ZTGSYL solves the generalized Sylvester equation.
11 *
12 * This routine is functionally equivalent to LAPACK's ztgsyl.
13 * For details on its interface, see
14 * http://www.netlib.org/lapack/explore-html/db/d68/ztgsyl_8f.html
15 * */
RELAPACK_ztgsyl(const char * trans,const blasint * ijob,const blasint * m,const blasint * n,const double * A,const blasint * ldA,const double * B,const blasint * ldB,double * C,const blasint * ldC,const double * D,const blasint * ldD,const double * E,const blasint * ldE,double * F,const blasint * ldF,double * scale,double * dif,double * Work,const blasint * lWork,blasint * iWork,blasint * info)16 void RELAPACK_ztgsyl(
17 const char *trans, const blasint *ijob, const blasint *m, const blasint *n,
18 const double *A, const blasint *ldA, const double *B, const blasint *ldB,
19 double *C, const blasint *ldC,
20 const double *D, const blasint *ldD, const double *E, const blasint *ldE,
21 double *F, const blasint *ldF,
22 double *scale, double *dif,
23 double *Work, const blasint *lWork, blasint *iWork, blasint *info
24 ) {
25
26 // Parse arguments
27 const blasint notran = LAPACK(lsame)(trans, "N");
28 const blasint tran = LAPACK(lsame)(trans, "C");
29
30 // Compute work buffer size
31 blasint lwmin = 1;
32 if (notran && (*ijob == 1 || *ijob == 2))
33 lwmin = MAX(1, 2 * *m * *n);
34 *info = 0;
35
36 // Check arguments
37 if (!tran && !notran)
38 *info = -1;
39 else if (notran && (*ijob < 0 || *ijob > 4))
40 *info = -2;
41 else if (*m <= 0)
42 *info = -3;
43 else if (*n <= 0)
44 *info = -4;
45 else if (*ldA < MAX(1, *m))
46 *info = -6;
47 else if (*ldB < MAX(1, *n))
48 *info = -8;
49 else if (*ldC < MAX(1, *m))
50 *info = -10;
51 else if (*ldD < MAX(1, *m))
52 *info = -12;
53 else if (*ldE < MAX(1, *n))
54 *info = -14;
55 else if (*ldF < MAX(1, *m))
56 *info = -16;
57 else if (*lWork < lwmin && *lWork != -1)
58 *info = -20;
59 if (*info) {
60 const blasint minfo = -*info;
61 LAPACK(xerbla)("ZTGSYL", &minfo, strlen("ZTGSYL"));
62 return;
63 }
64
65 if (*lWork == -1) {
66 // Work size query
67 *Work = lwmin;
68 return;
69 }
70
71 // Clean char * arguments
72 const char cleantrans = notran ? 'N' : 'C';
73
74 // Constant
75 const double ZERO[] = { 0., 0. };
76
77 blasint isolve = 1;
78 blasint ifunc = 0;
79 if (notran) {
80 if (*ijob >= 3) {
81 ifunc = *ijob - 2;
82 LAPACK(zlaset)("F", m, n, ZERO, ZERO, C, ldC);
83 LAPACK(zlaset)("F", m, n, ZERO, ZERO, F, ldF);
84 } else if (*ijob >= 1)
85 isolve = 2;
86 }
87
88 double scale2;
89 blasint iround;
90 for (iround = 1; iround <= isolve; iround++) {
91 *scale = 1;
92 double dscale = 0;
93 double dsum = 1;
94 RELAPACK_ztgsyl_rec(&cleantrans, &ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, &dsum, &dscale, info);
95 if (dscale != 0) {
96 if (*ijob == 1 || *ijob == 3)
97 *dif = sqrt(2 * *m * *n) / (dscale * sqrt(dsum));
98 else
99 *dif = sqrt(*m * *n) / (dscale * sqrt(dsum));
100 }
101 if (isolve == 2) {
102 if (iround == 1) {
103 if (notran)
104 ifunc = *ijob;
105 scale2 = *scale;
106 LAPACK(zlacpy)("F", m, n, C, ldC, Work, m);
107 LAPACK(zlacpy)("F", m, n, F, ldF, Work + 2 * *m * *n, m);
108 LAPACK(zlaset)("F", m, n, ZERO, ZERO, C, ldC);
109 LAPACK(zlaset)("F", m, n, ZERO, ZERO, F, ldF);
110 } else {
111 LAPACK(zlacpy)("F", m, n, Work, m, C, ldC);
112 LAPACK(zlacpy)("F", m, n, Work + 2 * *m * *n, m, F, ldF);
113 *scale = scale2;
114 }
115 }
116 }
117 }
118
119
120 /** ztgsyl's recursive vompute kernel */
RELAPACK_ztgsyl_rec(const char * trans,const blasint * ifunc,const blasint * m,const blasint * n,const double * A,const blasint * ldA,const double * B,const blasint * ldB,double * C,const blasint * ldC,const double * D,const blasint * ldD,const double * E,const blasint * ldE,double * F,const blasint * ldF,double * scale,double * dsum,double * dscale,blasint * info)121 static void RELAPACK_ztgsyl_rec(
122 const char *trans, const blasint *ifunc, const blasint *m, const blasint *n,
123 const double *A, const blasint *ldA, const double *B, const blasint *ldB,
124 double *C, const blasint *ldC,
125 const double *D, const blasint *ldD, const double *E, const blasint *ldE,
126 double *F, const blasint *ldF,
127 double *scale, double *dsum, double *dscale,
128 blasint *info
129 ) {
130
131 if (*m <= MAX(CROSSOVER_ZTGSYL, 1) && *n <= MAX(CROSSOVER_ZTGSYL, 1)) {
132 // Unblocked
133 LAPACK(ztgsy2)(trans, ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, dsum, dscale, info);
134 return;
135 }
136
137 // Constants
138 const double ONE[] = { 1., 0. };
139 const double MONE[] = { -1., 0. };
140 const blasint iONE[] = { 1 };
141
142 // Outputs
143 double scale1[] = { 1., 0. };
144 double scale2[] = { 1., 0. };
145 blasint info1[] = { 0 };
146 blasint info2[] = { 0 };
147
148 if (*m > *n) {
149 // Splitting
150 const blasint m1 = ZREC_SPLIT(*m);
151 const blasint m2 = *m - m1;
152
153 // A_TL A_TR
154 // 0 A_BR
155 const double *const A_TL = A;
156 const double *const A_TR = A + 2 * *ldA * m1;
157 const double *const A_BR = A + 2 * *ldA * m1 + 2 * m1;
158
159 // C_T
160 // C_B
161 double *const C_T = C;
162 double *const C_B = C + 2 * m1;
163
164 // D_TL D_TR
165 // 0 D_BR
166 const double *const D_TL = D;
167 const double *const D_TR = D + 2 * *ldD * m1;
168 const double *const D_BR = D + 2 * *ldD * m1 + 2 * m1;
169
170 // F_T
171 // F_B
172 double *const F_T = F;
173 double *const F_B = F + 2 * m1;
174
175 if (*trans == 'N') {
176 // recursion(A_BR, B, C_B, D_BR, E, F_B)
177 RELAPACK_ztgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale1, dsum, dscale, info1);
178 // C_T = C_T - A_TR * C_B
179 BLAS(zgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
180 // F_T = F_T - D_TR * C_B
181 BLAS(zgemm)("N", "N", &m1, n, &m2, MONE, D_TR, ldD, C_B, ldC, scale1, F_T, ldF);
182 // recursion(A_TL, B, C_T, D_TL, E, F_T)
183 RELAPACK_ztgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale2, dsum, dscale, info2);
184 // apply scale
185 if (scale2[0] != 1) {
186 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
187 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m2, n, F_B, ldF, info);
188 }
189 } else {
190 // recursion(A_TL, B, C_T, D_TL, E, F_T)
191 RELAPACK_ztgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale1, dsum, dscale, info1);
192 // apply scale
193 if (scale1[0] != 1)
194 LAPACK(zlascl)("G", iONE, iONE, ONE, scale1, &m2, n, F_B, ldF, info);
195 // C_B = C_B - A_TR^H * C_T
196 BLAS(zgemm)("C", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
197 // C_B = C_B - D_TR^H * F_T
198 BLAS(zgemm)("C", "N", &m2, n, &m1, MONE, D_TR, ldD, F_T, ldC, ONE, C_B, ldC);
199 // recursion(A_BR, B, C_B, D_BR, E, F_B)
200 RELAPACK_ztgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale2, dsum, dscale, info2);
201 // apply scale
202 if (scale2[0] != 1) {
203 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_T, ldC, info);
204 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m1, n, F_T, ldF, info);
205 }
206 }
207 } else {
208 // Splitting
209 const blasint n1 = ZREC_SPLIT(*n);
210 const blasint n2 = *n - n1;
211
212 // B_TL B_TR
213 // 0 B_BR
214 const double *const B_TL = B;
215 const double *const B_TR = B + 2 * *ldB * n1;
216 const double *const B_BR = B + 2 * *ldB * n1 + 2 * n1;
217
218 // C_L C_R
219 double *const C_L = C;
220 double *const C_R = C + 2 * *ldC * n1;
221
222 // E_TL E_TR
223 // 0 E_BR
224 const double *const E_TL = E;
225 const double *const E_TR = E + 2 * *ldE * n1;
226 const double *const E_BR = E + 2 * *ldE * n1 + 2 * n1;
227
228 // F_L F_R
229 double *const F_L = F;
230 double *const F_R = F + 2 * *ldF * n1;
231
232 if (*trans == 'N') {
233 // recursion(A, B_TL, C_L, D, E_TL, F_L)
234 RELAPACK_ztgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale1, dsum, dscale, info1);
235 // C_R = C_R + F_L * B_TR
236 BLAS(zgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, B_TR, ldB, scale1, C_R, ldC);
237 // F_R = F_R + F_L * E_TR
238 BLAS(zgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, E_TR, ldE, scale1, F_R, ldF);
239 // recursion(A, B_BR, C_R, D, E_BR, F_R)
240 RELAPACK_ztgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale2, dsum, dscale, info2);
241 // apply scale
242 if (scale2[0] != 1) {
243 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
244 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n1, F_L, ldF, info);
245 }
246 } else {
247 // recursion(A, B_BR, C_R, D, E_BR, F_R)
248 RELAPACK_ztgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale1, dsum, dscale, info1);
249 // apply scale
250 if (scale1[0] != 1)
251 LAPACK(zlascl)("G", iONE, iONE, ONE, scale1, m, &n1, C_L, ldC, info);
252 // F_L = F_L + C_R * B_TR
253 BLAS(zgemm)("N", "C", m, &n1, &n2, ONE, C_R, ldC, B_TR, ldB, scale1, F_L, ldF);
254 // F_L = F_L + F_R * E_TR
255 BLAS(zgemm)("N", "C", m, &n1, &n2, ONE, F_R, ldF, E_TR, ldB, ONE, F_L, ldF);
256 // recursion(A, B_TL, C_L, D, E_TL, F_L)
257 RELAPACK_ztgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale2, dsum, dscale, info2);
258 // apply scale
259 if (scale2[0] != 1) {
260 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
261 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n2, F_R, ldF, info);
262 }
263 }
264 }
265
266 *scale = scale1[0] * scale2[0];
267 *info = info1[0] || info2[0];
268 }
269