1 #include "relapack.h"
2 #include <math.h>
3
4 static void RELAPACK_ctgsyl_rec(const char *, const blasint *, const blasint *,
5 const blasint *, const float *, const blasint *, const float *, const blasint *,
6 float *, const blasint *, const float *, const blasint *, const float *,
7 const blasint *, float *, const blasint *, float *, float *, float *, blasint *);
8
9
10 /** CTGSYL solves the generalized Sylvester equation.
11 *
12 * This routine is functionally equivalent to LAPACK's ctgsyl.
13 * For details on its interface, see
14 * http://www.netlib.org/lapack/explore-html/d7/de7/ctgsyl_8f.html
15 * */
RELAPACK_ctgsyl(const char * trans,const blasint * ijob,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dif,float * Work,const blasint * lWork,blasint * iWork,blasint * info)16 void RELAPACK_ctgsyl(
17 const char *trans, const blasint *ijob, const blasint *m, const blasint *n,
18 const float *A, const blasint *ldA, const float *B, const blasint *ldB,
19 float *C, const blasint *ldC,
20 const float *D, const blasint *ldD, const float *E, const blasint *ldE,
21 float *F, const blasint *ldF,
22 float *scale, float *dif,
23 float *Work, const blasint *lWork, blasint *iWork, blasint *info
24 ) {
25
26 // Parse arguments
27 const blasint notran = LAPACK(lsame)(trans, "N");
28 const blasint tran = LAPACK(lsame)(trans, "C");
29
30 // Compute work buffer size
31 blasint lwmin = 1;
32 if (notran && (*ijob == 1 || *ijob == 2))
33 lwmin = MAX(1, 2 * *m * *n);
34 *info = 0;
35
36 // Check arguments
37 if (!tran && !notran)
38 *info = -1;
39 else if (notran && (*ijob < 0 || *ijob > 4))
40 *info = -2;
41 else if (*m <= 0)
42 *info = -3;
43 else if (*n <= 0)
44 *info = -4;
45 else if (*ldA < MAX(1, *m))
46 *info = -6;
47 else if (*ldB < MAX(1, *n))
48 *info = -8;
49 else if (*ldC < MAX(1, *m))
50 *info = -10;
51 else if (*ldD < MAX(1, *m))
52 *info = -12;
53 else if (*ldE < MAX(1, *n))
54 *info = -14;
55 else if (*ldF < MAX(1, *m))
56 *info = -16;
57 else if (*lWork < lwmin && *lWork != -1)
58 *info = -20;
59 if (*info) {
60 const blasint minfo = -*info;
61 LAPACK(xerbla)("CTGSYL", &minfo, strlen("CTGSYL"));
62 return;
63 }
64
65 if (*lWork == -1) {
66 // Work size query
67 *Work = lwmin;
68 return;
69 }
70
71 if ( *m == 0 || *n == 0) {
72 *scale = 1.;
73 if (notran && (*ijob != 0))
74 *dif = 0.;
75 return;
76 }
77
78 // Clean char * arguments
79 const char cleantrans = notran ? 'N' : 'C';
80
81 // Constant
82 const float ZERO[] = { 0., 0. };
83
84 blasint isolve = 1;
85 blasint ifunc = 0;
86 if (notran) {
87 if (*ijob >= 3) {
88 ifunc = *ijob - 2;
89 LAPACK(claset)("F", m, n, ZERO, ZERO, C, ldC);
90 LAPACK(claset)("F", m, n, ZERO, ZERO, F, ldF);
91 } else if (*ijob >= 1)
92 isolve = 2;
93 }
94
95 float scale2;
96 blasint iround;
97 for (iround = 1; iround <= isolve; iround++) {
98 *scale = 1;
99 float dscale = 0;
100 float dsum = 1;
101 RELAPACK_ctgsyl_rec(&cleantrans, &ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, &dsum, &dscale, info);
102 if (dscale != 0) {
103 if (*ijob == 1 || *ijob == 3)
104 *dif = sqrt(2 * *m * *n) / (dscale * sqrt(dsum));
105 else
106 *dif = sqrt(*m * *n) / (dscale * sqrt(dsum));
107 }
108 if (isolve == 2) {
109 if (iround == 1) {
110 if (notran)
111 ifunc = *ijob;
112 scale2 = *scale;
113 LAPACK(clacpy)("F", m, n, C, ldC, Work, m);
114 LAPACK(clacpy)("F", m, n, F, ldF, Work + 2 * *m * *n, m);
115 LAPACK(claset)("F", m, n, ZERO, ZERO, C, ldC);
116 LAPACK(claset)("F", m, n, ZERO, ZERO, F, ldF);
117 } else {
118 LAPACK(clacpy)("F", m, n, Work, m, C, ldC);
119 LAPACK(clacpy)("F", m, n, Work + 2 * *m * *n, m, F, ldF);
120 *scale = scale2;
121 }
122 }
123 }
124 }
125
126
127 /** ctgsyl's recursive vompute kernel */
RELAPACK_ctgsyl_rec(const char * trans,const blasint * ifunc,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dsum,float * dscale,blasint * info)128 static void RELAPACK_ctgsyl_rec(
129 const char *trans, const blasint *ifunc, const blasint *m, const blasint *n,
130 const float *A, const blasint *ldA, const float *B, const blasint *ldB,
131 float *C, const blasint *ldC,
132 const float *D, const blasint *ldD, const float *E, const blasint *ldE,
133 float *F, const blasint *ldF,
134 float *scale, float *dsum, float *dscale,
135 blasint *info
136 ) {
137
138 if (*m <= MAX(CROSSOVER_CTGSYL, 1) && *n <= MAX(CROSSOVER_CTGSYL, 1)) {
139 // Unblocked
140 LAPACK(ctgsy2)(trans, ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, dsum, dscale, info);
141 return;
142 }
143
144 // Constants
145 const float ONE[] = { 1., 0. };
146 const float MONE[] = { -1., 0. };
147 const blasint iONE[] = { 1 };
148
149 // Outputs
150 float scale1[] = { 1., 0. };
151 float scale2[] = { 1., 0. };
152 blasint info1[] = { 0 };
153 blasint info2[] = { 0 };
154
155 if (*m > *n) {
156 // Splitting
157 const blasint m1 = CREC_SPLIT(*m);
158 const blasint m2 = *m - m1;
159
160 // A_TL A_TR
161 // 0 A_BR
162 const float *const A_TL = A;
163 const float *const A_TR = A + 2 * *ldA * m1;
164 const float *const A_BR = A + 2 * *ldA * m1 + 2 * m1;
165
166 // C_T
167 // C_B
168 float *const C_T = C;
169 float *const C_B = C + 2 * m1;
170
171 // D_TL D_TR
172 // 0 D_BR
173 const float *const D_TL = D;
174 const float *const D_TR = D + 2 * *ldD * m1;
175 const float *const D_BR = D + 2 * *ldD * m1 + 2 * m1;
176
177 // F_T
178 // F_B
179 float *const F_T = F;
180 float *const F_B = F + 2 * m1;
181
182 if (*trans == 'N') {
183 // recursion(A_BR, B, C_B, D_BR, E, F_B)
184 RELAPACK_ctgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale1, dsum, dscale, info1);
185 // C_T = C_T - A_TR * C_B
186 BLAS(cgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
187 // F_T = F_T - D_TR * C_B
188 BLAS(cgemm)("N", "N", &m1, n, &m2, MONE, D_TR, ldD, C_B, ldC, scale1, F_T, ldF);
189 // recursion(A_TL, B, C_T, D_TL, E, F_T)
190 RELAPACK_ctgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale2, dsum, dscale, info2);
191 // apply scale
192 if (scale2[0] != 1) {
193 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
194 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m2, n, F_B, ldF, info);
195 }
196 } else {
197 // recursion(A_TL, B, C_T, D_TL, E, F_T)
198 RELAPACK_ctgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale1, dsum, dscale, info1);
199 // apply scale
200 if (scale1[0] != 1)
201 LAPACK(clascl)("G", iONE, iONE, ONE, scale1, &m2, n, F_B, ldF, info);
202 // C_B = C_B - A_TR^H * C_T
203 BLAS(cgemm)("C", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
204 // C_B = C_B - D_TR^H * F_T
205 BLAS(cgemm)("C", "N", &m2, n, &m1, MONE, D_TR, ldD, F_T, ldC, ONE, C_B, ldC);
206 // recursion(A_BR, B, C_B, D_BR, E, F_B)
207 RELAPACK_ctgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale2, dsum, dscale, info2);
208 // apply scale
209 if (scale2[0] != 1) {
210 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_T, ldC, info);
211 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, &m1, n, F_T, ldF, info);
212 }
213 }
214 } else {
215 // Splitting
216 const blasint n1 = CREC_SPLIT(*n);
217 const blasint n2 = *n - n1;
218
219 // B_TL B_TR
220 // 0 B_BR
221 const float *const B_TL = B;
222 const float *const B_TR = B + 2 * *ldB * n1;
223 const float *const B_BR = B + 2 * *ldB * n1 + 2 * n1;
224
225 // C_L C_R
226 float *const C_L = C;
227 float *const C_R = C + 2 * *ldC * n1;
228
229 // E_TL E_TR
230 // 0 E_BR
231 const float *const E_TL = E;
232 const float *const E_TR = E + 2 * *ldE * n1;
233 const float *const E_BR = E + 2 * *ldE * n1 + 2 * n1;
234
235 // F_L F_R
236 float *const F_L = F;
237 float *const F_R = F + 2 * *ldF * n1;
238
239 if (*trans == 'N') {
240 // recursion(A, B_TL, C_L, D, E_TL, F_L)
241 RELAPACK_ctgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale1, dsum, dscale, info1);
242 // C_R = C_R + F_L * B_TR
243 BLAS(cgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, B_TR, ldB, scale1, C_R, ldC);
244 // F_R = F_R + F_L * E_TR
245 BLAS(cgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, E_TR, ldE, scale1, F_R, ldF);
246 // recursion(A, B_BR, C_R, D, E_BR, F_R)
247 RELAPACK_ctgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale2, dsum, dscale, info2);
248 // apply scale
249 if (scale2[0] != 1) {
250 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
251 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n1, F_L, ldF, info);
252 }
253 } else {
254 // recursion(A, B_BR, C_R, D, E_BR, F_R)
255 RELAPACK_ctgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale1, dsum, dscale, info1);
256 // apply scale
257 if (scale1[0] != 1)
258 LAPACK(clascl)("G", iONE, iONE, ONE, scale1, m, &n1, C_L, ldC, info);
259 // F_L = F_L + C_R * B_TR
260 BLAS(cgemm)("N", "C", m, &n1, &n2, ONE, C_R, ldC, B_TR, ldB, scale1, F_L, ldF);
261 // F_L = F_L + F_R * E_TR
262 BLAS(cgemm)("N", "C", m, &n1, &n2, ONE, F_R, ldF, E_TR, ldB, ONE, F_L, ldF);
263 // recursion(A, B_TL, C_L, D, E_TL, F_L)
264 RELAPACK_ctgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale2, dsum, dscale, info2);
265 // apply scale
266 if (scale2[0] != 1) {
267 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
268 LAPACK(clascl)("G", iONE, iONE, ONE, scale2, m, &n2, F_R, ldF, info);
269 }
270 }
271 }
272
273 *scale = scale1[0] * scale2[0];
274 *info = info1[0] || info2[0];
275 }
276