1 /*
2 * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 */
17
18 /* clang-format off */
19
20 /* mmreal4.c -- F90 fast-/dgemm-like MATMUL intrinsics for real*4 type */
21
22 #include "stdioInterf.h"
23 #include "fioMacros.h"
24
25 #define SMALL_ROWSA 10
26 #define SMALL_ROWSB 10
27 #define SMALL_COLSB 10
28
ENTF90(MMUL_REAL4,mmul_real4)29 void ENTF90(MMUL_REAL4, mmul_real4)(int ta, int tb, __POINT_T mra,
30 __POINT_T ncb, __POINT_T kab, float *alpha,
31 float a[], __POINT_T lda, float b[],
32 __POINT_T ldb, float *beta, float c[],
33 __POINT_T ldc)
34
35 {
36 /*
37 * Notes on parameters
38 * ta = 0 -> no transpose of matrix a
39 * tb = 0 -> no transpose of matrix b
40
41 * mra = number of rows in matrices a and c ( = m )
42 * ncb = number of columns in matrices b and c ( = n )
43 * kab = shared dimension of matrices a and b ( = k, but need k elsewhere )
44 * a = starting address of matrix a
45 * b = starting address of matrix b
46 * c = starting address of matric c
47 * lda = leading dimension of matrix a
48 * ldb = leading dimension of matrix b
49 * ldc = leading dimension of matrix c
50 * alpha = 1.0
51 * beta = 0.0
52 * Note that these last two conditions are inconsitent with the general
53 * case for dgemm.
54 * Taken together we have
55 * c = beta * c + alpha * ( (ta)a * (tb)*b )
56 * where the meaning of (ta) and (tb) is that if ta = 0 a is not transposed
57 * and transposed otherwise and if tb = 0, b is not transpose and transposed
58 * otherwise.
59 */
60
61 // Local variables
62
63 int colsa, rowsa, rowsb, colsb;
64 int ar, ac;
65 int ndx, ndxsav, colchunk, colchunks, rowchunk, rowchunks;
66 int colsb_chunks, colsb_end, colsb_strt;
67 int bufr, bufc, loc, lor;
68 int small_size = SMALL_ROWSA * SMALL_ROWSB * SMALL_COLSB;
69 int tindex = 0;
70 float buffera[SMALL_ROWSA * SMALL_ROWSB];
71 float bufferb[SMALL_COLSB * SMALL_ROWSB];
72 float temp;
73 void ftn_mvmul_real4_(), ftn_vmmul_real4_();
74 void ftn_mnaxnb_real4_(), ftn_mnaxtb_real4_();
75 void ftn_mtaxnb_real4_(), ftn_mtaxtb_real4_();
76 float calpha, cbeta;
77 /*
78 * Small matrix multiply variables
79 */
80 int i, ia, ja, j, k, bk;
81 int astrt, bstrt, cstrt, andx, bndx, cndx, indx, indx_strt;
82 /*
83 * tindex has the following meaning:
84 * ta == 0, tb == 0: tindex = 0
85 * ta == 1, tb == 0: tindex = 1
86 * ta == 0, tb == 1; tindex = 2
87 * ta == 1, tb == 1; tindex = 3
88 */
89
90 /* if( ( tb == 0 ) && ( ncb == 1 ) && ( ldc == 1 ) ){ */
91 if ((tb == 0) && (ncb == 1)) {
92 /* matrix vector multiply */
93 ftn_mvmul_real4_(&ta, &mra, &kab, alpha, a, &lda, b, beta, c);
94 return;
95 }
96 if ((ta == 0) && (mra == 1) && (ldc == 1)) {
97 /* vector matrix multiply */
98 ftn_vmmul_real4_(&tb, &ncb, &kab, alpha, a, b, &ldb, beta, c);
99 return;
100 }
101 calpha = *alpha;
102 cbeta = *beta;
103 rowsa = mra;
104 colsa = kab;
105 rowsb = kab;
106 colsb = ncb;
107 if (ta == 1)
108 tindex = 1;
109
110 if (tb == 1)
111 tindex += 2;
112
113 // Check for really small matrix sizes
114
115 // Check for really small matrix sizes
116
117 if ((colsb <= SMALL_COLSB) && (rowsa <= SMALL_ROWSA) &&
118 (rowsb <= SMALL_ROWSB)) {
119 switch (tindex) {
120 case 0: /* matrix a and matrix b normally oriented
121 *
122 * The notation here refers to the Fortran orientation since
123 * that is the origination of these matrices
124 */
125 astrt = 0;
126 bstrt = 0;
127 cstrt = 0;
128 if (cbeta == (float)0.0) {
129 for (i = 0; i < rowsa; i++) {
130 /* Transpose the a row of the a matrix */
131 andx = astrt;
132 indx = 0;
133 for (ja = 0; ja < colsa; ja++) {
134 buffera[indx++] = calpha * a[andx];
135 andx += lda;
136 }
137 astrt++;
138 cndx = cstrt;
139 for (j = 0; j < colsb; j++) {
140 temp = 0.0;
141 bndx = bstrt;
142 for (k = 0; k < rowsb; k++)
143 temp += buffera[k] * b[bndx++];
144 bstrt += ldb;
145 c[cndx] = temp;
146 cndx += ldc;
147 }
148 cstrt++; /* set index for next row of c */
149 bstrt = 0;
150 }
151 } else {
152 for (i = 0; i < rowsa; i++) {
153 /* Transpose the a row of the a matrix */
154 andx = astrt;
155 indx = 0;
156 for (ja = 0; ja < colsa; ja++) {
157 buffera[indx++] = calpha * a[andx];
158 andx += lda;
159 }
160 astrt++;
161 cndx = cstrt;
162 for (j = 0; j < colsb; j++) {
163 temp = 0.0;
164 bndx = bstrt;
165 for (k = 0; k < rowsb; k++)
166 temp += buffera[k] * b[bndx++];
167 bstrt += ldb;
168 c[cndx] = temp + cbeta * c[cndx];
169 cndx += ldc;
170 }
171 cstrt++; /* set index for next row of c */
172 bstrt = 0;
173 }
174 }
175
176 break;
177 case 1: /* matrix a transpose, matrix b normally oriented */
178 bndx = 0;
179 cstrt = 0;
180 andx = 0;
181 if (cbeta == (float)0.0) {
182 for (j = 0; j < colsb; j++) {
183 cndx = cstrt;
184 for (i = 0; i < rowsa; i++) {
185 /* Matrix a need not be transposed */
186 temp = 0.0;
187 for (k = 0; k < rowsb; k++)
188 temp += a[andx + k] * b[bndx + k];
189 c[cndx] = calpha * temp;
190 andx += lda;
191 cndx++;
192 }
193 cstrt += ldc; /* set index for next column of c */
194 astrt++; /* set index for next column of a */
195 b += ldb;
196 andx = 0;
197 }
198 } else {
199 for (j = 0; j < colsb; j++) {
200 cndx = cstrt;
201 for (i = 0; i < rowsa; i++) {
202 /* Matrix a need not be transposed */
203 temp = 0.0;
204 for (k = 0; k < rowsb; k++)
205 temp += a[andx + k] * b[bndx + k];
206 c[cndx] = calpha * temp + cbeta * c[cndx];
207 andx += lda;
208 cndx++;
209 }
210 cstrt += ldc; /* set index for next column of c */
211 astrt++; /* set index for next column of a */
212 b += ldb;
213 andx = 0;
214 }
215 }
216
217 break;
218 case 2: /* Matrix a normal, b transposed */
219 /* We will transpose b and work with transposed rows of a */
220 /* Transpose matrix b */
221 indx_strt = 0;
222 bstrt = 0;
223 for (j = 0; j < rowsb; j++) {
224 indx = indx_strt;
225 bndx = bstrt;
226 for (i = 0; i < colsb; i++) {
227 bufferb[indx] = calpha * b[bndx++];
228 indx += rowsb;
229 }
230 indx_strt++;
231 bstrt += ldb;
232 }
233 /* All of b is now transposed */
234
235 astrt = 0;
236 cstrt = 0;
237 if (cbeta == (float)0.0) {
238 for (i = 0; i < rowsa; i++) {
239 /* Transpose the a row of the a matrix */
240 andx = astrt;
241 indx = 0;
242 for (ja = 0; ja < colsa; ja++) {
243 buffera[indx++] = a[andx];
244 andx += lda;
245 }
246 cndx = cstrt;
247 bndx = 0;
248 for (j = 0; j < colsb; j++) {
249 temp = 0.0;
250 for (k = 0; k < rowsb; k++)
251 temp += buffera[k] * bufferb[bndx++];
252 c[cndx] = temp;
253 cndx += ldc;
254 }
255 cstrt++; /* set index for next row of c */
256 astrt++;
257 }
258 } else {
259 for (i = 0; i < rowsa; i++) {
260 /* Transpose the a row of the a matrix */
261 andx = astrt;
262 indx = 0;
263 for (ja = 0; ja < colsa; ja++) {
264 buffera[indx++] = a[andx];
265 andx += lda;
266 }
267 cndx = cstrt;
268 bndx = 0;
269 for (j = 0; j < colsb; j++) {
270 temp = 0.0;
271 for (k = 0; k < rowsb; k++)
272 temp += buffera[k] * bufferb[bndx++];
273 c[cndx] = temp + cbeta * c[cndx];
274 cndx += ldc;
275 }
276 cstrt++; /* set index for next row of c */
277 astrt++;
278 }
279 }
280 break;
281 case 3: /* both matrices tranposed. Combination of cases 1 and 2 */
282 /* Transpose matrix b */
283
284 indx_strt = 0;
285 bstrt = 0;
286 for (j = 0; j < rowsb; j++) {
287 indx = indx_strt;
288 bndx = bstrt;
289 for (i = 0; i < colsb; i++) {
290 bufferb[indx] = calpha * b[bndx++];
291 indx += rowsb;
292 }
293 indx_strt++;
294 bstrt += ldb;
295 }
296
297 /* All of b is now transposed */
298 andx = 0;
299 cstrt = 0;
300 bndx = 0;
301 if (cbeta == (float)0.0) {
302 for (i = 0; i < colsb; i++) {
303 /* Matrix a need not be transposed */
304 cndx = cstrt;
305 for (j = 0; j < rowsa; j++) {
306 temp = 0.0;
307 for (k = 0; k < rowsb; k++)
308 temp += a[andx + k] * bufferb[bndx + k];
309 c[cndx] = temp;
310 cndx++;
311 andx += lda;
312 }
313 bndx += rowsb; /* index for next transposed column of b */
314 andx = 0; /* set index for next column of a */
315 cstrt += ldc; /* set index for next row of c */
316 }
317 } else {
318 for (i = 0; i < colsb; i++) {
319 /* Matrix a need not be transposed */
320 cndx = cstrt;
321 for (j = 0; j < rowsa; j++) {
322 temp = 0.0;
323 for (k = 0; k < rowsb; k++)
324 temp += a[andx + k] * bufferb[bndx + k];
325 c[cndx] = temp + cbeta * c[cndx];
326 cndx++;
327 andx += lda;
328 }
329 bndx += rowsb; /* index for next transposed column of b */
330 andx = 0; /* set index for next column of a */
331 }
332 }
333 }
334 } else {
335 switch (tindex) {
336 case 0:
337 ftn_mnaxnb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
338 &ldc);
339 break;
340 case 1:
341 ftn_mtaxnb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
342 &ldc);
343 break;
344 case 2:
345 ftn_mnaxtb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
346 &ldc);
347 break;
348 case 3:
349 ftn_mtaxtb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
350 &ldc);
351 }
352 }
353
354 }
355