1 /*
2 * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 */
17
18 /* clang-format off */
19
20 /** \file
21 * \brief F90 MATMUL intrinsics for COMPLEX*8 type
22 */
23
24 #include "stdioInterf.h"
25 #include "fioMacros.h"
26 #include "matmul.h"
27
ENTF90(MATMUL_CPLX8,matmul_cplx8mxv_t)28 void ENTF90(MATMUL_CPLX8,
29 matmul_cplx8mxv_t)(char *dest_addr, char *s1_addr, char *s2_addr,
30 int *t_flag, F90_Desc *dest_desc,
31 F90_Desc *s1_desc, F90_Desc *s2_desc)
32 {
33
34 __CPLX8_T *s1_base;
35 __CPLX8_T *s2_base;
36 __CPLX8_T *dest_base;
37
38 __CPLX8_T rslt_tmp;
39 __CPLX8_T *s1_p;
40 __CPLX8_T *s2_p;
41
42 __INT_T s1_d1_lstride;
43 __INT_T s1_d1_sstride;
44 __INT_T s1_d1_lb;
45 __INT_T s1_d1_soffset = 0;
46
47 __INT_T s1_d2_lstride = 1;
48 __INT_T s1_d2_sstride = 1;
49 __INT_T s1_d2_lb = 0;
50 __INT_T s1_d2_soffset = 0;
51
52 __INT_T s2_d1_lstride;
53 __INT_T s2_d1_sstride;
54 __INT_T s2_d1_lb;
55 __INT_T s2_d1_soffset = 0;
56
57 __INT_T s2_d2_lstride = 1;
58 __INT_T s2_d2_sstride = 1;
59 __INT_T s2_d2_lb = 0;
60 __INT_T s2_d2_soffset = 0;
61
62 __INT_T d_d1_lstride;
63 __INT_T d_d1_sstride;
64 __INT_T d_d1_lb;
65 __INT_T d_d1_soffset = 0;
66
67 __INT_T d_d2_lstride = 1;
68 __INT_T d_d2_sstride = 1;
69 __INT_T d_d2_lb = 0;
70 __INT_T d_d2_soffset = 0;
71
72 __INT_T d_rank = F90_RANK_G(dest_desc);
73 __INT_T s1_rank = F90_RANK_G(s1_desc);
74 __INT_T s2_rank = F90_RANK_G(s2_desc);
75
76 __INT_T k_extent = s2_rank == 2 ? F90_DIM_EXTENT_G(s2_desc, 1) : 1;
77 __INT_T m_extent = s1_rank == 2 ? F90_DIM_EXTENT_G(s1_desc, 1)
78 : F90_DIM_EXTENT_G(s1_desc, 0);
79 __INT_T n_extent = s1_rank == 2 ? F90_DIM_EXTENT_G(s1_desc, 0) : 1;
80
81 /* mxm
82 * transpose(s1(n,m)) x s2(n,k) -> dest(m,k)
83 * Check
84 * dest_d1 extent== m_extnet
85 * dest_d2 extent == k_extent
86 * s2_d1 extent = n_extent
87 *
88 * mxv
89 * transpose(s1(n,m)) x s2(n) -> dest(m)
90 * Check
91 * dest_d1 extent== m_extent
92 * s2_d1 extent == n_extent
93 */
94
95 if (d_rank == 2 && s1_rank == 2 && s2_rank == 2) {
96 if (F90_DIM_EXTENT_G(dest_desc, 0) != m_extent ||
97 F90_DIM_EXTENT_G(dest_desc, 1) != n_extent ||
98 F90_DIM_EXTENT_G(s2_desc, 0) != n_extent) {
99 __fort_abort("MATMUL: nonconforming array shapes");
100 }
101 } else if (d_rank == 1 && s1_rank == 2 && s2_rank == 1) {
102 if (F90_DIM_EXTENT_G(dest_desc, 0) != m_extent ||
103 F90_DIM_EXTENT_G(s2_desc, 0) != n_extent) {
104 __fort_abort("MATMUL: nonconforming array shapes");
105 }
106 } else {
107 __fort_abort("MATMUL: non-conforming array shapes");
108 }
109
110 s1_d1_lstride = F90_DIM_LSTRIDE_G(s1_desc, 0);
111 s1_d1_sstride = F90_DIM_SSTRIDE_G(s1_desc, 0);
112 s1_d1_lb = F90_DIM_LBOUND_G(s1_desc, 0);
113 if (s1_d1_sstride != 1 || F90_DIM_SOFFSET_G(s1_desc, 0))
114 s1_d1_soffset = F90_DIM_SOFFSET_G(s1_desc, 0) + s1_d1_sstride - s1_d1_lb;
115
116 if (s1_rank == 2) {
117 s1_d2_lstride = F90_DIM_LSTRIDE_G(s1_desc, 1);
118 s1_d2_lb = F90_DIM_LBOUND_G(s1_desc, 1);
119 s1_d2_sstride = F90_DIM_SSTRIDE_G(s1_desc, 1);
120 if (s1_d2_sstride != 1 || F90_DIM_SOFFSET_G(s1_desc, 1))
121 s1_d2_soffset = F90_DIM_SOFFSET_G(s1_desc, 1) + s1_d2_sstride - s1_d2_lb;
122 }
123
124 s2_d1_lstride = F90_DIM_LSTRIDE_G(s2_desc, 0);
125 s2_d1_lb = F90_DIM_LBOUND_G(s2_desc, 0);
126 s2_d1_sstride = F90_DIM_SSTRIDE_G(s2_desc, 0);
127 if (s2_d1_sstride != 1 || F90_DIM_SOFFSET_G(s2_desc, 0))
128 s2_d1_soffset = F90_DIM_SOFFSET_G(s2_desc, 0) + s2_d1_sstride - s2_d1_lb;
129
130 if (s2_rank == 2) {
131 s2_d2_lstride = F90_DIM_LSTRIDE_G(s2_desc, 1);
132 s2_d2_lb = F90_DIM_LBOUND_G(s2_desc, 1);
133 s2_d2_sstride = F90_DIM_SSTRIDE_G(s2_desc, 1);
134 if (s2_d2_sstride != 1 || F90_DIM_SOFFSET_G(s2_desc, 1))
135 s2_d2_soffset = F90_DIM_SOFFSET_G(s2_desc, 1) + s2_d2_sstride - s2_d2_lb;
136 }
137
138 d_d1_lstride = F90_DIM_LSTRIDE_G(dest_desc, 0);
139 d_d1_lb = F90_DIM_LBOUND_G(dest_desc, 0);
140 d_d1_sstride = F90_DIM_SSTRIDE_G(dest_desc, 0);
141 if (d_d1_sstride != 1 || F90_DIM_SOFFSET_G(dest_desc, 0))
142 d_d1_soffset = F90_DIM_SOFFSET_G(dest_desc, 0) + d_d1_sstride - d_d1_lb;
143
144 if (d_rank == 2) {
145 d_d2_lstride = F90_DIM_LSTRIDE_G(dest_desc, 1);
146 d_d2_lb = F90_DIM_LBOUND_G(dest_desc, 1);
147 d_d2_sstride = F90_DIM_SSTRIDE_G(dest_desc, 1);
148 if (d_d2_sstride != 1 || F90_DIM_SOFFSET_G(dest_desc, 1))
149 d_d2_soffset = F90_DIM_SOFFSET_G(dest_desc, 1) + d_d2_sstride - d_d2_lb;
150 }
151
152 if ((s1_d1_sstride == 1) && (s2_d1_sstride == 1) && (d_d1_sstride == 1) &&
153 (s1_d2_sstride == 1) && (s2_d2_sstride == 1) && (d_d2_sstride == 1) &&
154 (s1_d1_lstride == 1) && (s2_d1_lstride == 1)) {
155
156 s1_base = (__CPLX8_T *)s1_addr + F90_LBASE_G(s1_desc) +
157 s1_d2_soffset * s1_d2_lstride + s1_d1_lb * s1_d1_lstride +
158 s1_d2_lb * s1_d2_lstride - 1;
159 s2_base = (__CPLX8_T *)s2_addr + F90_LBASE_G(s2_desc) +
160 s2_d1_soffset * s2_d1_lstride + s2_d1_lb * s2_d1_lstride +
161 s2_d2_lb * s2_d2_lstride - 1;
162 dest_base = (__CPLX8_T *)dest_addr + F90_LBASE_G(dest_desc) +
163 d_d1_lb * d_d1_lstride + d_d2_lb * d_d2_lstride - 1;
164
165 if (s2_rank == 1) {
166 F90_MATMUL(cplx8_str1_mxv_t)( dest_base + d_d1_soffset*d_d1_lstride +
167 d_d2_soffset*d_d2_lstride,
168 s1_base + s1_d1_soffset * s1_d1_lstride,
169 s2_base + s2_d2_soffset * s2_d2_lstride,
170 &n_extent,&m_extent,
171 &s1_d2_lstride, &d_d1_lstride);
172
173 } else {
174 __fort_abort(
175 "Internal Error: matrix by matrix matmul/transpose not implemented");
176 }
177 return;
178 }
179
180 /* transpose s1 */
181 {
182 __INT_T dest_offset;
183 __INT_T s1_d1_base, s1_d1_offset, s1_m_delta, s1_d2_base, s1_n_delta,
184 s2_d1_base, s2_n_delta, s2_d2_base, s2_k_delta, d_d1_base, d_m_delta,
185 d_d2_base, d_k_delta;
186 __INT_T k;
187 __INT_T l;
188 __INT_T m;
189 __INT_T n;
190
191 l = s1_d1_lstride;
192 s1_d1_lstride = s1_d2_lstride;
193 s1_d2_lstride = l;
194
195 s1_base = (__CPLX8_T *)s1_addr + F90_LBASE_G(s1_desc) +
196 s1_d1_lb * s1_d1_lstride + s1_d2_lb * s1_d2_lstride - 1;
197 s2_base = (__CPLX8_T *)s2_addr + F90_LBASE_G(s2_desc) +
198 s2_d1_lb * s2_d1_lstride + s2_d2_lb * s2_d2_lstride - 1;
199 dest_base = (__CPLX8_T *)dest_addr + F90_LBASE_G(dest_desc) +
200 d_d1_lb * d_d1_lstride + d_d2_lb * d_d2_lstride - 1;
201
202 d_d1_base = d_d1_soffset * d_d1_lstride;
203 d_m_delta = d_d1_sstride * d_d1_lstride;
204 d_d2_base = d_d2_soffset * d_d2_lstride;
205 d_k_delta = s1_rank == 2 ? d_d2_sstride * d_d2_lstride : d_m_delta;
206
207 s1_d1_base = s1_d1_soffset * s1_d1_lstride;
208 s1_d1_offset = s1_d1_base;
209 s1_m_delta = s1_d1_sstride * s1_d1_lstride;
210 s1_base += s1_d2_soffset * s1_d2_lstride;
211 s1_n_delta = s1_rank == 2 ? s1_d2_sstride * s1_d2_lstride : s1_m_delta;
212
213 s2_base += s2_d1_soffset * s2_d1_lstride;
214 s2_n_delta = s2_d1_sstride * s2_d1_lstride;
215 s2_d2_base = s2_d2_soffset * s2_d2_lstride;
216 s2_k_delta = s2_d2_sstride * s2_d2_lstride;
217
218 for (k = 0; k < k_extent; k++) {
219 dest_offset = d_d1_base + d_d2_base;
220 d_d2_base += d_k_delta;
221 s1_d1_offset = s1_d1_base;
222 for (m = 0; m < m_extent; m++) {
223 s1_p = s1_base + s1_d1_offset;
224 s1_d1_offset += s1_m_delta;
225 s2_p = s2_base + s2_d2_base;
226 rslt_tmp.r = 0;
227 rslt_tmp.i = 0;
228 for (n = 0; n < n_extent; n++) {
229 rslt_tmp.r += s1_p->r * s2_p->r - s1_p->i * s2_p->i;
230 rslt_tmp.i += s1_p->r * s2_p->i + s1_p->i * s2_p->r;
231
232 s1_p += s1_n_delta;
233 s2_p += s2_n_delta;
234 }
235 *(dest_base + dest_offset) = rslt_tmp;
236 dest_offset += d_m_delta;
237 }
238 s2_d2_base += s2_k_delta;
239 }
240 }
241 }
242