1 #include "relapack.h"
2 
3 static void RELAPACK_cgemmt_rec(const char *, const char *, const char *,
4     const blasint *, const blasint *, const float *, const float *, const blasint *,
5     const float *, const blasint *, const float *, float *, const blasint *);
6 
7 static void RELAPACK_cgemmt_rec2(const char *, const char *, const char *,
8     const blasint *, const blasint *, const float *, const float *, const blasint *,
9     const float *, const blasint *, const float *, float *, const blasint *);
10 
11 
12 /** CGEMMT computes a matrix-matrix product with general matrices but updates
13  * only the upper or lower triangular part of the result matrix.
14  *
15  * This routine performs the same operation as the BLAS routine
16  * cgemm(transA, transB, n, n, k, alpha, A, ldA, B, ldB, beta, C, ldC)
17  * but only updates the triangular part of C specified by uplo:
18  * If (*uplo == 'L'), only the lower triangular part of C is updated,
19  * otherwise the upper triangular part is updated.
20  * */
RELAPACK_cgemmt(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const float * alpha,const float * A,const blasint * ldA,const float * B,const blasint * ldB,const float * beta,float * C,const blasint * ldC)21 void RELAPACK_cgemmt(
22     const char *uplo, const char *transA, const char *transB,
23     const blasint *n, const blasint *k,
24     const float *alpha, const float *A, const blasint *ldA,
25     const float *B, const blasint *ldB,
26     const float *beta, float *C, const blasint *ldC
27 ) {
28 
29 #if HAVE_XGEMMT
30     BLAS(cgemmt)(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
31     return;
32 #else
33 
34     // Check arguments
35     const blasint lower = LAPACK(lsame)(uplo, "L");
36     const blasint upper = LAPACK(lsame)(uplo, "U");
37     const blasint notransA = LAPACK(lsame)(transA, "N");
38     const blasint tranA = LAPACK(lsame)(transA, "T");
39     const blasint ctransA = LAPACK(lsame)(transA, "C");
40     const blasint notransB = LAPACK(lsame)(transB, "N");
41     const blasint tranB = LAPACK(lsame)(transB, "T");
42     const blasint ctransB = LAPACK(lsame)(transB, "C");
43     blasint info = 0;
44     if (!lower && !upper)
45         info = 1;
46     else if (!tranA && !ctransA && !notransA)
47         info = 2;
48     else if (!tranB && !ctransB && !notransB)
49         info = 3;
50     else if (*n < 0)
51         info = 4;
52     else if (*k < 0)
53         info = 5;
54     else if (*ldA < MAX(1, notransA ? *n : *k))
55         info = 8;
56     else if (*ldB < MAX(1, notransB ? *k : *n))
57         info = 10;
58     else if (*ldC < MAX(1, *n))
59         info = 13;
60     if (info) {
61         LAPACK(xerbla)("CGEMMT", &info, strlen("CGEMMT"));
62         return;
63     }
64 
65     // Clean char * arguments
66     const char cleanuplo = lower ? 'L' : 'U';
67     const char cleantransA = notransA ? 'N' : (tranA ? 'T' : 'C');
68     const char cleantransB = notransB ? 'N' : (tranB ? 'T' : 'C');
69 
70     // Recursive kernel
71     RELAPACK_cgemmt_rec(&cleanuplo, &cleantransA, &cleantransB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
72 #endif
73 }
74 
75 
76 /** cgemmt's recursive compute kernel */
RELAPACK_cgemmt_rec(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const float * alpha,const float * A,const blasint * ldA,const float * B,const blasint * ldB,const float * beta,float * C,const blasint * ldC)77 static void RELAPACK_cgemmt_rec(
78     const char *uplo, const char *transA, const char *transB,
79     const blasint *n, const blasint *k,
80     const float *alpha, const float *A, const blasint *ldA,
81     const float *B, const blasint *ldB,
82     const float *beta, float *C, const blasint *ldC
83 ) {
84 
85     if (*n <= MAX(CROSSOVER_CGEMMT, 1)) {
86         // Unblocked
87         RELAPACK_cgemmt_rec2(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
88         return;
89     }
90 
91     // Splitting
92     const blasint n1 = CREC_SPLIT(*n);
93     const blasint n2 = *n - n1;
94 
95     // A_T
96     // A_B
97     const float *const A_T = A;
98     const float *const A_B = A + 2 * ((*transA == 'N') ? n1 : *ldA * n1);
99 
100     // B_L B_R
101     const float *const B_L = B;
102     const float *const B_R = B + 2 * ((*transB == 'N') ? *ldB * n1 : n1);
103 
104     // C_TL C_TR
105     // C_BL C_BR
106     float *const C_TL = C;
107     float *const C_TR = C + 2 * *ldC * n1;
108     float *const C_BL = C                 + 2 * n1;
109     float *const C_BR = C + 2 * *ldC * n1 + 2 * n1;
110 
111     // recursion(C_TL)
112     RELAPACK_cgemmt_rec(uplo, transA, transB, &n1, k, alpha, A_T, ldA, B_L, ldB, beta, C_TL, ldC);
113 
114     if (*uplo == 'L')
115         // C_BL = alpha A_B B_L + beta C_BL
116         BLAS(cgemm)(transA, transB, &n2, &n1, k, alpha, A_B, ldA, B_L, ldB, beta, C_BL, ldC);
117     else
118         // C_TR = alpha A_T B_R + beta C_TR
119         BLAS(cgemm)(transA, transB, &n1, &n2, k, alpha, A_T, ldA, B_R, ldB, beta, C_TR, ldC);
120 
121     // recursion(C_BR)
122     RELAPACK_cgemmt_rec(uplo, transA, transB, &n2, k, alpha, A_B, ldA, B_R, ldB, beta, C_BR, ldC);
123 }
124 
125 
126 /** cgemmt's unblocked compute kernel */
RELAPACK_cgemmt_rec2(const char * uplo,const char * transA,const char * transB,const blasint * n,const blasint * k,const float * alpha,const float * A,const blasint * ldA,const float * B,const blasint * ldB,const float * beta,float * C,const blasint * ldC)127 static void RELAPACK_cgemmt_rec2(
128     const char *uplo, const char *transA, const char *transB,
129     const blasint *n, const blasint *k,
130     const float *alpha, const float *A, const blasint *ldA,
131     const float *B, const blasint *ldB,
132     const float *beta, float *C, const blasint *ldC
133 ) {
134 
135     const blasint incB = (*transB == 'N') ? 1 : *ldB;
136     const blasint incC = 1;
137 
138     blasint i;
139     for (i = 0; i < *n; i++) {
140         // A_0
141         // A_i
142         const float *const A_0 = A;
143         const float *const A_i = A + 2 * ((*transA == 'N') ? i : *ldA * i);
144 
145         // * B_i *
146         const float *const B_i = B + 2 * ((*transB == 'N') ? *ldB * i : i);
147 
148         // * C_0i *
149         // * C_ii *
150         float *const C_0i = C + 2 * *ldC * i;
151         float *const C_ii = C + 2 * *ldC * i + 2 * i;
152 
153         if (*uplo == 'L') {
154             const blasint nmi = *n - i;
155             if (*transA == 'N')
156                 BLAS(cgemv)(transA, &nmi, k, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
157             else
158                 BLAS(cgemv)(transA, k, &nmi, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
159         } else {
160             const blasint ip1 = i + 1;
161             if (*transA == 'N')
162                 BLAS(cgemv)(transA, &ip1, k, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
163             else
164                 BLAS(cgemv)(transA, k, &ip1, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
165         }
166     }
167 }
168