1 #include "relapack.h"
2
3 static void RELAPACK_zgemmt_rec(const char *, const char *, const char *,
4 const blasint *, const blasint *, const double *, const double *, const blasint *,
5 const double *, const blasint *, const double *, double *, const blasint *);
6
7 static void RELAPACK_zgemmt_rec2(const char *, const char *, const char *,
8 const blasint *, const blasint *, const double *, const double *, const blasint *,
9 const double *, const blasint *, const double *, double *, const blasint *);
10
11
12 /** ZGEMMT computes a matrix-matrix product with general matrices but updates
13 * only the upper or lower triangular part of the result matrix.
14 *
15 * This routine performs the same operation as the BLAS routine
16 * zgemm(transA, transB, n, n, k, alpha, A, ldA, B, ldB, beta, C, ldC)
17 * but only updates the triangular part of C specified by uplo:
18 * If (*uplo == 'L'), only the lower triangular part of C is updated,
19 * otherwise the upper triangular part is updated.
20 * */
RELAPACK_zgemmt(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const double * alpha,const double * A,const blasint * ldA,const double * B,const blasint * ldB,const double * beta,double * C,const blasint * ldC)21 void RELAPACK_zgemmt(
22 const char *uplo, const char *transA, const char *transB,
23 const blasint *n, const blasint *k,
24 const double *alpha, const double *A, const blasint *ldA,
25 const double *B, const blasint *ldB,
26 const double *beta, double *C, const blasint *ldC
27 ) {
28
29 #if HAVE_XGEMMT
30 BLAS(zgemmt)(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
31 return;
32 #else
33
34 // Check arguments
35 const blasint lower = LAPACK(lsame)(uplo, "L");
36 const blasint upper = LAPACK(lsame)(uplo, "U");
37 const blasint notransA = LAPACK(lsame)(transA, "N");
38 const blasint tranA = LAPACK(lsame)(transA, "T");
39 const blasint ctransA = LAPACK(lsame)(transA, "C");
40 const blasint notransB = LAPACK(lsame)(transB, "N");
41 const blasint tranB = LAPACK(lsame)(transB, "T");
42 const blasint ctransB = LAPACK(lsame)(transB, "C");
43 blasint info = 0;
44 if (!lower && !upper)
45 info = 1;
46 else if (!tranA && !ctransA && !notransA)
47 info = 2;
48 else if (!tranB && !ctransB && !notransB)
49 info = 3;
50 else if (*n < 0)
51 info = 4;
52 else if (*k < 0)
53 info = 5;
54 else if (*ldA < MAX(1, notransA ? *n : *k))
55 info = 8;
56 else if (*ldB < MAX(1, notransB ? *k : *n))
57 info = 10;
58 else if (*ldC < MAX(1, *n))
59 info = 13;
60 if (info) {
61 LAPACK(xerbla)("ZGEMMT", &info, strlen("ZGEMMT"));
62 return;
63 }
64
65 // Clean char * arguments
66 const char cleanuplo = lower ? 'L' : 'U';
67 const char cleantransA = notransA ? 'N' : (tranA ? 'T' : 'C');
68 const char cleantransB = notransB ? 'N' : (tranB ? 'T' : 'C');
69
70 // Recursive kernel
71 RELAPACK_zgemmt_rec(&cleanuplo, &cleantransA, &cleantransB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
72 #endif
73 }
74
75
76 /** zgemmt's recursive compute kernel */
RELAPACK_zgemmt_rec(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const double * alpha,const double * A,const blasint * ldA,const double * B,const blasint * ldB,const double * beta,double * C,const blasint * ldC)77 static void RELAPACK_zgemmt_rec(
78 const char *uplo, const char *transA, const char *transB,
79 const blasint *n, const blasint *k,
80 const double *alpha, const double *A, const blasint *ldA,
81 const double *B, const blasint *ldB,
82 const double *beta, double *C, const blasint *ldC
83 ) {
84
85 if (*n <= MAX(CROSSOVER_ZGEMMT, 1)) {
86 // Unblocked
87 RELAPACK_zgemmt_rec2(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
88 return;
89 }
90
91 // Splitting
92 const blasint n1 = ZREC_SPLIT(*n);
93 const blasint n2 = *n - n1;
94
95 // A_T
96 // A_B
97 const double *const A_T = A;
98 const double *const A_B = A + 2 * ((*transA == 'N') ? n1 : *ldA * n1);
99
100 // B_L B_R
101 const double *const B_L = B;
102 const double *const B_R = B + 2 * ((*transB == 'N') ? *ldB * n1 : n1);
103
104 // C_TL C_TR
105 // C_BL C_BR
106 double *const C_TL = C;
107 double *const C_TR = C + 2 * *ldC * n1;
108 double *const C_BL = C + 2 * n1;
109 double *const C_BR = C + 2 * *ldC * n1 + 2 * n1;
110
111 // recursion(C_TL)
112 RELAPACK_zgemmt_rec(uplo, transA, transB, &n1, k, alpha, A_T, ldA, B_L, ldB, beta, C_TL, ldC);
113
114 if (*uplo == 'L')
115 // C_BL = alpha A_B B_L + beta C_BL
116 BLAS(zgemm)(transA, transB, &n2, &n1, k, alpha, A_B, ldA, B_L, ldB, beta, C_BL, ldC);
117 else
118 // C_TR = alpha A_T B_R + beta C_TR
119 BLAS(zgemm)(transA, transB, &n1, &n2, k, alpha, A_T, ldA, B_R, ldB, beta, C_TR, ldC);
120
121 // recursion(C_BR)
122 RELAPACK_zgemmt_rec(uplo, transA, transB, &n2, k, alpha, A_B, ldA, B_R, ldB, beta, C_BR, ldC);
123 }
124
125
126 /** zgemmt's unblocked compute kernel */
RELAPACK_zgemmt_rec2(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const double * alpha,const double * A,const blasint * ldA,const double * B,const blasint * ldB,const double * beta,double * C,const blasint * ldC)127 static void RELAPACK_zgemmt_rec2(
128 const char *uplo, const char *transA, const char *transB,
129 const blasint *n, const blasint *k,
130 const double *alpha, const double *A, const blasint *ldA,
131 const double *B, const blasint *ldB,
132 const double *beta, double *C, const blasint *ldC
133 ) {
134
135 const blasint incB = (*transB == 'N') ? 1 : *ldB;
136 const blasint incC = 1;
137
138 blasint i;
139 for (i = 0; i < *n; i++) {
140 // A_0
141 // A_i
142 const double *const A_0 = A;
143 const double *const A_i = A + 2 * ((*transA == 'N') ? i : *ldA * i);
144
145 // * B_i *
146 const double *const B_i = B + 2 * ((*transB == 'N') ? *ldB * i : i);
147
148 // * C_0i *
149 // * C_ii *
150 double *const C_0i = C + 2 * *ldC * i;
151 double *const C_ii = C + 2 * *ldC * i + 2 * i;
152
153 if (*uplo == 'L') {
154 const blasint nmi = *n - i;
155 if (*transA == 'N')
156 BLAS(zgemv)(transA, &nmi, k, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
157 else
158 BLAS(zgemv)(transA, k, &nmi, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
159 } else {
160 const blasint ip1 = i + 1;
161 if (*transA == 'N')
162 BLAS(zgemv)(transA, &ip1, k, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
163 else
164 BLAS(zgemv)(transA, k, &ip1, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
165 }
166 }
167 }
168