1 #include "relapack.h"
2 #include <math.h>
3
4 static void RELAPACK_stgsyl_rec(const char *, const blasint *, const blasint *,
5 const blasint *, const float *, const blasint *, const float *, const blasint *,
6 float *, const blasint *, const float *, const blasint *, const float *,
7 const blasint *, float *, const blasint *, float *, float *, float *, blasint *, blasint *,
8 blasint *);
9
10
11 /** STGSYL solves the generalized Sylvester equation.
12 *
13 * This routine is functionally equivalent to LAPACK's stgsyl.
14 * For details on its interface, see
15 * http://www.netlib.org/lapack/explore-html/dc/d67/stgsyl_8f.html
16 * */
RELAPACK_stgsyl(const char * trans,const blasint * ijob,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dif,float * Work,const blasint * lWork,blasint * iWork,blasint * info)17 void RELAPACK_stgsyl(
18 const char *trans, const blasint *ijob, const blasint *m, const blasint *n,
19 const float *A, const blasint *ldA, const float *B, const blasint *ldB,
20 float *C, const blasint *ldC,
21 const float *D, const blasint *ldD, const float *E, const blasint *ldE,
22 float *F, const blasint *ldF,
23 float *scale, float *dif,
24 float *Work, const blasint *lWork, blasint *iWork, blasint *info
25 ) {
26
27 // Parse arguments
28 const blasint notran = LAPACK(lsame)(trans, "N");
29 const blasint tran = LAPACK(lsame)(trans, "T");
30
31 // Compute work buffer size
32 blasint lwmin = 1;
33 if (notran && (*ijob == 1 || *ijob == 2))
34 lwmin = MAX(1, 2 * *m * *n);
35 *info = 0;
36
37 // Check arguments
38 if (!tran && !notran)
39 *info = -1;
40 else if (notran && (*ijob < 0 || *ijob > 4))
41 *info = -2;
42 else if (*m <= 0)
43 *info = -3;
44 else if (*n <= 0)
45 *info = -4;
46 else if (*ldA < MAX(1, *m))
47 *info = -6;
48 else if (*ldB < MAX(1, *n))
49 *info = -8;
50 else if (*ldC < MAX(1, *m))
51 *info = -10;
52 else if (*ldD < MAX(1, *m))
53 *info = -12;
54 else if (*ldE < MAX(1, *n))
55 *info = -14;
56 else if (*ldF < MAX(1, *m))
57 *info = -16;
58 else if (*lWork < lwmin && *lWork != -1)
59 *info = -20;
60 if (*info) {
61 const blasint minfo = -*info;
62 LAPACK(xerbla)("STGSYL", &minfo, strlen("STGSYL"));
63 return;
64 }
65
66 if (*lWork == -1) {
67 // Work size query
68 *Work = lwmin;
69 return;
70 }
71
72 // Clean char * arguments
73 const char cleantrans = notran ? 'N' : 'T';
74
75 // Constant
76 const float ZERO[] = { 0. };
77
78 blasint isolve = 1;
79 blasint ifunc = 0;
80 if (notran) {
81 if (*ijob >= 3) {
82 ifunc = *ijob - 2;
83 LAPACK(slaset)("F", m, n, ZERO, ZERO, C, ldC);
84 LAPACK(slaset)("F", m, n, ZERO, ZERO, F, ldF);
85 } else if (*ijob >= 1)
86 isolve = 2;
87 }
88
89 float scale2;
90 blasint iround;
91 for (iround = 1; iround <= isolve; iround++) {
92 *scale = 1;
93 float dscale = 0;
94 float dsum = 1;
95 blasint pq;
96 RELAPACK_stgsyl_rec(&cleantrans, &ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, &dsum, &dscale, iWork, &pq, info);
97 if (dscale != 0) {
98 if (*ijob == 1 || *ijob == 3)
99 *dif = sqrt(2 * *m * *n) / (dscale * sqrt(dsum));
100 else
101 *dif = sqrt(pq) / (dscale * sqrt(dsum));
102 }
103 if (isolve == 2) {
104 if (iround == 1) {
105 if (notran)
106 ifunc = *ijob;
107 scale2 = *scale;
108 LAPACK(slacpy)("F", m, n, C, ldC, Work, m);
109 LAPACK(slacpy)("F", m, n, F, ldF, Work + *m * *n, m);
110 LAPACK(slaset)("F", m, n, ZERO, ZERO, C, ldC);
111 LAPACK(slaset)("F", m, n, ZERO, ZERO, F, ldF);
112 } else {
113 LAPACK(slacpy)("F", m, n, Work, m, C, ldC);
114 LAPACK(slacpy)("F", m, n, Work + *m * *n, m, F, ldF);
115 *scale = scale2;
116 }
117 }
118 }
119 }
120
121
122 /** stgsyl's recursive vompute kernel */
RELAPACK_stgsyl_rec(const char * trans,const blasint * ifunc,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dsum,float * dscale,blasint * iWork,blasint * pq,blasint * info)123 static void RELAPACK_stgsyl_rec(
124 const char *trans, const blasint *ifunc, const blasint *m, const blasint *n,
125 const float *A, const blasint *ldA, const float *B, const blasint *ldB,
126 float *C, const blasint *ldC,
127 const float *D, const blasint *ldD, const float *E, const blasint *ldE,
128 float *F, const blasint *ldF,
129 float *scale, float *dsum, float *dscale,
130 blasint *iWork, blasint *pq, blasint *info
131 ) {
132
133 if (*m <= MAX(CROSSOVER_STGSYL, 1) && *n <= MAX(CROSSOVER_STGSYL, 1)) {
134 // Unblocked
135 LAPACK(stgsy2)(trans, ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, dsum, dscale, iWork, pq, info);
136 return;
137 }
138
139 // Constants
140 const float ONE[] = { 1. };
141 const float MONE[] = { -1. };
142 const blasint iONE[] = { 1 };
143
144 // Outputs
145 float scale1[] = { 1. };
146 float scale2[] = { 1. };
147 blasint info1[] = { 0 };
148 blasint info2[] = { 0 };
149
150 if (*m > *n) {
151 // Splitting
152 blasint m1 = SREC_SPLIT(*m);
153 if (A[m1 + *ldA * (m1 - 1)])
154 m1++;
155 const blasint m2 = *m - m1;
156
157 // A_TL A_TR
158 // 0 A_BR
159 const float *const A_TL = A;
160 const float *const A_TR = A + *ldA * m1;
161 const float *const A_BR = A + *ldA * m1 + m1;
162
163 // C_T
164 // C_B
165 float *const C_T = C;
166 float *const C_B = C + m1;
167
168 // D_TL D_TR
169 // 0 D_BR
170 const float *const D_TL = D;
171 const float *const D_TR = D + *ldD * m1;
172 const float *const D_BR = D + *ldD * m1 + m1;
173
174 // F_T
175 // F_B
176 float *const F_T = F;
177 float *const F_B = F + m1;
178
179 if (*trans == 'N') {
180 // recursion(A_BR, B, C_B, D_BR, E, F_B)
181 RELAPACK_stgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale1, dsum, dscale, iWork, pq, info1);
182 // C_T = C_T - A_TR * C_B
183 BLAS(sgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
184 // F_T = F_T - D_TR * C_B
185 BLAS(sgemm)("N", "N", &m1, n, &m2, MONE, D_TR, ldD, C_B, ldC, scale1, F_T, ldF);
186 // recursion(A_TL, B, C_T, D_TL, E, F_T)
187 RELAPACK_stgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale2, dsum, dscale, iWork, pq, info2);
188 // apply scale
189 if (scale2[0] != 1) {
190 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
191 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m2, n, F_B, ldF, info);
192 }
193 } else {
194 // recursion(A_TL, B, C_T, D_TL, E, F_T)
195 RELAPACK_stgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale1, dsum, dscale, iWork, pq, info1);
196 // apply scale
197 if (scale1[0] != 1)
198 LAPACK(slascl)("G", iONE, iONE, ONE, scale1, &m2, n, F_B, ldF, info);
199 // C_B = C_B - A_TR^H * C_T
200 BLAS(sgemm)("T", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
201 // C_B = C_B - D_TR^H * F_T
202 BLAS(sgemm)("T", "N", &m2, n, &m1, MONE, D_TR, ldD, F_T, ldC, ONE, C_B, ldC);
203 // recursion(A_BR, B, C_B, D_BR, E, F_B)
204 RELAPACK_stgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale2, dsum, dscale, iWork, pq, info2);
205 // apply scale
206 if (scale2[0] != 1) {
207 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_T, ldC, info);
208 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m1, n, F_T, ldF, info);
209 }
210 }
211 } else {
212 // Splitting
213 blasint n1 = SREC_SPLIT(*n);
214 if (B[n1 + *ldB * (n1 - 1)])
215 n1++;
216 const blasint n2 = *n - n1;
217
218 // B_TL B_TR
219 // 0 B_BR
220 const float *const B_TL = B;
221 const float *const B_TR = B + *ldB * n1;
222 const float *const B_BR = B + *ldB * n1 + n1;
223
224 // C_L C_R
225 float *const C_L = C;
226 float *const C_R = C + *ldC * n1;
227
228 // E_TL E_TR
229 // 0 E_BR
230 const float *const E_TL = E;
231 const float *const E_TR = E + *ldE * n1;
232 const float *const E_BR = E + *ldE * n1 + n1;
233
234 // F_L F_R
235 float *const F_L = F;
236 float *const F_R = F + *ldF * n1;
237
238 if (*trans == 'N') {
239 // recursion(A, B_TL, C_L, D, E_TL, F_L)
240 RELAPACK_stgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale1, dsum, dscale, iWork, pq, info1);
241 // C_R = C_R + F_L * B_TR
242 BLAS(sgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, B_TR, ldB, scale1, C_R, ldC);
243 // F_R = F_R + F_L * E_TR
244 BLAS(sgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, E_TR, ldE, scale1, F_R, ldF);
245 // recursion(A, B_BR, C_R, D, E_BR, F_R)
246 RELAPACK_stgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale2, dsum, dscale, iWork, pq, info2);
247 // apply scale
248 if (scale2[0] != 1) {
249 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
250 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n1, F_L, ldF, info);
251 }
252 } else {
253 // recursion(A, B_BR, C_R, D, E_BR, F_R)
254 RELAPACK_stgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale1, dsum, dscale, iWork, pq, info1);
255 // apply scale
256 if (scale1[0] != 1)
257 LAPACK(slascl)("G", iONE, iONE, ONE, scale1, m, &n1, C_L, ldC, info);
258 // F_L = F_L + C_R * B_TR
259 BLAS(sgemm)("N", "T", m, &n1, &n2, ONE, C_R, ldC, B_TR, ldB, scale1, F_L, ldF);
260 // F_L = F_L + F_R * E_TR
261 BLAS(sgemm)("N", "T", m, &n1, &n2, ONE, F_R, ldF, E_TR, ldB, ONE, F_L, ldF);
262 // recursion(A, B_TL, C_L, D, E_TL, F_L)
263 RELAPACK_stgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale2, dsum, dscale, iWork, pq, info2);
264 // apply scale
265 if (scale2[0] != 1) {
266 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
267 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n2, F_R, ldF, info);
268 }
269 }
270 }
271
272 *scale = scale1[0] * scale2[0];
273 *info = info1[0] || info2[0];
274 }
275