1 /*
2  * Copyright (c) 2017, NVIDIA CORPORATION.  All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 /* clang-format off */
19 
20 /** \file
21  * \brief F90  MATMUL intrinsics for COMPLEX*8 type
22  */
23 
24 #include "stdioInterf.h"
25 #include "fioMacros.h"
26 #include "matmul.h"
27 
ENTF90(MATMUL_CPLX8,matmul_cplx8mxv_t)28 void ENTF90(MATMUL_CPLX8,
29             matmul_cplx8mxv_t)(char *dest_addr, char *s1_addr, char *s2_addr,
30                                int *t_flag, F90_Desc *dest_desc,
31                                F90_Desc *s1_desc, F90_Desc *s2_desc)
32 {
33 
34   __CPLX8_T *s1_base;
35   __CPLX8_T *s2_base;
36   __CPLX8_T *dest_base;
37 
38   __CPLX8_T rslt_tmp;
39   __CPLX8_T *s1_p;
40   __CPLX8_T *s2_p;
41 
42   __INT_T s1_d1_lstride;
43   __INT_T s1_d1_sstride;
44   __INT_T s1_d1_lb;
45   __INT_T s1_d1_soffset = 0;
46 
47   __INT_T s1_d2_lstride = 1;
48   __INT_T s1_d2_sstride = 1;
49   __INT_T s1_d2_lb = 0;
50   __INT_T s1_d2_soffset = 0;
51 
52   __INT_T s2_d1_lstride;
53   __INT_T s2_d1_sstride;
54   __INT_T s2_d1_lb;
55   __INT_T s2_d1_soffset = 0;
56 
57   __INT_T s2_d2_lstride = 1;
58   __INT_T s2_d2_sstride = 1;
59   __INT_T s2_d2_lb = 0;
60   __INT_T s2_d2_soffset = 0;
61 
62   __INT_T d_d1_lstride;
63   __INT_T d_d1_sstride;
64   __INT_T d_d1_lb;
65   __INT_T d_d1_soffset = 0;
66 
67   __INT_T d_d2_lstride = 1;
68   __INT_T d_d2_sstride = 1;
69   __INT_T d_d2_lb = 0;
70   __INT_T d_d2_soffset = 0;
71 
72   __INT_T d_rank = F90_RANK_G(dest_desc);
73   __INT_T s1_rank = F90_RANK_G(s1_desc);
74   __INT_T s2_rank = F90_RANK_G(s2_desc);
75 
76   __INT_T k_extent = s2_rank == 2 ? F90_DIM_EXTENT_G(s2_desc, 1) : 1;
77   __INT_T m_extent = s1_rank == 2 ? F90_DIM_EXTENT_G(s1_desc, 1)
78                                   : F90_DIM_EXTENT_G(s1_desc, 0);
79   __INT_T n_extent = s1_rank == 2 ? F90_DIM_EXTENT_G(s1_desc, 0) : 1;
80 
81   /* mxm
82    *  transpose(s1(n,m)) x s2(n,k) -> dest(m,k)
83    *  Check
84    *   dest_d1 extent== m_extnet
85    *   dest_d2 extent == k_extent
86    *   s2_d1 extent = n_extent
87    *
88    * mxv
89    *  transpose(s1(n,m)) x s2(n) -> dest(m)
90    *  Check
91    *   dest_d1 extent== m_extent
92    *   s2_d1 extent == n_extent
93    */
94 
95   if (d_rank == 2 && s1_rank == 2 && s2_rank == 2) {
96     if (F90_DIM_EXTENT_G(dest_desc, 0) != m_extent ||
97         F90_DIM_EXTENT_G(dest_desc, 1) != n_extent ||
98         F90_DIM_EXTENT_G(s2_desc, 0) != n_extent) {
99       __fort_abort("MATMUL: nonconforming array shapes");
100     }
101   } else if (d_rank == 1 && s1_rank == 2 && s2_rank == 1) {
102     if (F90_DIM_EXTENT_G(dest_desc, 0) != m_extent ||
103         F90_DIM_EXTENT_G(s2_desc, 0) != n_extent) {
104       __fort_abort("MATMUL: nonconforming array shapes");
105     }
106   } else {
107     __fort_abort("MATMUL: non-conforming array shapes");
108   }
109 
110   s1_d1_lstride = F90_DIM_LSTRIDE_G(s1_desc, 0);
111   s1_d1_sstride = F90_DIM_SSTRIDE_G(s1_desc, 0);
112   s1_d1_lb = F90_DIM_LBOUND_G(s1_desc, 0);
113   if (s1_d1_sstride != 1 || F90_DIM_SOFFSET_G(s1_desc, 0))
114     s1_d1_soffset = F90_DIM_SOFFSET_G(s1_desc, 0) + s1_d1_sstride - s1_d1_lb;
115 
116   if (s1_rank == 2) {
117     s1_d2_lstride = F90_DIM_LSTRIDE_G(s1_desc, 1);
118     s1_d2_lb = F90_DIM_LBOUND_G(s1_desc, 1);
119     s1_d2_sstride = F90_DIM_SSTRIDE_G(s1_desc, 1);
120     if (s1_d2_sstride != 1 || F90_DIM_SOFFSET_G(s1_desc, 1))
121       s1_d2_soffset = F90_DIM_SOFFSET_G(s1_desc, 1) + s1_d2_sstride - s1_d2_lb;
122   }
123 
124   s2_d1_lstride = F90_DIM_LSTRIDE_G(s2_desc, 0);
125   s2_d1_lb = F90_DIM_LBOUND_G(s2_desc, 0);
126   s2_d1_sstride = F90_DIM_SSTRIDE_G(s2_desc, 0);
127   if (s2_d1_sstride != 1 || F90_DIM_SOFFSET_G(s2_desc, 0))
128     s2_d1_soffset = F90_DIM_SOFFSET_G(s2_desc, 0) + s2_d1_sstride - s2_d1_lb;
129 
130   if (s2_rank == 2) {
131     s2_d2_lstride = F90_DIM_LSTRIDE_G(s2_desc, 1);
132     s2_d2_lb = F90_DIM_LBOUND_G(s2_desc, 1);
133     s2_d2_sstride = F90_DIM_SSTRIDE_G(s2_desc, 1);
134     if (s2_d2_sstride != 1 || F90_DIM_SOFFSET_G(s2_desc, 1))
135       s2_d2_soffset = F90_DIM_SOFFSET_G(s2_desc, 1) + s2_d2_sstride - s2_d2_lb;
136   }
137 
138   d_d1_lstride = F90_DIM_LSTRIDE_G(dest_desc, 0);
139   d_d1_lb = F90_DIM_LBOUND_G(dest_desc, 0);
140   d_d1_sstride = F90_DIM_SSTRIDE_G(dest_desc, 0);
141   if (d_d1_sstride != 1 || F90_DIM_SOFFSET_G(dest_desc, 0))
142     d_d1_soffset = F90_DIM_SOFFSET_G(dest_desc, 0) + d_d1_sstride - d_d1_lb;
143 
144   if (d_rank == 2) {
145     d_d2_lstride = F90_DIM_LSTRIDE_G(dest_desc, 1);
146     d_d2_lb = F90_DIM_LBOUND_G(dest_desc, 1);
147     d_d2_sstride = F90_DIM_SSTRIDE_G(dest_desc, 1);
148     if (d_d2_sstride != 1 || F90_DIM_SOFFSET_G(dest_desc, 1))
149       d_d2_soffset = F90_DIM_SOFFSET_G(dest_desc, 1) + d_d2_sstride - d_d2_lb;
150   }
151 
152   if ((s1_d1_sstride == 1) && (s2_d1_sstride == 1) && (d_d1_sstride == 1) &&
153       (s1_d2_sstride == 1) && (s2_d2_sstride == 1) && (d_d2_sstride == 1) &&
154       (s1_d1_lstride == 1) && (s2_d1_lstride == 1)) {
155 
156     s1_base = (__CPLX8_T *)s1_addr + F90_LBASE_G(s1_desc) +
157               s1_d2_soffset * s1_d2_lstride + s1_d1_lb * s1_d1_lstride +
158               s1_d2_lb * s1_d2_lstride - 1;
159     s2_base = (__CPLX8_T *)s2_addr + F90_LBASE_G(s2_desc) +
160               s2_d1_soffset * s2_d1_lstride + s2_d1_lb * s2_d1_lstride +
161               s2_d2_lb * s2_d2_lstride - 1;
162     dest_base = (__CPLX8_T *)dest_addr + F90_LBASE_G(dest_desc) +
163                 d_d1_lb * d_d1_lstride + d_d2_lb * d_d2_lstride - 1;
164 
165     if (s2_rank == 1) {
166       F90_MATMUL(cplx8_str1_mxv_t)( dest_base + d_d1_soffset*d_d1_lstride +
167                                             d_d2_soffset*d_d2_lstride,
168                                      s1_base + s1_d1_soffset * s1_d1_lstride,
169                                      s2_base + s2_d2_soffset * s2_d2_lstride,
170                                      &n_extent,&m_extent,
171                                      &s1_d2_lstride, &d_d1_lstride);
172 
173     } else {
174       __fort_abort(
175           "Internal Error: matrix by matrix matmul/transpose not implemented");
176     }
177     return;
178   }
179 
180   /* transpose s1 */
181   {
182     __INT_T dest_offset;
183     __INT_T s1_d1_base, s1_d1_offset, s1_m_delta, s1_d2_base, s1_n_delta,
184         s2_d1_base, s2_n_delta, s2_d2_base, s2_k_delta, d_d1_base, d_m_delta,
185         d_d2_base, d_k_delta;
186     __INT_T k;
187     __INT_T l;
188     __INT_T m;
189     __INT_T n;
190 
191     l = s1_d1_lstride;
192     s1_d1_lstride = s1_d2_lstride;
193     s1_d2_lstride = l;
194 
195     s1_base = (__CPLX8_T *)s1_addr + F90_LBASE_G(s1_desc) +
196               s1_d1_lb * s1_d1_lstride + s1_d2_lb * s1_d2_lstride - 1;
197     s2_base = (__CPLX8_T *)s2_addr + F90_LBASE_G(s2_desc) +
198               s2_d1_lb * s2_d1_lstride + s2_d2_lb * s2_d2_lstride - 1;
199     dest_base = (__CPLX8_T *)dest_addr + F90_LBASE_G(dest_desc) +
200                 d_d1_lb * d_d1_lstride + d_d2_lb * d_d2_lstride - 1;
201 
202     d_d1_base = d_d1_soffset * d_d1_lstride;
203     d_m_delta = d_d1_sstride * d_d1_lstride;
204     d_d2_base = d_d2_soffset * d_d2_lstride;
205     d_k_delta = s1_rank == 2 ? d_d2_sstride * d_d2_lstride : d_m_delta;
206 
207     s1_d1_base = s1_d1_soffset * s1_d1_lstride;
208     s1_d1_offset = s1_d1_base;
209     s1_m_delta = s1_d1_sstride * s1_d1_lstride;
210     s1_base += s1_d2_soffset * s1_d2_lstride;
211     s1_n_delta = s1_rank == 2 ? s1_d2_sstride * s1_d2_lstride : s1_m_delta;
212 
213     s2_base += s2_d1_soffset * s2_d1_lstride;
214     s2_n_delta = s2_d1_sstride * s2_d1_lstride;
215     s2_d2_base = s2_d2_soffset * s2_d2_lstride;
216     s2_k_delta = s2_d2_sstride * s2_d2_lstride;
217 
218     for (k = 0; k < k_extent; k++) {
219       dest_offset = d_d1_base + d_d2_base;
220       d_d2_base += d_k_delta;
221       s1_d1_offset = s1_d1_base;
222       for (m = 0; m < m_extent; m++) {
223         s1_p = s1_base + s1_d1_offset;
224         s1_d1_offset += s1_m_delta;
225         s2_p = s2_base + s2_d2_base;
226         rslt_tmp.r = 0;
227         rslt_tmp.i = 0;
228         for (n = 0; n < n_extent; n++) {
229           rslt_tmp.r += s1_p->r * s2_p->r - s1_p->i * s2_p->i;
230           rslt_tmp.i += s1_p->r * s2_p->i + s1_p->i * s2_p->r;
231 
232           s1_p += s1_n_delta;
233           s2_p += s2_n_delta;
234         }
235         *(dest_base + dest_offset) = rslt_tmp;
236         dest_offset += d_m_delta;
237       }
238       s2_d2_base += s2_k_delta;
239     }
240   }
241 }
242