1 #include "relapack.h"
2 
3 static void RELAPACK_strsyl_rec(const char *, const char *, const blasint *,
4     const blasint *, const blasint *, const float *, const blasint *, const float *,
5     const blasint *, float *, const blasint *, float *, blasint *);
6 
7 
8 /** STRSYL solves the real Sylvester matrix equation.
9  *
10  * This routine is functionally equivalent to LAPACK's strsyl.
11  * For details on its interface, see
12  * http://www.netlib.org/lapack/explore-html/d4/d7d/strsyl_8f.html
13  * */
RELAPACK_strsyl(const char * tranA,const char * tranB,const blasint * isgn,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,float * scale,blasint * info)14 void RELAPACK_strsyl(
15     const char *tranA, const char *tranB, const blasint *isgn,
16     const blasint *m, const blasint *n,
17     const float *A, const blasint *ldA, const float *B, const blasint *ldB,
18     float *C, const blasint *ldC, float *scale,
19     blasint *info
20 ) {
21 
22     // Check arguments
23     const blasint notransA = LAPACK(lsame)(tranA, "N");
24     const blasint transA = LAPACK(lsame)(tranA, "T");
25     const blasint ctransA = LAPACK(lsame)(tranA, "C");
26     const blasint notransB = LAPACK(lsame)(tranB, "N");
27     const blasint transB = LAPACK(lsame)(tranB, "T");
28     const blasint ctransB = LAPACK(lsame)(tranB, "C");
29     *info = 0;
30     if (!transA && !ctransA && !notransA)
31         *info = -1;
32     else if (!transB && !ctransB && !notransB)
33         *info = -2;
34     else if (*isgn != 1 && *isgn != -1)
35         *info = -3;
36     else if (*m < 0)
37         *info = -4;
38     else if (*n < 0)
39         *info = -5;
40     else if (*ldA < MAX(1, *m))
41         *info = -7;
42     else if (*ldB < MAX(1, *n))
43         *info = -9;
44     else if (*ldC < MAX(1, *m))
45         *info = -11;
46     if (*info) {
47         const blasint minfo = -*info;
48         LAPACK(xerbla)("STRSYL", &minfo, strlen("STRSYL"));
49         return;
50     }
51 
52     if (*m == 0 || *n == 0) {
53         *scale = 1.;
54         return;
55     }
56 
57     // Clean char * arguments
58     const char cleantranA = notransA ? 'N' : (transA ? 'T' : 'C');
59     const char cleantranB = notransB ? 'N' : (transB ? 'T' : 'C');
60 
61     // Recursive kernel
62     RELAPACK_strsyl_rec(&cleantranA, &cleantranB, isgn, m, n, A, ldA, B, ldB, C, ldC, scale, info);
63 }
64 
65 
66 /** strsyl's recursive compute kernel */
RELAPACK_strsyl_rec(const char * tranA,const char * tranB,const blasint * isgn,const blasint * m,const blasint * n,const float * A,const blasint * ldA,const float * B,const blasint * ldB,float * C,const blasint * ldC,float * scale,blasint * info)67 static void RELAPACK_strsyl_rec(
68     const char *tranA, const char *tranB, const blasint *isgn,
69     const blasint *m, const blasint *n,
70     const float *A, const blasint *ldA, const float *B, const blasint *ldB,
71     float *C, const blasint *ldC, float *scale,
72     blasint *info
73 ) {
74 
75     if (*m <= MAX(CROSSOVER_STRSYL, 1) && *n <= MAX(CROSSOVER_STRSYL, 1)) {
76         // Unblocked
77         RELAPACK_strsyl_rec2(tranA, tranB, isgn, m, n, A, ldA, B, ldB, C, ldC, scale, info);
78         return;
79     }
80 
81     // Constants
82     const float ONE[]  = { 1. };
83     const float MONE[] = { -1. };
84     const float MSGN[] = { -*isgn };
85     const blasint   iONE[] = { 1 };
86 
87     // Outputs
88     float scale1[] = { 1. };
89     float scale2[] = { 1. };
90     blasint   info1[]  = { 0 };
91     blasint   info2[]  = { 0 };
92 
93     if (*m > *n) {
94         // Splitting
95         blasint m1 = SREC_SPLIT(*m);
96         if (A[m1 + *ldA * (m1 - 1)])
97             m1++;
98         const blasint m2 = *m - m1;
99 
100         // A_TL A_TR
101         // 0    A_BR
102         const float *const A_TL = A;
103         const float *const A_TR = A + *ldA * m1;
104         const float *const A_BR = A + *ldA * m1 + m1;
105 
106         // C_T
107         // C_B
108         float *const C_T = C;
109         float *const C_B = C + m1;
110 
111         if (*tranA == 'N') {
112             // recusion(A_BR, B, C_B)
113             RELAPACK_strsyl_rec(tranA, tranB, isgn, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, scale1, info1);
114             // C_T = C_T - A_TR * C_B
115             BLAS(sgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
116             // recusion(A_TL, B, C_T)
117             RELAPACK_strsyl_rec(tranA, tranB, isgn, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, scale2, info2);
118             // apply scale
119             if (scale2[0] != 1)
120                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
121         } else {
122             // recusion(A_TL, B, C_T)
123             RELAPACK_strsyl_rec(tranA, tranB, isgn, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, scale1, info1);
124             // C_B = C_B - A_TR' * C_T
125             BLAS(sgemm)("C", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
126             // recusion(A_BR, B, C_B)
127             RELAPACK_strsyl_rec(tranA, tranB, isgn, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, scale2, info2);
128             // apply scale
129             if (scale2[0] != 1)
130                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_B, ldC, info);
131         }
132     } else {
133         // Splitting
134         blasint n1 = SREC_SPLIT(*n);
135         if (B[n1 + *ldB * (n1 - 1)])
136             n1++;
137         const blasint n2 = *n - n1;
138 
139         // B_TL B_TR
140         // 0    B_BR
141         const float *const B_TL = B;
142         const float *const B_TR = B + *ldB * n1;
143         const float *const B_BR = B + *ldB * n1 + n1;
144 
145         // C_L C_R
146         float *const C_L = C;
147         float *const C_R = C + *ldC * n1;
148 
149         if (*tranB == 'N') {
150             // recusion(A, B_TL, C_L)
151             RELAPACK_strsyl_rec(tranA, tranB, isgn, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, scale1, info1);
152             // C_R = C_R -/+ C_L * B_TR
153             BLAS(sgemm)("N", "N", m, &n2, &n1, MSGN, C_L, ldC, B_TR, ldB, scale1, C_R, ldC);
154             // recusion(A, B_BR, C_R)
155             RELAPACK_strsyl_rec(tranA, tranB, isgn, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, scale2, info2);
156             // apply scale
157             if (scale2[0] != 1)
158                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
159         } else {
160             // recusion(A, B_BR, C_R)
161             RELAPACK_strsyl_rec(tranA, tranB, isgn, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, scale1, info1);
162             // C_L = C_L -/+ C_R * B_TR'
163             BLAS(sgemm)("N", "C", m, &n1, &n2, MSGN, C_R, ldC, B_TR, ldB, scale1, C_L, ldC);
164             // recusion(A, B_TL, C_L)
165             RELAPACK_strsyl_rec(tranA, tranB, isgn, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, scale2, info2);
166             // apply scale
167             if (scale2[0] != 1)
168                 LAPACK(slascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
169         }
170     }
171 
172     *scale = scale1[0] * scale2[0];
173     *info  = info1[0] || info2[0];
174 }
175