1 /*
2  * Copyright (c) 2017, NVIDIA CORPORATION.  All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 /* clang-format off */
19 
20 /* red_sum.c -- intrinsic reduction function */
21 
22 #include "stdioInterf.h"
23 #include "fioMacros.h"
24 #include "red.h"
25 
26 #define CSUMFN(NAME, RTYP, ATYP)                                               \
27   static void l_##NAME(RTYP *r, __INT_T n, RTYP *v, __INT_T vs, __LOG_T *m,    \
28                        __INT_T ms, __INT_T *loc, __INT_T li, __INT_T ls)       \
29   {                                                                            \
30     __INT_T i, j;                                                              \
31     ATYP xr = r->r, xi = r->i;                                                 \
32     __LOG_T mask_log;                                                          \
33     if (ms == 0)                                                               \
34       for (i = 0; n > 0; n--, i += vs) {                                       \
35         xr += v[i].r;                                                          \
36         xi += v[i].i;                                                          \
37       }                                                                        \
38     else {                                                                     \
39       mask_log = GET_DIST_MASK_LOG;                                           \
40       for (i = j = 0; n > 0; n--, i += vs, j += ms)                            \
41         if (m[j] & mask_log) {                                                 \
42           xr += v[i].r;                                                        \
43           xi += v[i].i;                                                        \
44         }                                                                      \
45     }                                                                          \
46     r->r = xr;                                                                 \
47     r->i = xi;                                                                 \
48   }                                                                            \
49   static void g_##NAME(__INT_T n, RTYP *lr, RTYP *rr, void *lv, void *rv)      \
50   {                                                                            \
51     __INT_T i;                                                                 \
52     for (i = 0; i < n; i++) {                                                  \
53       lr[i].r += rr[i].r;                                                      \
54       lr[i].i += rr[i].i;                                                      \
55     }                                                                          \
56   }
57 
58 #define CSUMFNLKN(NAME, RTYP, ATYP, N)                                         \
59   static void l_##NAME##l##N(RTYP *r, __INT_T n, RTYP *v, __INT_T vs,          \
60                              __LOG##N##_T *m, __INT_T ms, __INT_T *loc,        \
61                              __INT_T li, __INT_T ls)                           \
62   {                                                                            \
63     __INT_T i, j;                                                              \
64     ATYP xr = r->r, xi = r->i;                                                 \
65     __LOG##N##_T mask_log;                                                     \
66     if (ms == 0)                                                               \
67       for (i = 0; n > 0; n--, i += vs) {                                       \
68         xr += v[i].r;                                                          \
69         xi += v[i].i;                                                          \
70       }                                                                        \
71     else {                                                                     \
72       mask_log = GET_DIST_MASK_LOG##N;                                        \
73       for (i = j = 0; n > 0; n--, i += vs, j += ms)                            \
74         if (m[j] & mask_log) {                                                 \
75           xr += v[i].r;                                                        \
76           xi += v[i].i;                                                        \
77         }                                                                      \
78     }                                                                          \
79     r->r = xr;                                                                 \
80     r->i = xi;                                                                 \
81   }
82 
83 ARITHFN(+, sum_int1, __INT1_T, long)
84 ARITHFN(+, sum_int2, __INT2_T, long)
85 ARITHFN(+, sum_int4, __INT4_T, long)
86 ARITHFN(+, sum_int8, __INT8_T, __INT8_T)
87 ARITHFN(+, sum_real4, __REAL4_T, __REAL4_T)
88 ARITHFN(+, sum_real8, __REAL8_T, __REAL8_T)
89 ARITHFN(+, sum_real16, __REAL16_T, __REAL16_T)
90 CSUMFN(sum_cplx8, __CPLX8_T, __REAL4_T)
91 CSUMFN(sum_cplx16, __CPLX16_T, __REAL8_T)
92 CSUMFN(sum_cplx32, __CPLX32_T, __REAL16_T)
93 
94 ARITHFNLKN(+, sum_int1, __INT1_T, long, 1)
95 ARITHFNLKN(+, sum_int2, __INT2_T, long, 1)
96 ARITHFNLKN(+, sum_int4, __INT4_T, long, 1)
97 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 1)
98 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 1)
99 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 1)
100 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 1)
101 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 1)
102 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 1)
103 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 1)
104 
105 ARITHFNLKN(+, sum_int1, __INT1_T, long, 2)
106 ARITHFNLKN(+, sum_int2, __INT2_T, long, 2)
107 ARITHFNLKN(+, sum_int4, __INT4_T, long, 2)
108 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 2)
109 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 2)
110 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 2)
111 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 2)
112 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 2)
113 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 2)
114 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 2)
115 
116 ARITHFNLKN(+, sum_int1, __INT1_T, long, 4)
117 ARITHFNLKN(+, sum_int2, __INT2_T, long, 4)
118 ARITHFNLKN(+, sum_int4, __INT4_T, long, 4)
119 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 4)
120 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 4)
121 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 4)
122 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 4)
123 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 4)
124 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 4)
125 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 4)
126 
127 ARITHFNLKN(+, sum_int1, __INT1_T, long, 8)
128 ARITHFNLKN(+, sum_int2, __INT2_T, long, 8)
129 ARITHFNLKN(+, sum_int4, __INT4_T, long, 8)
130 ARITHFNLKN(+, sum_int8, __INT8_T, __INT8_T, 8)
131 ARITHFNLKN(+, sum_real4, __REAL4_T, __REAL4_T, 8)
132 ARITHFNLKN(+, sum_real8, __REAL8_T, __REAL8_T, 8)
133 ARITHFNLKN(+, sum_real16, __REAL16_T, __REAL16_T, 8)
134 CSUMFNLKN(sum_cplx8, __CPLX8_T, __REAL4_T, 8)
135 CSUMFNLKN(sum_cplx16, __CPLX16_T, __REAL8_T, 8)
136 CSUMFNLKN(sum_cplx32, __CPLX32_T, __REAL16_T, 8)
137 
138 static void (*l_sum[4][__NTYPES])() = TYPELIST1LK(l_sum_);
139 void (*I8(__fort_g_sum)[__NTYPES])() = TYPELIST1(g_sum_);
140 
141 /* dim absent */
142 
ENTFTN(SUMS,sums)143 void ENTFTN(SUMS, sums)(char *rb, char *ab, char *mb, DECL_HDR_PTRS(rs),
144                         F90_Desc *as, F90_Desc *ms)
145 {
146   red_parm z;
147 
148   INIT_RED_PARM(z);
149   __fort_red_what = "SUM";
150 
151   z.kind = F90_KIND_G(as);
152   z.len = F90_LEN_G(as);
153   z.mask_present = (F90_TAG_G(ms) == __DESC && F90_RANK_G(ms) > 0);
154   if (!z.mask_present) {
155     z.lk_shift = GET_DIST_SHIFTS(__LOG);
156   } else {
157     z.lk_shift = GET_DIST_SHIFTS(F90_KIND_G(ms));
158   }
159   z.l_fn = l_sum[z.lk_shift][z.kind];
160   z.g_fn = I8(__fort_g_sum)[z.kind];
161   z.zb = GET_DIST_ZED;
162   I8(__fort_red_scalar)(&z, rb, ab, mb, rs, as, ms, NULL, __SUM);
163 }
164 
165 /* dim present */
166 
ENTFTN(SUM,sum)167 void ENTFTN(SUM, sum)(char *rb, char *ab, char *mb, char *db, DECL_HDR_PTRS(rs),
168                       F90_Desc *as, F90_Desc *ms, F90_Desc *ds)
169 {
170   red_parm z;
171 
172   INIT_RED_PARM(z);
173   __fort_red_what = "SUM";
174 
175   z.kind = F90_KIND_G(as);
176   z.len = F90_LEN_G(as);
177   z.mask_present = (F90_TAG_G(ms) == __DESC && F90_RANK_G(ms) > 0);
178   if (!z.mask_present) {
179     z.lk_shift = GET_DIST_SHIFTS(__LOG);
180   } else {
181     z.lk_shift = GET_DIST_SHIFTS(F90_KIND_G(ms));
182   }
183   z.l_fn = l_sum[z.lk_shift][z.kind];
184   z.g_fn = I8(__fort_g_sum)[z.kind];
185   z.zb = GET_DIST_ZED;
186   if (ISSCALAR(ms)) {
187     DECL_HDR_VARS(ms2);
188 
189     mb = (char *)I8(__fort_create_conforming_mask_array)(__fort_red_what, ab, mb,
190                                                         as, ms, ms2);
191     I8(__fort_red_array)(&z, rb, ab, mb, db, rs, as, ms2, ds, __SUM);
192     __fort_gfree(mb);
193   } else {
194     I8(__fort_red_array)(&z, rb, ab, mb, db, rs, as, ms, ds, __SUM);
195   }
196 }
197 
198 /* global SUM accumulation */
199 
ENTFTN(REDUCE_SUM,reduce_sum)200 void ENTFTN(REDUCE_SUM, reduce_sum)(char *hb, __INT_T *dimsb, __INT_T *nargb,
201                                     char *rb, DECL_HDR_PTRS(hd),
202                                     F90_Desc *dimsd, F90_Desc *nargd,
203                                     F90_Desc *rd)
204 {
205 #if defined(DEBUG)
206   if (dimsd == NULL || dimsd->tag != __INT)
207     __fort_abort("GLOBAL_SUM: invalid dims descriptor");
208   if (nargd == NULL || nargd->tag != __INT)
209     __fort_abort("REDUCE_SUM: invalid arg count descriptor");
210   if (*nargb != 1)
211     __fort_abort("REDUCE_SUM: arg count not 1");
212 #endif
213   I8(__fort_global_reduce)(rb, hb, *dimsb, rd, hd, "SUM", I8(__fort_g_sum));
214 }
215 
ENTFTN(GLOBAL_SUM,global_sum)216 void ENTFTN(GLOBAL_SUM, global_sum)(char *rb, char *hb, __INT_T *dimsb,
217                                     DECL_HDR_PTRS(rd), F90_Desc *hd,
218                                     F90_Desc *dimsd)
219 {
220   I8(__fort_global_reduce)(rb, hb, *dimsb, rd, hd, "SUM", I8(__fort_g_sum));
221 }
222