1 #include "relapack.h"
2
3 static void RELAPACK_cgemmt_rec(const char *, const char *, const char *,
4 const blasint *, const blasint *, const float *, const float *, const blasint *,
5 const float *, const blasint *, const float *, float *, const blasint *);
6
7 static void RELAPACK_cgemmt_rec2(const char *, const char *, const char *,
8 const blasint *, const blasint *, const float *, const float *, const blasint *,
9 const float *, const blasint *, const float *, float *, const blasint *);
10
11
12 /** CGEMMT computes a matrix-matrix product with general matrices but updates
13 * only the upper or lower triangular part of the result matrix.
14 *
15 * This routine performs the same operation as the BLAS routine
16 * cgemm(transA, transB, n, n, k, alpha, A, ldA, B, ldB, beta, C, ldC)
17 * but only updates the triangular part of C specified by uplo:
18 * If (*uplo == 'L'), only the lower triangular part of C is updated,
19 * otherwise the upper triangular part is updated.
20 * */
RELAPACK_cgemmt(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const float * alpha,const float * A,const blasint * ldA,const float * B,const blasint * ldB,const float * beta,float * C,const blasint * ldC)21 void RELAPACK_cgemmt(
22 const char *uplo, const char *transA, const char *transB,
23 const blasint *n, const blasint *k,
24 const float *alpha, const float *A, const blasint *ldA,
25 const float *B, const blasint *ldB,
26 const float *beta, float *C, const blasint *ldC
27 ) {
28
29 #if HAVE_XGEMMT
30 BLAS(cgemmt)(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
31 return;
32 #else
33
34 // Check arguments
35 const blasint lower = LAPACK(lsame)(uplo, "L");
36 const blasint upper = LAPACK(lsame)(uplo, "U");
37 const blasint notransA = LAPACK(lsame)(transA, "N");
38 const blasint tranA = LAPACK(lsame)(transA, "T");
39 const blasint ctransA = LAPACK(lsame)(transA, "C");
40 const blasint notransB = LAPACK(lsame)(transB, "N");
41 const blasint tranB = LAPACK(lsame)(transB, "T");
42 const blasint ctransB = LAPACK(lsame)(transB, "C");
43 blasint info = 0;
44 if (!lower && !upper)
45 info = 1;
46 else if (!tranA && !ctransA && !notransA)
47 info = 2;
48 else if (!tranB && !ctransB && !notransB)
49 info = 3;
50 else if (*n < 0)
51 info = 4;
52 else if (*k < 0)
53 info = 5;
54 else if (*ldA < MAX(1, notransA ? *n : *k))
55 info = 8;
56 else if (*ldB < MAX(1, notransB ? *k : *n))
57 info = 10;
58 else if (*ldC < MAX(1, *n))
59 info = 13;
60 if (info) {
61 LAPACK(xerbla)("CGEMMT", &info, strlen("CGEMMT"));
62 return;
63 }
64
65 // Clean char * arguments
66 const char cleanuplo = lower ? 'L' : 'U';
67 const char cleantransA = notransA ? 'N' : (tranA ? 'T' : 'C');
68 const char cleantransB = notransB ? 'N' : (tranB ? 'T' : 'C');
69
70 // Recursive kernel
71 RELAPACK_cgemmt_rec(&cleanuplo, &cleantransA, &cleantransB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
72 #endif
73 }
74
75
76 /** cgemmt's recursive compute kernel */
RELAPACK_cgemmt_rec(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const float * alpha,const float * A,const blasint * ldA,const float * B,const blasint * ldB,const float * beta,float * C,const blasint * ldC)77 static void RELAPACK_cgemmt_rec(
78 const char *uplo, const char *transA, const char *transB,
79 const blasint *n, const blasint *k,
80 const float *alpha, const float *A, const blasint *ldA,
81 const float *B, const blasint *ldB,
82 const float *beta, float *C, const blasint *ldC
83 ) {
84
85 if (*n <= MAX(CROSSOVER_CGEMMT, 1)) {
86 // Unblocked
87 RELAPACK_cgemmt_rec2(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
88 return;
89 }
90
91 // Splitting
92 const blasint n1 = CREC_SPLIT(*n);
93 const blasint n2 = *n - n1;
94
95 // A_T
96 // A_B
97 const float *const A_T = A;
98 const float *const A_B = A + 2 * ((*transA == 'N') ? n1 : *ldA * n1);
99
100 // B_L B_R
101 const float *const B_L = B;
102 const float *const B_R = B + 2 * ((*transB == 'N') ? *ldB * n1 : n1);
103
104 // C_TL C_TR
105 // C_BL C_BR
106 float *const C_TL = C;
107 float *const C_TR = C + 2 * *ldC * n1;
108 float *const C_BL = C + 2 * n1;
109 float *const C_BR = C + 2 * *ldC * n1 + 2 * n1;
110
111 // recursion(C_TL)
112 RELAPACK_cgemmt_rec(uplo, transA, transB, &n1, k, alpha, A_T, ldA, B_L, ldB, beta, C_TL, ldC);
113
114 if (*uplo == 'L')
115 // C_BL = alpha A_B B_L + beta C_BL
116 BLAS(cgemm)(transA, transB, &n2, &n1, k, alpha, A_B, ldA, B_L, ldB, beta, C_BL, ldC);
117 else
118 // C_TR = alpha A_T B_R + beta C_TR
119 BLAS(cgemm)(transA, transB, &n1, &n2, k, alpha, A_T, ldA, B_R, ldB, beta, C_TR, ldC);
120
121 // recursion(C_BR)
122 RELAPACK_cgemmt_rec(uplo, transA, transB, &n2, k, alpha, A_B, ldA, B_R, ldB, beta, C_BR, ldC);
123 }
124
125
126 /** cgemmt's unblocked compute kernel */
RELAPACK_cgemmt_rec2(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const float * alpha,const float * A,const blasint * ldA,const float * B,const blasint * ldB,const float * beta,float * C,const blasint * ldC)127 static void RELAPACK_cgemmt_rec2(
128 const char *uplo, const char *transA, const char *transB,
129 const blasint *n, const blasint *k,
130 const float *alpha, const float *A, const blasint *ldA,
131 const float *B, const blasint *ldB,
132 const float *beta, float *C, const blasint *ldC
133 ) {
134
135 const blasint incB = (*transB == 'N') ? 1 : *ldB;
136 const blasint incC = 1;
137
138 blasint i;
139 for (i = 0; i < *n; i++) {
140 // A_0
141 // A_i
142 const float *const A_0 = A;
143 const float *const A_i = A + 2 * ((*transA == 'N') ? i : *ldA * i);
144
145 // * B_i *
146 const float *const B_i = B + 2 * ((*transB == 'N') ? *ldB * i : i);
147
148 // * C_0i *
149 // * C_ii *
150 float *const C_0i = C + 2 * *ldC * i;
151 float *const C_ii = C + 2 * *ldC * i + 2 * i;
152
153 if (*uplo == 'L') {
154 const blasint nmi = *n - i;
155 if (*transA == 'N')
156 BLAS(cgemv)(transA, &nmi, k, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
157 else
158 BLAS(cgemv)(transA, k, &nmi, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
159 } else {
160 const blasint ip1 = i + 1;
161 if (*transA == 'N')
162 BLAS(cgemv)(transA, &ip1, k, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
163 else
164 BLAS(cgemv)(transA, k, &ip1, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
165 }
166 }
167 }
168