1 /*
2 * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 */
17
18 /* clang-format off */
19
20 /* red_sum.c -- intrinsic reduction function */
21
22 #include "stdioInterf.h"
23 #include "fioMacros.h"
24 #include "red.h"
25
26 #define CSUMFN(NAME, RTYP, ATYP) \
27 static void l_##NAME(RTYP *r, __INT_T n, RTYP *v, __INT_T vs, __LOG_T *m, \
28 __INT_T ms, __INT_T *loc, __INT_T li, __INT_T ls) \
29 { \
30 __INT_T i, j; \
31 ATYP xr = r->r, xi = r->i; \
32 __LOG_T mask_log; \
33 if (ms == 0) \
34 for (i = 0; n > 0; n--, i += vs) { \
35 xr += v[i].r; \
36 xi += v[i].i; \
37 } \
38 else { \
39 mask_log = GET_DIST_MASK_LOG; \
40 for (i = j = 0; n > 0; n--, i += vs, j += ms) \
41 if (m[j] & mask_log) { \
42 xr += v[i].r; \
43 xi += v[i].i; \
44 } \
45 } \
46 r->r = xr; \
47 r->i = xi; \
48 } \
49 static void g_##NAME(__INT_T n, RTYP *lr, RTYP *rr, void *lv, void *rv) \
50 { \
51 __INT_T i; \
52 for (i = 0; i < n; i++) { \
53 lr[i].r += rr[i].r; \
54 lr[i].i += rr[i].i; \
55 } \
56 }
57
58 #define CSUMFNLKN(NAME, RTYP, ATYP, N) \
59 static void l_##NAME##l##N(RTYP *r, __INT_T n, RTYP *v, __INT_T vs, \
60 __LOG##N##_T *m, __INT_T ms, __INT_T *loc, \
61 __INT_T li, __INT_T ls) \
62 { \
63 __INT_T i, j; \
64 ATYP xr = r->r, xi = r->i; \
65 __LOG##N##_T mask_log; \
66 if (ms == 0) \
67 for (i = 0; n > 0; n--, i += vs) { \
68 xr += v[i].r; \
69 xi += v[i].i; \
70 } \
71 else { \
72 mask_log = GET_DIST_MASK_LOG##N; \
73 for (i = j = 0; n > 0; n--, i += vs, j += ms) \
74 if (m[j] & mask_log) { \
75 xr += v[i].r; \
76 xi += v[i].i; \
77 } \
78 } \
79 r->r = xr; \
80 r->i = xi; \
81 }
82
83 ARITHFN(+, sum_int1, __INT1_T, long)
84 ARITHFN(+, sum_int2, __INT2_T, long)
85 ARITHFN(+, sum_int4, __INT4_T, long)
86 ARITHFN(+, sum_int8, __INT8_T, __INT8_T)
87 ARITHFN(+, sum_real4, __REAL4_T, __REAL4_T)
88 ARITHFN(+, sum_real8, __REAL8_T, __REAL8_T)
89 ARITHFN(+, sum_real16, __REAL16_T, __REAL16_T)
90 CSUMFN(sum_cplx8, __CPLX8_T, __REAL4_T)
91 CSUMFN(sum_cplx16, __CPLX16_T, __REAL8_T)
92 CSUMFN(sum_cplx32, __CPLX32_T, __REAL16_T)
93
94 ARITHFNLKN(+, sum_int1, __INT1_T, long, 1)
95 ARITHFNLKN(+, sum_int2, __INT2_T, long, 1)
96 ARITHFNLKN(+, sum_int4, __INT4_T, long, 1)
97 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 1)
98 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 1)
99 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 1)
100 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 1)
101 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 1)
102 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 1)
103 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 1)
104
105 ARITHFNLKN(+, sum_int1, __INT1_T, long, 2)
106 ARITHFNLKN(+, sum_int2, __INT2_T, long, 2)
107 ARITHFNLKN(+, sum_int4, __INT4_T, long, 2)
108 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 2)
109 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 2)
110 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 2)
111 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 2)
112 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 2)
113 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 2)
114 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 2)
115
116 ARITHFNLKN(+, sum_int1, __INT1_T, long, 4)
117 ARITHFNLKN(+, sum_int2, __INT2_T, long, 4)
118 ARITHFNLKN(+, sum_int4, __INT4_T, long, 4)
119 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 4)
120 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 4)
121 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 4)
122 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 4)
123 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 4)
124 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 4)
125 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 4)
126
127 ARITHFNLKN(+, sum_int1, __INT1_T, long, 8)
128 ARITHFNLKN(+, sum_int2, __INT2_T, long, 8)
129 ARITHFNLKN(+, sum_int4, __INT4_T, long, 8)
130 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 8)
131 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 8)
132 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 8)
133 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 8)
134 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 8)
135 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 8)
136 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 8)
137
138 static void (*l_sum[4][__NTYPES])() = TYPELIST1LK(l_sum_);
139 void (*I8(__fort_g_sum)[__NTYPES])() = TYPELIST1(g_sum_);
140
141 /* dim absent */
142
ENTFTN(SUMS,sums)143 void ENTFTN(SUMS, sums)(char *rb, char *ab, char *mb, DECL_HDR_PTRS(rs),
144 F90_Desc *as, F90_Desc *ms)
145 {
146 red_parm z;
147
148 INIT_RED_PARM(z);
149 __fort_red_what = "SUM";
150
151 z.kind = F90_KIND_G(as);
152 z.len = F90_LEN_G(as);
153 z.mask_present = (F90_TAG_G(ms) == __DESC && F90_RANK_G(ms) > 0);
154 if (!z.mask_present) {
155 z.lk_shift = GET_DIST_SHIFTS(__LOG);
156 } else {
157 z.lk_shift = GET_DIST_SHIFTS(F90_KIND_G(ms));
158 }
159 z.l_fn = l_sum[z.lk_shift][z.kind];
160 z.g_fn = I8(__fort_g_sum)[z.kind];
161 z.zb = GET_DIST_ZED;
162 I8(__fort_red_scalar)(&z, rb, ab, mb, rs, as, ms, NULL, __SUM);
163 }
164
165 /* dim present */
166
ENTFTN(SUM,sum)167 void ENTFTN(SUM, sum)(char *rb, char *ab, char *mb, char *db, DECL_HDR_PTRS(rs),
168 F90_Desc *as, F90_Desc *ms, F90_Desc *ds)
169 {
170 red_parm z;
171
172 INIT_RED_PARM(z);
173 __fort_red_what = "SUM";
174
175 z.kind = F90_KIND_G(as);
176 z.len = F90_LEN_G(as);
177 z.mask_present = (F90_TAG_G(ms) == __DESC && F90_RANK_G(ms) > 0);
178 if (!z.mask_present) {
179 z.lk_shift = GET_DIST_SHIFTS(__LOG);
180 } else {
181 z.lk_shift = GET_DIST_SHIFTS(F90_KIND_G(ms));
182 }
183 z.l_fn = l_sum[z.lk_shift][z.kind];
184 z.g_fn = I8(__fort_g_sum)[z.kind];
185 z.zb = GET_DIST_ZED;
186 if (ISSCALAR(ms)) {
187 DECL_HDR_VARS(ms2);
188
189 mb = (char *)I8(__fort_create_conforming_mask_array)(__fort_red_what, ab, mb,
190 as, ms, ms2);
191 I8(__fort_red_array)(&z, rb, ab, mb, db, rs, as, ms2, ds, __SUM);
192 __fort_gfree(mb);
193 } else {
194 I8(__fort_red_array)(&z, rb, ab, mb, db, rs, as, ms, ds, __SUM);
195 }
196 }
197
198 /* global SUM accumulation */
199
ENTFTN(REDUCE_SUM,reduce_sum)200 void ENTFTN(REDUCE_SUM, reduce_sum)(char *hb, __INT_T *dimsb, __INT_T *nargb,
201 char *rb, DECL_HDR_PTRS(hd),
202 F90_Desc *dimsd, F90_Desc *nargd,
203 F90_Desc *rd)
204 {
205 #if defined(DEBUG)
206 if (dimsd == NULL || dimsd->tag != __INT)
207 __fort_abort("GLOBAL_SUM: invalid dims descriptor");
208 if (nargd == NULL || nargd->tag != __INT)
209 __fort_abort("REDUCE_SUM: invalid arg count descriptor");
210 if (*nargb != 1)
211 __fort_abort("REDUCE_SUM: arg count not 1");
212 #endif
213 I8(__fort_global_reduce)(rb, hb, *dimsb, rd, hd, "SUM", I8(__fort_g_sum));
214 }
215
ENTFTN(GLOBAL_SUM,global_sum)216 void ENTFTN(GLOBAL_SUM, global_sum)(char *rb, char *hb, __INT_T *dimsb,
217 DECL_HDR_PTRS(rd), F90_Desc *hd,
218 F90_Desc *dimsd)
219 {
220 I8(__fort_global_reduce)(rb, hb, *dimsb, rd, hd, "SUM", I8(__fort_g_sum));
221 }
222