1 #include "relapack.h"
2 #include <math.h>
3 
4 static void RELAPACK_stgsyl_rec(const char *, const blasint *, const blasint *,
5     const blasint *, const float *, const blasint *, const float *, const blasint *,
6     float *, const blasint *, const float *, const blasint *, const float *,
7     const blasint *, float *, const blasint *, float *, float *, float *, blasint *, blasint *,
8     blasint *);
9 
10 
11 /** STGSYL solves the generalized Sylvester equation.
12  *
13  * This routine is functionally equivalent to LAPACK's stgsyl.
14  * For details on its interface, see
15  * http://www.netlib.org/lapack/explore-html/dc/d67/stgsyl_8f.html
16  * */
RELAPACK_stgsyl(const char * trans,const blasint * ijob,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dif,float * Work,const blasint * lWork,blasint * iWork,blasint * info)17 void RELAPACK_stgsyl(
18     const char *trans, const blasint *ijob, const blasint *m, const blasint *n,
19     const float *A, const blasint *ldA, const float *B, const blasint *ldB,
20     float *C, const blasint *ldC,
21     const float *D, const blasint *ldD, const float *E, const blasint *ldE,
22     float *F, const blasint *ldF,
23     float *scale, float *dif,
24     float *Work, const blasint *lWork, blasint *iWork, blasint *info
25 ) {
26 
27     // Parse arguments
28     const blasint notran = LAPACK(lsame)(trans, "N");
29     const blasint tran = LAPACK(lsame)(trans, "T");
30 
31     // Compute work buffer size
32     blasint lwmin = 1;
33     if (notran && (*ijob == 1 || *ijob == 2))
34         lwmin = MAX(1, 2 * *m * *n);
35     *info = 0;
36 
37     // Check arguments
38     if (!tran && !notran)
39         *info = -1;
40     else if (notran && (*ijob < 0 || *ijob > 4))
41         *info = -2;
42     else if (*m <= 0)
43         *info = -3;
44     else if (*n <= 0)
45         *info = -4;
46     else if (*ldA < MAX(1, *m))
47         *info = -6;
48     else if (*ldB < MAX(1, *n))
49         *info = -8;
50     else if (*ldC < MAX(1, *m))
51         *info = -10;
52     else if (*ldD < MAX(1, *m))
53         *info = -12;
54     else if (*ldE < MAX(1, *n))
55         *info = -14;
56     else if (*ldF < MAX(1, *m))
57         *info = -16;
58     else if (*lWork < lwmin && *lWork != -1)
59         *info = -20;
60     if (*info) {
61         const blasint minfo = -*info;
62         LAPACK(xerbla)("STGSYL", &minfo, strlen("STGSYL"));
63         return;
64     }
65 
66     if (*lWork == -1) {
67         // Work size query
68         *Work = lwmin;
69         return;
70     }
71 
72     // Clean char * arguments
73     const char cleantrans = notran ? 'N' : 'T';
74 
75     // Constant
76     const float ZERO[] = { 0. };
77 
78     blasint isolve = 1;
79     blasint ifunc  = 0;
80     if (notran) {
81         if (*ijob >= 3) {
82             ifunc = *ijob - 2;
83             LAPACK(slaset)("F", m, n, ZERO, ZERO, C, ldC);
84             LAPACK(slaset)("F", m, n, ZERO, ZERO, F, ldF);
85         } else if (*ijob >= 1)
86             isolve = 2;
87     }
88 
89     float scale2;
90     blasint iround;
91     for (iround = 1; iround <= isolve; iround++) {
92         *scale = 1;
93         float dscale = 0;
94         float dsum   = 1;
95         blasint pq;
96         RELAPACK_stgsyl_rec(&cleantrans, &ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, &dsum, &dscale, iWork, &pq, info);
97         if (dscale != 0) {
98             if (*ijob == 1 || *ijob == 3)
99                 *dif = sqrt(2 * *m * *n) / (dscale * sqrt(dsum));
100             else
101                 *dif = sqrt(pq) / (dscale * sqrt(dsum));
102         }
103         if (isolve == 2) {
104             if (iround == 1) {
105                 if (notran)
106                     ifunc = *ijob;
107                 scale2 = *scale;
108                 LAPACK(slacpy)("F", m, n, C, ldC, Work, m);
109                 LAPACK(slacpy)("F", m, n, F, ldF, Work + *m * *n, m);
110                 LAPACK(slaset)("F", m, n, ZERO, ZERO, C, ldC);
111                 LAPACK(slaset)("F", m, n, ZERO, ZERO, F, ldF);
112             } else {
113                 LAPACK(slacpy)("F", m, n, Work, m, C, ldC);
114                 LAPACK(slacpy)("F", m, n, Work + *m * *n, m, F, ldF);
115                 *scale = scale2;
116             }
117         }
118     }
119 }
120 
121 
122 /** stgsyl's recursive vompute kernel */
RELAPACK_stgsyl_rec(const char * trans,const blasint * ifunc,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,const float * D,const blasint * ldD,const float * E,const blasint * ldE,float * F,const blasint * ldF,float * scale,float * dsum,float * dscale,blasint * iWork,blasint * pq,blasint * info)123 static void RELAPACK_stgsyl_rec(
124     const char *trans, const blasint *ifunc, const blasint *m, const blasint *n,
125     const float *A, const blasint *ldA, const float *B, const blasint *ldB,
126     float *C, const blasint *ldC,
127     const float *D, const blasint *ldD, const float *E, const blasint *ldE,
128     float *F, const blasint *ldF,
129     float *scale, float *dsum, float *dscale,
130     blasint *iWork, blasint *pq, blasint *info
131 ) {
132 
133     if (*m <= MAX(CROSSOVER_STGSYL, 1) && *n <= MAX(CROSSOVER_STGSYL, 1)) {
134         // Unblocked
135         LAPACK(stgsy2)(trans, ifunc, m, n, A, ldA, B, ldB, C, ldC, D, ldD, E, ldE, F, ldF, scale, dsum, dscale, iWork, pq, info);
136         return;
137     }
138 
139     // Constants
140     const float ONE[]  = { 1. };
141     const float MONE[] = { -1. };
142     const blasint   iONE[] = { 1 };
143 
144     // Outputs
145     float scale1[] = { 1. };
146     float scale2[] = { 1. };
147     blasint   info1[]  = { 0 };
148     blasint   info2[]  = { 0 };
149 
150     if (*m > *n) {
151         // Splitting
152         blasint m1 = SREC_SPLIT(*m);
153         if (A[m1 + *ldA * (m1 - 1)])
154             m1++;
155         const blasint m2 = *m - m1;
156 
157         // A_TL A_TR
158         // 0    A_BR
159         const float *const A_TL = A;
160         const float *const A_TR = A + *ldA * m1;
161         const float *const A_BR = A + *ldA * m1 + m1;
162 
163         // C_T
164         // C_B
165         float *const C_T = C;
166         float *const C_B = C + m1;
167 
168         // D_TL D_TR
169         // 0    D_BR
170         const float *const D_TL = D;
171         const float *const D_TR = D + *ldD * m1;
172         const float *const D_BR = D + *ldD * m1 + m1;
173 
174         // F_T
175         // F_B
176         float *const F_T = F;
177         float *const F_B = F + m1;
178 
179         if (*trans == 'N') {
180             // recursion(A_BR, B, C_B, D_BR, E, F_B)
181             RELAPACK_stgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale1, dsum, dscale, iWork, pq, info1);
182             // C_T = C_T - A_TR * C_B
183             BLAS(sgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
184             // F_T = F_T - D_TR * C_B
185             BLAS(sgemm)("N", "N", &m1, n, &m2, MONE, D_TR, ldD, C_B, ldC, scale1, F_T, ldF);
186             // recursion(A_TL, B, C_T, D_TL, E, F_T)
187             RELAPACK_stgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale2, dsum, dscale, iWork, pq, info2);
188             // apply scale
189             if (scale2[0] != 1) {
190                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
191                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m2, n, F_B, ldF, info);
192             }
193         } else {
194             // recursion(A_TL, B, C_T, D_TL, E, F_T)
195             RELAPACK_stgsyl_rec(trans, ifunc, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, D_TL, ldD, E, ldE, F_T, ldF, scale1, dsum, dscale, iWork, pq, info1);
196             // apply scale
197             if (scale1[0] != 1)
198                 LAPACK(slascl)("G", iONE, iONE, ONE, scale1, &m2, n, F_B, ldF, info);
199             // C_B = C_B - A_TR^H * C_T
200             BLAS(sgemm)("T", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
201             // C_B = C_B - D_TR^H * F_T
202             BLAS(sgemm)("T", "N", &m2, n, &m1, MONE, D_TR, ldD, F_T, ldC, ONE, C_B, ldC);
203             // recursion(A_BR, B, C_B, D_BR, E, F_B)
204             RELAPACK_stgsyl_rec(trans, ifunc, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, D_BR, ldD, E, ldE, F_B, ldF, scale2, dsum, dscale, iWork, pq, info2);
205             // apply scale
206             if (scale2[0] != 1) {
207                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_T, ldC, info);
208                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m1, n, F_T, ldF, info);
209             }
210         }
211     } else {
212         // Splitting
213         blasint n1 = SREC_SPLIT(*n);
214         if (B[n1 + *ldB * (n1 - 1)])
215             n1++;
216         const blasint n2 = *n - n1;
217 
218         // B_TL B_TR
219         // 0    B_BR
220         const float *const B_TL = B;
221         const float *const B_TR = B + *ldB * n1;
222         const float *const B_BR = B + *ldB * n1 + n1;
223 
224         // C_L C_R
225         float *const C_L = C;
226         float *const C_R = C + *ldC * n1;
227 
228         // E_TL E_TR
229         // 0    E_BR
230         const float *const E_TL = E;
231         const float *const E_TR = E + *ldE * n1;
232         const float *const E_BR = E + *ldE * n1 + n1;
233 
234         // F_L F_R
235         float *const F_L = F;
236         float *const F_R = F + *ldF * n1;
237 
238         if (*trans == 'N') {
239             // recursion(A, B_TL, C_L, D, E_TL, F_L)
240             RELAPACK_stgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale1, dsum, dscale, iWork, pq, info1);
241             // C_R = C_R + F_L * B_TR
242             BLAS(sgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, B_TR, ldB, scale1, C_R, ldC);
243             // F_R = F_R + F_L * E_TR
244             BLAS(sgemm)("N", "N", m, &n2, &n1, ONE, F_L, ldF, E_TR, ldE, scale1, F_R, ldF);
245             // recursion(A, B_BR, C_R, D, E_BR, F_R)
246             RELAPACK_stgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale2, dsum, dscale, iWork, pq, info2);
247             // apply scale
248             if (scale2[0] != 1) {
249                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
250                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n1, F_L, ldF, info);
251             }
252         } else {
253             // recursion(A, B_BR, C_R, D, E_BR, F_R)
254             RELAPACK_stgsyl_rec(trans, ifunc, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, D, ldD, E_BR, ldE, F_R, ldF, scale1, dsum, dscale, iWork, pq, info1);
255             // apply scale
256             if (scale1[0] != 1)
257                 LAPACK(slascl)("G", iONE, iONE, ONE, scale1, m, &n1, C_L, ldC, info);
258             // F_L = F_L + C_R * B_TR
259             BLAS(sgemm)("N", "T", m, &n1, &n2, ONE, C_R, ldC, B_TR, ldB, scale1, F_L, ldF);
260             // F_L = F_L + F_R * E_TR
261             BLAS(sgemm)("N", "T", m, &n1, &n2, ONE, F_R, ldF, E_TR, ldB, ONE, F_L, ldF);
262             // recursion(A, B_TL, C_L, D, E_TL, F_L)
263             RELAPACK_stgsyl_rec(trans, ifunc, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, D, ldD, E_TL, ldE, F_L, ldF, scale2, dsum, dscale, iWork, pq, info2);
264             // apply scale
265             if (scale2[0] != 1) {
266                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
267                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n2, F_R, ldF, info);
268             }
269         }
270     }
271 
272     *scale = scale1[0] * scale2[0];
273     *info  = info1[0] || info2[0];
274 }
275