1 /* Implementation of the MATMUL intrinsic
2    Copyright (C) 2002-2013 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
4 
5 This file is part of the GNU Fortran runtime library (libgfortran).
6 
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
11 
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 GNU General Public License for more details.
16 
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
20 
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
24 <http://www.gnu.org/licenses/>.  */
25 
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>
30 
31 
32 #if defined (HAVE_GFC_COMPLEX_4)
33 
34 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
35    passed to us by the front-end, in which case we'll call it for large
36    matrices.  */
37 
38 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
39                           const int *, const GFC_COMPLEX_4 *, const GFC_COMPLEX_4 *,
40                           const int *, const GFC_COMPLEX_4 *, const int *,
41                           const GFC_COMPLEX_4 *, GFC_COMPLEX_4 *, const int *,
42                           int, int);
43 
44 /* The order of loops is different in the case of plain matrix
45    multiplication C=MATMUL(A,B), and in the frequent special case where
46    the argument A is the temporary result of a TRANSPOSE intrinsic:
47    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
48    looking at their strides.
49 
50    The equivalent Fortran pseudo-code is:
51 
52    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
53    IF (.NOT.IS_TRANSPOSED(A)) THEN
54      C = 0
55      DO J=1,N
56        DO K=1,COUNT
57          DO I=1,M
58            C(I,J) = C(I,J)+A(I,K)*B(K,J)
59    ELSE
60      DO J=1,N
61        DO I=1,M
62          S = 0
63          DO K=1,COUNT
64            S = S+A(I,K)*B(K,J)
65          C(I,J) = S
66    ENDIF
67 */
68 
69 /* If try_blas is set to a nonzero value, then the matmul function will
70    see if there is a way to perform the matrix multiplication by a call
71    to the BLAS gemm function.  */
72 
73 extern void matmul_c4 (gfc_array_c4 * const restrict retarray,
74 	gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b, int try_blas,
75 	int blas_limit, blas_call gemm);
76 export_proto(matmul_c4);
77 
78 void
matmul_c4(gfc_array_c4 * const restrict retarray,gfc_array_c4 * const restrict a,gfc_array_c4 * const restrict b,int try_blas,int blas_limit,blas_call gemm)79 matmul_c4 (gfc_array_c4 * const restrict retarray,
80 	gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b, int try_blas,
81 	int blas_limit, blas_call gemm)
82 {
83   const GFC_COMPLEX_4 * restrict abase;
84   const GFC_COMPLEX_4 * restrict bbase;
85   GFC_COMPLEX_4 * restrict dest;
86 
87   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
88   index_type x, y, n, count, xcount, ycount;
89 
90   assert (GFC_DESCRIPTOR_RANK (a) == 2
91           || GFC_DESCRIPTOR_RANK (b) == 2);
92 
93 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
94 
95    Either A or B (but not both) can be rank 1:
96 
97    o One-dimensional argument A is implicitly treated as a row matrix
98      dimensioned [1,count], so xcount=1.
99 
100    o One-dimensional argument B is implicitly treated as a column matrix
101      dimensioned [count, 1], so ycount=1.
102   */
103 
104   if (retarray->base_addr == NULL)
105     {
106       if (GFC_DESCRIPTOR_RANK (a) == 1)
107         {
108 	  GFC_DIMENSION_SET(retarray->dim[0], 0,
109 	                    GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
110         }
111       else if (GFC_DESCRIPTOR_RANK (b) == 1)
112         {
113 	  GFC_DIMENSION_SET(retarray->dim[0], 0,
114 	                    GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
115         }
116       else
117         {
118 	  GFC_DIMENSION_SET(retarray->dim[0], 0,
119 	                    GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
120 
121           GFC_DIMENSION_SET(retarray->dim[1], 0,
122 	                    GFC_DESCRIPTOR_EXTENT(b,1) - 1,
123 			    GFC_DESCRIPTOR_EXTENT(retarray,0));
124         }
125 
126       retarray->base_addr
127 	= xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_4));
128       retarray->offset = 0;
129     }
130     else if (unlikely (compile_options.bounds_check))
131       {
132 	index_type ret_extent, arg_extent;
133 
134 	if (GFC_DESCRIPTOR_RANK (a) == 1)
135 	  {
136 	    arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
137 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
138 	    if (arg_extent != ret_extent)
139 	      runtime_error ("Incorrect extent in return array in"
140 			     " MATMUL intrinsic: is %ld, should be %ld",
141 			     (long int) ret_extent, (long int) arg_extent);
142 	  }
143 	else if (GFC_DESCRIPTOR_RANK (b) == 1)
144 	  {
145 	    arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
146 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
147 	    if (arg_extent != ret_extent)
148 	      runtime_error ("Incorrect extent in return array in"
149 			     " MATMUL intrinsic: is %ld, should be %ld",
150 			     (long int) ret_extent, (long int) arg_extent);
151 	  }
152 	else
153 	  {
154 	    arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
155 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
156 	    if (arg_extent != ret_extent)
157 	      runtime_error ("Incorrect extent in return array in"
158 			     " MATMUL intrinsic for dimension 1:"
159 			     " is %ld, should be %ld",
160 			     (long int) ret_extent, (long int) arg_extent);
161 
162 	    arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
163 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
164 	    if (arg_extent != ret_extent)
165 	      runtime_error ("Incorrect extent in return array in"
166 			     " MATMUL intrinsic for dimension 2:"
167 			     " is %ld, should be %ld",
168 			     (long int) ret_extent, (long int) arg_extent);
169 	  }
170       }
171 
172 
173   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
174     {
175       /* One-dimensional result may be addressed in the code below
176 	 either as a row or a column matrix. We want both cases to
177 	 work. */
178       rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
179     }
180   else
181     {
182       rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
183       rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
184     }
185 
186 
187   if (GFC_DESCRIPTOR_RANK (a) == 1)
188     {
189       /* Treat it as a a row matrix A[1,count]. */
190       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
191       aystride = 1;
192 
193       xcount = 1;
194       count = GFC_DESCRIPTOR_EXTENT(a,0);
195     }
196   else
197     {
198       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
199       aystride = GFC_DESCRIPTOR_STRIDE(a,1);
200 
201       count = GFC_DESCRIPTOR_EXTENT(a,1);
202       xcount = GFC_DESCRIPTOR_EXTENT(a,0);
203     }
204 
205   if (count != GFC_DESCRIPTOR_EXTENT(b,0))
206     {
207       if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
208 	runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
209     }
210 
211   if (GFC_DESCRIPTOR_RANK (b) == 1)
212     {
213       /* Treat it as a column matrix B[count,1] */
214       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
215 
216       /* bystride should never be used for 1-dimensional b.
217 	 in case it is we want it to cause a segfault, rather than
218 	 an incorrect result. */
219       bystride = 0xDEADBEEF;
220       ycount = 1;
221     }
222   else
223     {
224       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
225       bystride = GFC_DESCRIPTOR_STRIDE(b,1);
226       ycount = GFC_DESCRIPTOR_EXTENT(b,1);
227     }
228 
229   abase = a->base_addr;
230   bbase = b->base_addr;
231   dest = retarray->base_addr;
232 
233 
234   /* Now that everything is set up, we're performing the multiplication
235      itself.  */
236 
237 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
238 
239   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
240       && (bxstride == 1 || bystride == 1)
241       && (((float) xcount) * ((float) ycount) * ((float) count)
242           > POW3(blas_limit)))
243   {
244     const int m = xcount, n = ycount, k = count, ldc = rystride;
245     const GFC_COMPLEX_4 one = 1, zero = 0;
246     const int lda = (axstride == 1) ? aystride : axstride,
247               ldb = (bxstride == 1) ? bystride : bxstride;
248 
249     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
250       {
251         assert (gemm != NULL);
252         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
253               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
254         return;
255       }
256   }
257 
258   if (rxstride == 1 && axstride == 1 && bxstride == 1)
259     {
260       const GFC_COMPLEX_4 * restrict bbase_y;
261       GFC_COMPLEX_4 * restrict dest_y;
262       const GFC_COMPLEX_4 * restrict abase_n;
263       GFC_COMPLEX_4 bbase_yn;
264 
265       if (rystride == xcount)
266 	memset (dest, 0, (sizeof (GFC_COMPLEX_4) * xcount * ycount));
267       else
268 	{
269 	  for (y = 0; y < ycount; y++)
270 	    for (x = 0; x < xcount; x++)
271 	      dest[x + y*rystride] = (GFC_COMPLEX_4)0;
272 	}
273 
274       for (y = 0; y < ycount; y++)
275 	{
276 	  bbase_y = bbase + y*bystride;
277 	  dest_y = dest + y*rystride;
278 	  for (n = 0; n < count; n++)
279 	    {
280 	      abase_n = abase + n*aystride;
281 	      bbase_yn = bbase_y[n];
282 	      for (x = 0; x < xcount; x++)
283 		{
284 		  dest_y[x] += abase_n[x] * bbase_yn;
285 		}
286 	    }
287 	}
288     }
289   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
290     {
291       if (GFC_DESCRIPTOR_RANK (a) != 1)
292 	{
293 	  const GFC_COMPLEX_4 *restrict abase_x;
294 	  const GFC_COMPLEX_4 *restrict bbase_y;
295 	  GFC_COMPLEX_4 *restrict dest_y;
296 	  GFC_COMPLEX_4 s;
297 
298 	  for (y = 0; y < ycount; y++)
299 	    {
300 	      bbase_y = &bbase[y*bystride];
301 	      dest_y = &dest[y*rystride];
302 	      for (x = 0; x < xcount; x++)
303 		{
304 		  abase_x = &abase[x*axstride];
305 		  s = (GFC_COMPLEX_4) 0;
306 		  for (n = 0; n < count; n++)
307 		    s += abase_x[n] * bbase_y[n];
308 		  dest_y[x] = s;
309 		}
310 	    }
311 	}
312       else
313 	{
314 	  const GFC_COMPLEX_4 *restrict bbase_y;
315 	  GFC_COMPLEX_4 s;
316 
317 	  for (y = 0; y < ycount; y++)
318 	    {
319 	      bbase_y = &bbase[y*bystride];
320 	      s = (GFC_COMPLEX_4) 0;
321 	      for (n = 0; n < count; n++)
322 		s += abase[n*axstride] * bbase_y[n];
323 	      dest[y*rystride] = s;
324 	    }
325 	}
326     }
327   else if (axstride < aystride)
328     {
329       for (y = 0; y < ycount; y++)
330 	for (x = 0; x < xcount; x++)
331 	  dest[x*rxstride + y*rystride] = (GFC_COMPLEX_4)0;
332 
333       for (y = 0; y < ycount; y++)
334 	for (n = 0; n < count; n++)
335 	  for (x = 0; x < xcount; x++)
336 	    /* dest[x,y] += a[x,n] * b[n,y] */
337 	    dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
338     }
339   else if (GFC_DESCRIPTOR_RANK (a) == 1)
340     {
341       const GFC_COMPLEX_4 *restrict bbase_y;
342       GFC_COMPLEX_4 s;
343 
344       for (y = 0; y < ycount; y++)
345 	{
346 	  bbase_y = &bbase[y*bystride];
347 	  s = (GFC_COMPLEX_4) 0;
348 	  for (n = 0; n < count; n++)
349 	    s += abase[n*axstride] * bbase_y[n*bxstride];
350 	  dest[y*rxstride] = s;
351 	}
352     }
353   else
354     {
355       const GFC_COMPLEX_4 *restrict abase_x;
356       const GFC_COMPLEX_4 *restrict bbase_y;
357       GFC_COMPLEX_4 *restrict dest_y;
358       GFC_COMPLEX_4 s;
359 
360       for (y = 0; y < ycount; y++)
361 	{
362 	  bbase_y = &bbase[y*bystride];
363 	  dest_y = &dest[y*rystride];
364 	  for (x = 0; x < xcount; x++)
365 	    {
366 	      abase_x = &abase[x*axstride];
367 	      s = (GFC_COMPLEX_4) 0;
368 	      for (n = 0; n < count; n++)
369 		s += abase_x[n*aystride] * bbase_y[n*bxstride];
370 	      dest_y[x*rxstride] = s;
371 	    }
372 	}
373     }
374 }
375 
376 #endif
377