1 #include "relapack.h"
2 
3 static void RELAPACK_ztrsyl_rec(const char *, const char *, const blasint *,
4     const blasint *, const blasint *, const double *, const blasint *, const double *,
5     const blasint *, double *, const blasint *, double *, blasint *);
6 
7 
8 /** ZTRSYL solves the complex Sylvester matrix equation.
9  *
10  * This routine is functionally equivalent to LAPACK's ztrsyl.
11  * For details on its interface, see
12  * http://www.netlib.org/lapack/explore-html/d1/d36/ztrsyl_8f.html
13  * */
RELAPACK_ztrsyl(const char * tranA,const char * tranB,const blasint * isgn,const blasint * m,const blasint * n,const double * A,const blasint * ldA,const double * B,const blasint * ldB,double * C,const blasint * ldC,double * scale,blasint * info)14 void RELAPACK_ztrsyl(
15     const char *tranA, const char *tranB, const blasint *isgn,
16     const blasint *m, const blasint *n,
17     const double *A, const blasint *ldA, const double *B, const blasint *ldB,
18     double *C, const blasint *ldC, double *scale,
19     blasint *info
20 ) {
21 
22     // Check arguments
23     const blasint notransA = LAPACK(lsame)(tranA, "N");
24     const blasint ctransA = LAPACK(lsame)(tranA, "C");
25     const blasint notransB = LAPACK(lsame)(tranB, "N");
26     const blasint ctransB = LAPACK(lsame)(tranB, "C");
27     *info = 0;
28     if (!ctransA && !notransA)
29         *info = -1;
30     else if (!ctransB && !notransB)
31         *info = -2;
32     else if (*isgn != 1 && *isgn != -1)
33         *info = -3;
34     else if (*m < 0)
35         *info = -4;
36     else if (*n < 0)
37         *info = -5;
38     else if (*ldA < MAX(1, *m))
39         *info = -7;
40     else if (*ldB < MAX(1, *n))
41         *info = -9;
42     else if (*ldC < MAX(1, *m))
43         *info = -11;
44     if (*info) {
45         const blasint minfo = -*info;
46         LAPACK(xerbla)("ZTRSYL", &minfo, strlen("ZTRSYL"));
47         return;
48     }
49 
50     if (*m == 0 || *n == 0) {
51         *scale = 1.;
52         return;
53     }
54 
55     // Clean char * arguments
56     const char cleantranA = notransA ? 'N' : 'C';
57     const char cleantranB = notransB ? 'N' : 'C';
58 
59     // Recursive kernel
60     RELAPACK_ztrsyl_rec(&cleantranA, &cleantranB, isgn, m, n, A, ldA, B, ldB, C, ldC, scale, info);
61 }
62 
63 
64 /** ztrsyl's recursive compute kernel */
RELAPACK_ztrsyl_rec(const char * tranA,const char * tranB,const blasint * isgn,const blasint * m,const blasint * n,const double * A,const blasint * ldA,const double * B,const blasint * ldB,double * C,const blasint * ldC,double * scale,blasint * info)65 static void RELAPACK_ztrsyl_rec(
66     const char *tranA, const char *tranB, const blasint *isgn,
67     const blasint *m, const blasint *n,
68     const double *A, const blasint *ldA, const double *B, const blasint *ldB,
69     double *C, const blasint *ldC, double *scale,
70     blasint *info
71 ) {
72 
73     if (*m <= MAX(CROSSOVER_ZTRSYL, 1) && *n <= MAX(CROSSOVER_ZTRSYL, 1)) {
74         // Unblocked
75         RELAPACK_ztrsyl_rec2(tranA, tranB, isgn, m, n, A, ldA, B, ldB, C, ldC, scale, info);
76         return;
77     }
78 
79     // Constants
80     const double ONE[]  = { 1., 0. };
81     const double MONE[] = { -1., 0. };
82     const double MSGN[] = { -*isgn, 0. };
83     const blasint    iONE[] = { 1 };
84 
85     // Outputs
86     double scale1[] = { 1., 0. };
87     double scale2[] = { 1., 0. };
88     blasint    info1[]  = { 0 };
89     blasint    info2[]  = { 0 };
90 
91     if (*m > *n) {
92         // Splitting
93         const blasint m1 = ZREC_SPLIT(*m);
94         const blasint m2 = *m - m1;
95 
96         // A_TL A_TR
97         // 0    A_BR
98         const double *const A_TL = A;
99         const double *const A_TR = A + 2 * *ldA * m1;
100         const double *const A_BR = A + 2 * *ldA * m1 + 2 * m1;
101 
102         // C_T
103         // C_B
104         double *const C_T = C;
105         double *const C_B = C + 2 * m1;
106 
107         if (*tranA == 'N') {
108             // recusion(A_BR, B, C_B)
109             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, scale1, info1);
110             // C_T = C_T - A_TR * C_B
111             BLAS(zgemm)("N", "N", &m1, n, &m2, MONE, A_TR, ldA, C_B, ldC, scale1, C_T, ldC);
112             // recusion(A_TL, B, C_T)
113             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, scale2, info2);
114             // apply scale
115             if (scale2[0] != 1)
116                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m2, n, C_B, ldC, info);
117         } else {
118             // recusion(A_TL, B, C_T)
119             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, &m1, n, A_TL, ldA, B, ldB, C_T, ldC, scale1, info1);
120             // C_B = C_B - A_TR' * C_T
121             BLAS(zgemm)("C", "N", &m2, n, &m1, MONE, A_TR, ldA, C_T, ldC, scale1, C_B, ldC);
122             // recusion(A_BR, B, C_B)
123             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, &m2, n, A_BR, ldA, B, ldB, C_B, ldC, scale2, info2);
124             // apply scale
125             if (scale2[0] != 1)
126                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, &m1, n, C_B, ldC, info);
127         }
128     } else {
129         // Splitting
130         const blasint n1 = ZREC_SPLIT(*n);
131         const blasint n2 = *n - n1;
132 
133         // B_TL B_TR
134         // 0    B_BR
135         const double *const B_TL = B;
136         const double *const B_TR = B + 2 * *ldB * n1;
137         const double *const B_BR = B + 2 * *ldB * n1 + 2 * n1;
138 
139         // C_L C_R
140         double *const C_L = C;
141         double *const C_R = C + 2 * *ldC * n1;
142 
143         if (*tranB == 'N') {
144             // recusion(A, B_TL, C_L)
145             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, scale1, info1);
146             // C_R = C_R -/+ C_L * B_TR
147             BLAS(zgemm)("N", "N", m, &n2, &n1, MSGN, C_L, ldC, B_TR, ldB, scale1, C_R, ldC);
148             // recusion(A, B_BR, C_R)
149             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, scale2, info2);
150             // apply scale
151             if (scale2[0] != 1)
152                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n1, C_L, ldC, info);
153         } else {
154             // recusion(A, B_BR, C_R)
155             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, m, &n2, A, ldA, B_BR, ldB, C_R, ldC, scale1, info1);
156             // C_L = C_L -/+ C_R * B_TR'
157             BLAS(zgemm)("N", "C", m, &n1, &n2, MSGN, C_R, ldC, B_TR, ldB, scale1, C_L, ldC);
158             // recusion(A, B_TL, C_L)
159             RELAPACK_ztrsyl_rec(tranA, tranB, isgn, m, &n1, A, ldA, B_TL, ldB, C_L, ldC, scale2, info2);
160             // apply scale
161             if (scale2[0] != 1)
162                 LAPACK(zlascl)("G", iONE, iONE, ONE, scale2, m, &n2, C_R, ldC, info);
163         }
164     }
165 
166     *scale = scale1[0] * scale2[0];
167     *info  = info1[0] || info2[0];
168 }
169