1 /*
2  * Copyright (c) 2017, NVIDIA CORPORATION.  All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 /* clang-format off */
19 
20 /* mmreal4.c -- F90 fast-/dgemm-like MATMUL intrinsics for real*4 type */
21 
22 #include "stdioInterf.h"
23 #include "fioMacros.h"
24 
25 #define SMALL_ROWSA 10
26 #define SMALL_ROWSB 10
27 #define SMALL_COLSB 10
28 
ENTF90(MMUL_REAL4,mmul_real4)29 void ENTF90(MMUL_REAL4, mmul_real4)(int ta, int tb, __POINT_T mra,
30                                     __POINT_T ncb, __POINT_T kab, float *alpha,
31                                     float a[], __POINT_T lda, float b[],
32                                     __POINT_T ldb, float *beta, float c[],
33                                     __POINT_T ldc)
34 
35 {
36   /*
37   *   Notes on parameters
38   *   ta = 0 -> no transpose of matrix a
39   *   tb = 0 -> no transpose of matrix b
40 
41   *   mra = number of rows in matrices a and c ( = m )
42   *   ncb = number of columns in matrices b and c ( = n )
43   *   kab = shared dimension of matrices a and b ( = k, but need k elsewhere )
44   *   a = starting address of matrix a
45   *   b = starting address of matrix b
46   *   c = starting address of matric c
47   *   lda = leading dimension of matrix a
48   *   ldb = leading dimension of matrix b
49   *   ldc = leading dimension of matrix c
50   *   alpha = 1.0
51   *   beta = 0.0
52   *   Note that these last two conditions are inconsitent with the general
53   *   case for dgemm.
54   *   Taken together we have
55   *   c = beta * c + alpha * ( (ta)a * (tb)*b )
56   *   where the meaning of (ta) and (tb) is that if ta = 0 a is not transposed
57   *   and transposed otherwise and if tb = 0, b is not transpose and transposed
58   *   otherwise.
59   */
60 
61   // Local variables
62 
63   int colsa, rowsa, rowsb, colsb;
64   int ar, ac;
65   int ndx, ndxsav, colchunk, colchunks, rowchunk, rowchunks;
66   int colsb_chunks, colsb_end, colsb_strt;
67   int bufr, bufc, loc, lor;
68   int small_size = SMALL_ROWSA * SMALL_ROWSB * SMALL_COLSB;
69   int tindex = 0;
70   float buffera[SMALL_ROWSA * SMALL_ROWSB];
71   float bufferb[SMALL_COLSB * SMALL_ROWSB];
72   float temp;
73   void ftn_mvmul_real4_(), ftn_vmmul_real4_();
74   void ftn_mnaxnb_real4_(), ftn_mnaxtb_real4_();
75   void ftn_mtaxnb_real4_(), ftn_mtaxtb_real4_();
76   float calpha, cbeta;
77   /*
78    * Small matrix multiply variables
79    */
80   int i, ia, ja, j, k, bk;
81   int astrt, bstrt, cstrt, andx, bndx, cndx, indx, indx_strt;
82   /*
83    * tindex has the following meaning:
84    * ta == 0, tb == 0: tindex = 0
85    * ta == 1, tb == 0: tindex = 1
86    * ta == 0, tb == 1; tindex = 2
87    * ta == 1, tb == 1; tindex = 3
88    */
89 
90   /*  if( ( tb == 0 ) && ( ncb == 1 ) && ( ldc == 1 ) ){ */
91   if ((tb == 0) && (ncb == 1)) {
92     /* matrix vector multiply */
93     ftn_mvmul_real4_(&ta, &mra, &kab, alpha, a, &lda, b, beta, c);
94     return;
95   }
96   if ((ta == 0) && (mra == 1) && (ldc == 1)) {
97     /* vector matrix multiply */
98     ftn_vmmul_real4_(&tb, &ncb, &kab, alpha, a, b, &ldb, beta, c);
99     return;
100   }
101   calpha = *alpha;
102   cbeta = *beta;
103   rowsa = mra;
104   colsa = kab;
105   rowsb = kab;
106   colsb = ncb;
107   if (ta == 1)
108     tindex = 1;
109 
110   if (tb == 1)
111     tindex += 2;
112 
113   // Check for really small matrix sizes
114 
115   // Check for really small matrix sizes
116 
117   if ((colsb <= SMALL_COLSB) && (rowsa <= SMALL_ROWSA) &&
118       (rowsb <= SMALL_ROWSB)) {
119     switch (tindex) {
120     case 0: /* matrix a and matrix b normally oriented
121              *
122              * The notation here refers to the Fortran orientation since
123              * that is the origination of these matrices
124              */
125       astrt = 0;
126       bstrt = 0;
127       cstrt = 0;
128       if (cbeta == (float)0.0) {
129         for (i = 0; i < rowsa; i++) {
130           /* Transpose the a row of the a matrix */
131           andx = astrt;
132           indx = 0;
133           for (ja = 0; ja < colsa; ja++) {
134             buffera[indx++] = calpha * a[andx];
135             andx += lda;
136           }
137           astrt++;
138           cndx = cstrt;
139           for (j = 0; j < colsb; j++) {
140             temp = 0.0;
141             bndx = bstrt;
142             for (k = 0; k < rowsb; k++)
143               temp += buffera[k] * b[bndx++];
144             bstrt += ldb;
145             c[cndx] = temp;
146             cndx += ldc;
147           }
148           cstrt++; /* set index for next row of c */
149           bstrt = 0;
150         }
151       } else {
152         for (i = 0; i < rowsa; i++) {
153           /* Transpose the a row of the a matrix */
154           andx = astrt;
155           indx = 0;
156           for (ja = 0; ja < colsa; ja++) {
157             buffera[indx++] = calpha * a[andx];
158             andx += lda;
159           }
160           astrt++;
161           cndx = cstrt;
162           for (j = 0; j < colsb; j++) {
163             temp = 0.0;
164             bndx = bstrt;
165             for (k = 0; k < rowsb; k++)
166               temp += buffera[k] * b[bndx++];
167             bstrt += ldb;
168             c[cndx] = temp + cbeta * c[cndx];
169             cndx += ldc;
170           }
171           cstrt++; /* set index for next row of c */
172           bstrt = 0;
173         }
174       }
175 
176       break;
177     case 1: /* matrix a transpose, matrix b normally oriented */
178       bndx = 0;
179       cstrt = 0;
180       andx = 0;
181       if (cbeta == (float)0.0) {
182         for (j = 0; j < colsb; j++) {
183           cndx = cstrt;
184           for (i = 0; i < rowsa; i++) {
185             /* Matrix a need not be transposed */
186             temp = 0.0;
187             for (k = 0; k < rowsb; k++)
188               temp += a[andx + k] * b[bndx + k];
189             c[cndx] = calpha * temp;
190             andx += lda;
191             cndx++;
192           }
193           cstrt += ldc; /* set index for next column of c */
194           astrt++;      /* set index for next column of a */
195           b += ldb;
196           andx = 0;
197         }
198       } else {
199         for (j = 0; j < colsb; j++) {
200           cndx = cstrt;
201           for (i = 0; i < rowsa; i++) {
202             /* Matrix a need not be transposed */
203             temp = 0.0;
204             for (k = 0; k < rowsb; k++)
205               temp += a[andx + k] * b[bndx + k];
206             c[cndx] = calpha * temp + cbeta * c[cndx];
207             andx += lda;
208             cndx++;
209           }
210           cstrt += ldc; /* set index for next column of c */
211           astrt++;      /* set index for next column of a */
212           b += ldb;
213           andx = 0;
214         }
215       }
216 
217       break;
218     case 2: /* Matrix a normal, b transposed */
219       /* We will transpose b and work with transposed rows of a */
220       /* Transpose matrix b */
221       indx_strt = 0;
222       bstrt = 0;
223       for (j = 0; j < rowsb; j++) {
224         indx = indx_strt;
225         bndx = bstrt;
226         for (i = 0; i < colsb; i++) {
227           bufferb[indx] = calpha * b[bndx++];
228           indx += rowsb;
229         }
230         indx_strt++;
231         bstrt += ldb;
232       }
233       /* All of b is now transposed */
234 
235       astrt = 0;
236       cstrt = 0;
237       if (cbeta == (float)0.0) {
238         for (i = 0; i < rowsa; i++) {
239           /* Transpose the a row of the a matrix */
240           andx = astrt;
241           indx = 0;
242           for (ja = 0; ja < colsa; ja++) {
243             buffera[indx++] = a[andx];
244             andx += lda;
245           }
246           cndx = cstrt;
247           bndx = 0;
248           for (j = 0; j < colsb; j++) {
249             temp = 0.0;
250             for (k = 0; k < rowsb; k++)
251               temp += buffera[k] * bufferb[bndx++];
252             c[cndx] = temp;
253             cndx += ldc;
254           }
255           cstrt++; /* set index for next row of c */
256           astrt++;
257         }
258       } else {
259         for (i = 0; i < rowsa; i++) {
260           /* Transpose the a row of the a matrix */
261           andx = astrt;
262           indx = 0;
263           for (ja = 0; ja < colsa; ja++) {
264             buffera[indx++] = a[andx];
265             andx += lda;
266           }
267           cndx = cstrt;
268           bndx = 0;
269           for (j = 0; j < colsb; j++) {
270             temp = 0.0;
271             for (k = 0; k < rowsb; k++)
272               temp += buffera[k] * bufferb[bndx++];
273             c[cndx] = temp + cbeta * c[cndx];
274             cndx += ldc;
275           }
276           cstrt++; /* set index for next row of c */
277           astrt++;
278         }
279       }
280       break;
281     case 3: /* both matrices tranposed. Combination of cases 1 and 2 */
282       /* Transpose matrix b */
283 
284       indx_strt = 0;
285       bstrt = 0;
286       for (j = 0; j < rowsb; j++) {
287         indx = indx_strt;
288         bndx = bstrt;
289         for (i = 0; i < colsb; i++) {
290           bufferb[indx] = calpha * b[bndx++];
291           indx += rowsb;
292         }
293         indx_strt++;
294         bstrt += ldb;
295       }
296 
297       /* All of b is now transposed */
298       andx = 0;
299       cstrt = 0;
300       bndx = 0;
301       if (cbeta == (float)0.0) {
302         for (i = 0; i < colsb; i++) {
303           /* Matrix a need not be transposed */
304           cndx = cstrt;
305           for (j = 0; j < rowsa; j++) {
306             temp = 0.0;
307             for (k = 0; k < rowsb; k++)
308               temp += a[andx + k] * bufferb[bndx + k];
309             c[cndx] = temp;
310             cndx++;
311             andx += lda;
312           }
313           bndx += rowsb; /* index for next transposed column of b */
314           andx = 0;      /* set index for next column of a */
315           cstrt += ldc;  /* set index for next row of c */
316         }
317       } else {
318         for (i = 0; i < colsb; i++) {
319           /* Matrix a need not be transposed */
320           cndx = cstrt;
321           for (j = 0; j < rowsa; j++) {
322             temp = 0.0;
323             for (k = 0; k < rowsb; k++)
324               temp += a[andx + k] * bufferb[bndx + k];
325             c[cndx] = temp + cbeta * c[cndx];
326             cndx++;
327             andx += lda;
328           }
329           bndx += rowsb; /* index for next transposed column of b */
330           andx = 0;      /* set index for next column of a */
331         }
332       }
333     }
334   } else {
335     switch (tindex) {
336     case 0:
337       ftn_mnaxnb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
338                           &ldc);
339       break;
340     case 1:
341       ftn_mtaxnb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
342                           &ldc);
343       break;
344     case 2:
345       ftn_mnaxtb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
346                           &ldc);
347       break;
348     case 3:
349       ftn_mtaxtb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
350                           &ldc);
351     }
352   }
353 
354 }
355