1 /*
2  * Copyright (c) 2017, NVIDIA CORPORATION.  All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 /* clang-format off */
19 
20 /** \file
21  * \brief F90  MATMUL intrinsics for real*8 type
22  */
23 
24 #include "stdioInterf.h"
25 #include "fioMacros.h"
26 #include "matmul.h"
27 
ENTF90(MATMUL_REAL8,matmul_real8)28 void ENTF90(MATMUL_REAL8, matmul_real8)(char *dest_addr, char *s1_addr,
29                                         char *s2_addr, F90_Desc *dest_desc,
30                                         F90_Desc *s1_desc, F90_Desc *s2_desc)
31 {
32 
33   __REAL8_T *s1_base;
34   __REAL8_T *s2_base;
35   __REAL8_T *dest_base;
36   __REAL8_T *d_elem_p;
37   __REAL8_T *s1_elem_p;
38   __REAL8_T *s2_elem_p;
39 
40   __REAL8_T rslt_tmp;
41 
42   __INT_T s1_d1_lstride;
43   __INT_T s1_d1_sstride;
44   __INT_T s1_d1_lb;
45   __INT_T s1_d1_soffset = 0;
46 
47   __INT_T s1_d2_lstride = 1;
48   __INT_T s1_d2_sstride = 1;
49   __INT_T s1_d2_lb = 0;
50   __INT_T s1_d2_soffset = 0;
51 
52   __INT_T s2_d1_lstride;
53   __INT_T s2_d1_sstride;
54   __INT_T s2_d1_lb;
55   __INT_T s2_d1_soffset = 0;
56 
57   __INT_T s2_d2_lstride = 1;
58   __INT_T s2_d2_sstride = 1;
59   __INT_T s2_d2_lb = 0;
60   __INT_T s2_d2_soffset = 0;
61 
62   __INT_T d_d1_lstride;
63   __INT_T d_d1_sstride;
64   __INT_T d_d1_lb;
65   __INT_T d_d1_soffset = 0;
66 
67   __INT_T d_d2_lstride = 1;
68   __INT_T d_d2_sstride = 1;
69   __INT_T d_d2_lb = 0;
70   __INT_T d_d2_soffset = 0;
71 
72   __INT_T d_rank = F90_RANK_G(dest_desc);
73   __INT_T s1_rank = F90_RANK_G(s1_desc);
74   __INT_T s2_rank = F90_RANK_G(s2_desc);
75 
76   __INT_T k_extent = s2_rank == 2 ? F90_DIM_EXTENT_G(s2_desc, 1) : 1;
77   __INT_T m_extent = s1_rank == 2 ? F90_DIM_EXTENT_G(s1_desc, 1)
78                                   : F90_DIM_EXTENT_G(s1_desc, 0);
79   __INT_T n_extent = s1_rank == 2 ? F90_DIM_EXTENT_G(s1_desc, 0) : 1;
80 
81   __INT_T dest_offset;
82 
83   __INT_T s1_d1_base, s1_d1_delta, s1_d1_offset, s1_d2_base, s1_d2_delta,
84       s1_d2_offset, s2_d1_base, s2_d1_delta, s2_d1_offset, s2_d2_base,
85       s2_d2_delta, s2_d2_offset, d_d1_base, d_d1_delta, d_d1_offset, d_d2_base,
86       d_d2_delta, d_d2_offset;
87 
88   __INT_T k;
89   __INT_T m;
90   __INT_T n;
91 
92   /* mxm
93    *  s1(n,m) x s2(m,k) -> dest(n,k)
94    *  Check
95    *   dest_d1 extent== n_extnet
96    *   dest_d2 extent == k_extent
97    *   s2_d1 extent = m_extent
98    *
99    * mxv
100    *  s1(n,m) x s2(m) -> dest(n)
101    *  Check
102    *   dest_d1 extent== n_extent
103    *   s2_d1 extent == m_extent
104    *
105    * vxm
106    *  s1(m) x s2(m,k) -> dest(k)
107    *  check
108    *   s2_d1 extent == m_extent
109    *   dest_d1 extent == k_extent
110    */
111 
112   if (d_rank == 2 && s1_rank == 2 && s2_rank == 2) {
113     if (F90_DIM_EXTENT_G(dest_desc, 0) != n_extent ||
114         F90_DIM_EXTENT_G(dest_desc, 1) != k_extent ||
115         F90_DIM_EXTENT_G(s2_desc, 0) != m_extent) {
116       __fort_abort("MATMUL: nonconforming array shapes");
117     }
118   } else if (d_rank == 1 && s1_rank == 2 && s2_rank == 1) {
119     if (F90_DIM_EXTENT_G(dest_desc, 0) != n_extent ||
120         F90_DIM_EXTENT_G(s2_desc, 0) != m_extent) {
121       __fort_abort("MATMUL: nonconforming array shapes");
122     }
123   } else if (d_rank == 1 && s1_rank == 1 && s2_rank == 2) {
124     if (F90_DIM_EXTENT_G(dest_desc, 0) != k_extent ||
125         F90_DIM_EXTENT_G(s2_desc, 0) != m_extent) {
126       __fort_abort("MATMUL: nonconforming array shapes");
127     }
128   } else {
129     __fort_abort("MATMUL: non-conforming array shapes");
130   }
131 
132   s1_d1_lstride = F90_DIM_LSTRIDE_G(s1_desc, 0);
133   s1_d1_sstride = F90_DIM_SSTRIDE_G(s1_desc, 0);
134   s1_d1_lb = F90_DIM_LBOUND_G(s1_desc, 0);
135   if (s1_d1_sstride != 1 || F90_DIM_SOFFSET_G(s1_desc, 0))
136     s1_d1_soffset = F90_DIM_SOFFSET_G(s1_desc, 0) + s1_d1_sstride - s1_d1_lb;
137 
138   if (s1_rank == 2) {
139     s1_d2_lstride = F90_DIM_LSTRIDE_G(s1_desc, 1);
140     s1_d2_lb = F90_DIM_LBOUND_G(s1_desc, 1);
141     s1_d2_sstride = F90_DIM_SSTRIDE_G(s1_desc, 1);
142     if (s1_d2_sstride != 1 || F90_DIM_SOFFSET_G(s1_desc, 1))
143       s1_d2_soffset = F90_DIM_SOFFSET_G(s1_desc, 1) + s1_d2_sstride - s1_d2_lb;
144   }
145 
146   s2_d1_lstride = F90_DIM_LSTRIDE_G(s2_desc, 0);
147   s2_d1_lb = F90_DIM_LBOUND_G(s2_desc, 0);
148   s2_d1_sstride = F90_DIM_SSTRIDE_G(s2_desc, 0);
149   if (s2_d1_sstride != 1 || F90_DIM_SOFFSET_G(s2_desc, 0))
150     s2_d1_soffset = F90_DIM_SOFFSET_G(s2_desc, 0) + s2_d1_sstride - s2_d1_lb;
151 
152   if (s2_rank == 2) {
153     s2_d2_lstride = F90_DIM_LSTRIDE_G(s2_desc, 1);
154     s2_d2_lb = F90_DIM_LBOUND_G(s2_desc, 1);
155     s2_d2_sstride = F90_DIM_SSTRIDE_G(s2_desc, 1);
156     if (s2_d2_sstride != 1 || F90_DIM_SOFFSET_G(s2_desc, 1))
157       s2_d2_soffset = F90_DIM_SOFFSET_G(s2_desc, 1) + s2_d2_sstride - s2_d2_lb;
158   }
159 
160   d_d1_lstride = F90_DIM_LSTRIDE_G(dest_desc, 0);
161   d_d1_lb = F90_DIM_LBOUND_G(dest_desc, 0);
162   d_d1_sstride = F90_DIM_SSTRIDE_G(dest_desc, 0);
163   if (d_d1_sstride != 1 || F90_DIM_SOFFSET_G(dest_desc, 0))
164     d_d1_soffset = F90_DIM_SOFFSET_G(dest_desc, 0) + d_d1_sstride - d_d1_lb;
165 
166   if (d_rank == 2) {
167     d_d2_lstride = F90_DIM_LSTRIDE_G(dest_desc, 1);
168     d_d2_lb = F90_DIM_LBOUND_G(dest_desc, 1);
169     d_d2_sstride = F90_DIM_SSTRIDE_G(dest_desc, 1);
170     if (d_d2_sstride != 1 || F90_DIM_SOFFSET_G(dest_desc, 1))
171       d_d2_soffset = F90_DIM_SOFFSET_G(dest_desc, 1) + d_d2_sstride - d_d2_lb;
172   }
173   s1_base = (__REAL8_T *)s1_addr + F90_LBASE_G(s1_desc) +
174             s1_d1_lb * s1_d1_lstride + s1_d2_lb * s1_d2_lstride - 1;
175   s2_base = (__REAL8_T *)s2_addr + F90_LBASE_G(s2_desc) +
176             s2_d1_lb * s2_d1_lstride + s2_d2_lb * s2_d2_lstride - 1;
177   dest_base = (__REAL8_T *)dest_addr + F90_LBASE_G(dest_desc) +
178               d_d1_lb * d_d1_lstride + d_d2_lb * d_d2_lstride - 1;
179 
180   d_d1_offset = d_d1_base = d_d1_soffset * d_d1_lstride;
181   d_d1_delta = d_d1_sstride * d_d1_lstride;
182 
183   d_d2_offset = d_d2_base = d_d2_soffset * d_d2_lstride;
184   d_d2_delta = s1_rank == 2 ? d_d2_sstride * d_d2_lstride : d_d1_delta;
185 
186   s1_d1_offset = s1_d1_base = s1_d1_soffset * s1_d1_lstride;
187   s1_d1_delta = s1_d1_sstride * s1_d1_lstride;
188 
189   s1_d2_offset = s1_d2_base = s1_d2_soffset * s1_d2_lstride;
190   s1_d2_delta = s1_rank == 2 ? s1_d2_sstride * s1_d2_lstride : s1_d1_delta;
191 
192   s2_d1_offset = s2_d1_base = s2_d1_soffset * s2_d1_lstride;
193   s2_d1_delta = s2_d1_sstride * s2_d1_lstride;
194 
195   s2_d2_offset = s2_d2_base = s2_d2_soffset * s2_d2_lstride;
196   s2_d2_delta = s2_d2_sstride * s2_d2_lstride;
197 
198   if ((s1_d1_sstride == 1) && (s2_d1_sstride == 1) && (d_d1_sstride == 1) &&
199       (s1_d2_sstride == 1) && (s2_d2_sstride == 1) && (d_d2_sstride == 1) &&
200       (s1_d1_lstride == 1) && (s2_d1_lstride == 1)) {
201 
202     s1_base += s1_d2_soffset * s1_d2_lstride;
203     s2_base += s2_d1_soffset * s2_d1_lstride;
204     if (s2_rank == 1) {
205       F90_MATMUL(real8_str1_mxv)(dest_base + d_d1_soffset*d_d1_lstride +
206                                             d_d2_soffset*d_d2_lstride,
207                                         s1_base + s1_d1_soffset * s1_d1_lstride,
208                                         s2_base + s2_d2_soffset * s2_d2_lstride,
209                                         &n_extent,&m_extent,
210                                         &s1_d2_lstride, &d_d1_lstride);
211 
212     } else if (s1_rank == 1) {
213       F90_MATMUL(real8_str1_vxm)( dest_base + d_d1_soffset*d_d1_lstride +
214                                             d_d2_soffset*d_d2_lstride,
215                                      s1_base + s1_d1_soffset * s1_d1_lstride,
216                                      s2_base + s2_d2_soffset * s2_d2_lstride,
217                                      &k_extent,&m_extent,
218                                      &s2_d2_lstride, &d_d1_lstride);
219     } else {
220       F90_MATMUL(real8_str1)(dest_base + d_d1_soffset*d_d1_lstride +
221                                             d_d2_soffset*d_d2_lstride,
222                                     s1_base + s1_d1_soffset * s1_d1_lstride,
223                                     s2_base + s2_d2_soffset * s2_d2_lstride,
224                                     &k_extent,&m_extent,&n_extent,
225                                     &s1_d2_lstride,&s2_d2_lstride,&d_d2_lstride,
226                                     &d_d1_lstride);
227     }
228   } else if (s1_rank == 2) {
229     for (k = 0; k < k_extent; k++) {
230       d_elem_p = dest_base + d_d1_base + d_d2_offset;
231       d_d2_offset += d_d2_delta;
232       for (n = 0; n < n_extent; n++) {
233         *d_elem_p = 0;
234         d_elem_p += d_d1_delta;
235       }
236     }
237 
238     d_d2_offset = d_d2_base;
239     for (k = 0; k < k_extent; k++) {
240       s2_elem_p = s2_base + s2_d1_base + s2_d2_offset;
241       s2_d2_offset += s2_d2_delta;
242       s1_d2_offset = s1_d2_base;
243       for (m = 0; m < m_extent; m++) {
244         s1_elem_p = s1_base + s1_d1_base + s1_d2_offset;
245         s1_d2_offset += s1_d2_delta;
246         d_elem_p = dest_base + d_d1_base + d_d2_offset;
247         for (n = 0; n < n_extent; n++) {
248           *d_elem_p += *s1_elem_p * *s2_elem_p;
249 
250           d_elem_p += d_d1_delta;
251           s1_elem_p += s1_d1_delta;
252         }
253         s2_elem_p += s2_d1_delta;
254       }
255       d_d2_offset += d_d2_delta;
256     }
257   } else {
258     s1_base += s1_d1_base;
259     s2_base += s2_d1_soffset * s2_d1_lstride;
260     dest_offset = d_d1_base;
261     for (k = 0; k < k_extent; k++) {
262       s1_elem_p = s1_base;
263       s2_elem_p = s2_base + s2_d2_base;
264       rslt_tmp = 0;
265       for (m = 0; m < m_extent; m++) {
266         rslt_tmp += *s1_elem_p * *s2_elem_p;
267         s1_elem_p += s1_d1_delta;
268         s2_elem_p += s2_d1_delta;
269       }
270       *(dest_base + dest_offset) = rslt_tmp;
271       dest_offset += d_d1_delta;
272       s2_d2_base += s2_d2_delta;
273     }
274   }
275 }
276