1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #ifndef CH4_IMPL_H_INCLUDED
7 #define CH4_IMPL_H_INCLUDED
8 
9 #include "ch4_types.h"
10 #include "mpidig_am.h"
11 #include "mpidu_shm.h"
12 #include "ch4r_proc.h"
13 
14 int MPIDI_Progress_test(int flags);
15 int MPIDIG_get_context_index(uint64_t context_id);
16 uint64_t MPIDIG_generate_win_id(MPIR_Comm * comm_ptr);
17 
18 /* Static inlines */
19 
20 /* Reconstruct context offset associated with a persistent request.
21  * Input must be a persistent request. */
MPIDI_prequest_get_context_offset(MPIR_Request * preq)22 MPL_STATIC_INLINE_PREFIX int MPIDI_prequest_get_context_offset(MPIR_Request * preq)
23 {
24     int context_offset;
25 
26     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_PREQUEST_GET_CONTEXT_OFFSET);
27     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_PREQUEST_GET_CONTEXT_OFFSET);
28 
29     MPIR_Assert(preq->kind == MPIR_REQUEST_KIND__PREQUEST_SEND ||
30                 preq->kind == MPIR_REQUEST_KIND__PREQUEST_RECV);
31 
32     context_offset = MPIDI_PREQUEST(preq, context_id) - preq->comm->context_id;
33 
34     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_PREQUEST_GET_CONTEXT_OFFSET);
35 
36     return context_offset;
37 }
38 
MPIDIG_context_id_to_comm(uint64_t context_id)39 MPL_STATIC_INLINE_PREFIX MPIR_Comm *MPIDIG_context_id_to_comm(uint64_t context_id)
40 {
41     int comm_idx = MPIDIG_get_context_index(context_id);
42     int subcomm_type = MPIR_CONTEXT_READ_FIELD(SUBCOMM, context_id);
43     int is_localcomm = MPIR_CONTEXT_READ_FIELD(IS_LOCALCOMM, context_id);
44     MPIR_Comm *ret;
45 
46     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_CONTEXT_ID_TO_COMM);
47     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_CONTEXT_ID_TO_COMM);
48 
49     MPIR_Assert(subcomm_type <= 3);
50     MPIR_Assert(is_localcomm <= 2);
51     ret = MPIDI_global.comm_req_lists[comm_idx].comm[is_localcomm][subcomm_type];
52 
53     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_CONTEXT_ID_TO_COMM);
54     return ret;
55 }
56 
MPIDIG_context_id_to_uelist(uint64_t context_id)57 MPL_STATIC_INLINE_PREFIX MPIDIG_rreq_t **MPIDIG_context_id_to_uelist(uint64_t context_id)
58 {
59     int comm_idx = MPIDIG_get_context_index(context_id);
60     int subcomm_type = MPIR_CONTEXT_READ_FIELD(SUBCOMM, context_id);
61     int is_localcomm = MPIR_CONTEXT_READ_FIELD(IS_LOCALCOMM, context_id);
62     MPIDIG_rreq_t **ret;
63 
64     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_CONTEXT_ID_TO_UELIST);
65     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_CONTEXT_ID_TO_UELIST);
66 
67     MPIR_Assert(subcomm_type <= 3);
68     MPIR_Assert(is_localcomm <= 2);
69 
70     ret = &MPIDI_global.comm_req_lists[comm_idx].uelist[is_localcomm][subcomm_type];
71 
72     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_CONTEXT_ID_TO_UELIST);
73     return ret;
74 }
75 
MPIDIG_win_id_to_context(uint64_t win_id)76 MPL_STATIC_INLINE_PREFIX MPIR_Context_id_t MPIDIG_win_id_to_context(uint64_t win_id)
77 {
78     MPIR_Context_id_t ret;
79 
80     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_ID_TO_CONTEXT);
81     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_ID_TO_CONTEXT);
82 
83     /* pick the lower 32-bit to extract context id */
84     ret = (win_id - 1) & 0xffffffff;
85 
86     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_ID_TO_CONTEXT);
87     return ret;
88 }
89 
MPIDIG_win_to_context(const MPIR_Win * win)90 MPL_STATIC_INLINE_PREFIX MPIR_Context_id_t MPIDIG_win_to_context(const MPIR_Win * win)
91 {
92     MPIR_Context_id_t ret;
93 
94     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_TO_CONTEXT);
95     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_TO_CONTEXT);
96 
97     ret = MPIDIG_win_id_to_context(MPIDIG_WIN(win, win_id));
98 
99     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_TO_CONTEXT);
100     return ret;
101 }
102 
MPIDIU_request_complete(MPIR_Request * req)103 MPL_STATIC_INLINE_PREFIX void MPIDIU_request_complete(MPIR_Request * req)
104 {
105     int incomplete;
106 
107     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIU_REQUEST_COMPLETE);
108     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIU_REQUEST_COMPLETE);
109 
110     MPIR_cc_decr(req->cc_ptr, &incomplete);
111     if (!incomplete) {
112         MPIR_Request_free_unsafe(req);
113     }
114 
115     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIU_REQUEST_COMPLETE);
116 }
117 
MPIDIG_win_target_add(MPIR_Win * win,int rank)118 MPL_STATIC_INLINE_PREFIX MPIDIG_win_target_t *MPIDIG_win_target_add(MPIR_Win * win, int rank)
119 {
120     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_TARGET_ADD);
121     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_TARGET_ADD);
122 
123     MPIDIG_win_target_t *target_ptr = NULL;
124     target_ptr = (MPIDIG_win_target_t *) MPL_malloc(sizeof(MPIDIG_win_target_t), MPL_MEM_RMA);
125     MPIR_Assert(target_ptr);
126     target_ptr->rank = rank;
127     MPIR_cc_set(&target_ptr->local_cmpl_cnts, 0);
128     MPIR_cc_set(&target_ptr->remote_cmpl_cnts, 0);
129     MPIR_cc_set(&target_ptr->remote_acc_cmpl_cnts, 0);
130     target_ptr->sync.lock.locked = 0;
131     target_ptr->sync.access_epoch_type = MPIDIG_EPOTYPE_NONE;
132     target_ptr->sync.assert_mode = 0;
133 
134     HASH_ADD(hash_handle, MPIDIG_WIN(win, targets), rank, sizeof(int), target_ptr, MPL_MEM_RMA);
135 
136     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_TARGET_ADD);
137     return target_ptr;
138 }
139 
MPIDIG_win_target_find(MPIR_Win * win,int rank)140 MPL_STATIC_INLINE_PREFIX MPIDIG_win_target_t *MPIDIG_win_target_find(MPIR_Win * win, int rank)
141 {
142     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_TARGET_FIND);
143     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_TARGET_FIND);
144 
145     MPIDIG_win_target_t *target_ptr = NULL;
146     HASH_FIND(hash_handle, MPIDIG_WIN(win, targets), &rank, sizeof(int), target_ptr);
147 
148     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_TARGET_FIND);
149     return target_ptr;
150 }
151 
MPIDIG_win_target_get(MPIR_Win * win,int rank)152 MPL_STATIC_INLINE_PREFIX MPIDIG_win_target_t *MPIDIG_win_target_get(MPIR_Win * win, int rank)
153 {
154     MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_find(win, rank);
155     if (!target_ptr)
156         target_ptr = MPIDIG_win_target_add(win, rank);
157     return target_ptr;
158 }
159 
MPIDIG_win_target_delete(MPIR_Win * win,MPIDIG_win_target_t * target_ptr)160 MPL_STATIC_INLINE_PREFIX void MPIDIG_win_target_delete(MPIR_Win * win,
161                                                        MPIDIG_win_target_t * target_ptr)
162 {
163     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_TARGET_DELETE);
164     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_TARGET_DELETE);
165 
166     HASH_DELETE(hash_handle, MPIDIG_WIN(win, targets), target_ptr);
167     MPL_free(target_ptr);
168 
169     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_TARGET_DELETE);
170 }
171 
MPIDIG_win_target_cleanall(MPIR_Win * win)172 MPL_STATIC_INLINE_PREFIX void MPIDIG_win_target_cleanall(MPIR_Win * win)
173 {
174     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_TARGET_CLEANALL);
175     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_TARGET_CLEANALL);
176 
177     MPIDIG_win_target_t *target_ptr, *tmp;
178     HASH_ITER(hash_handle, MPIDIG_WIN(win, targets), target_ptr, tmp) {
179         HASH_DELETE(hash_handle, MPIDIG_WIN(win, targets), target_ptr);
180         MPL_free(target_ptr);
181     }
182 
183     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_TARGET_CLEANALL);
184 }
185 
MPIDIG_win_hash_clear(MPIR_Win * win)186 MPL_STATIC_INLINE_PREFIX void MPIDIG_win_hash_clear(MPIR_Win * win)
187 {
188     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_HASH_CLEAR);
189     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_HASH_CLEAR);
190 
191     HASH_CLEAR(hash_handle, MPIDIG_WIN(win, targets));
192 
193     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_HASH_CLEAR);
194 }
195 
196 #define MPIDI_Datatype_get_info(count_, datatype_,              \
197                                 dt_contig_out_, data_sz_out_,   \
198                                 dt_ptr_, dt_true_lb_)           \
199     do {                                                        \
200         if (IS_BUILTIN(datatype_)) {                            \
201             (dt_ptr_)        = NULL;                            \
202             (dt_contig_out_) = TRUE;                            \
203             (dt_true_lb_)    = 0;                               \
204             (data_sz_out_)   = (size_t)(count_) *               \
205                 MPIR_Datatype_get_basic_size(datatype_);        \
206         } else {                                                \
207             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));      \
208             if (dt_ptr_)                                        \
209             {                                                   \
210                 (dt_contig_out_) = (dt_ptr_)->is_contig;        \
211                 (dt_true_lb_)    = (dt_ptr_)->true_lb;          \
212                 (data_sz_out_)   = (size_t)(count_) *           \
213                     (dt_ptr_)->size;                            \
214             }                                                   \
215             else                                                \
216             {                                                   \
217                 (dt_contig_out_) = 1;                           \
218                 (dt_true_lb_)    = 0;                           \
219                 (data_sz_out_)   = 0;                           \
220             }                                                   \
221         }                                                       \
222     } while (0)
223 
224 #define MPIDI_Datatype_get_size_dt_ptr(count_, datatype_,       \
225                                        data_sz_out_, dt_ptr_)   \
226     do {                                                        \
227         if (IS_BUILTIN(datatype_)) {                            \
228             (dt_ptr_)        = NULL;                            \
229             (data_sz_out_)   = (size_t)(count_) *               \
230                 MPIR_Datatype_get_basic_size(datatype_);        \
231         } else {                                                \
232             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));      \
233             (data_sz_out_)   = (dt_ptr_) ? (size_t)(count_) *   \
234                 (dt_ptr_)->size : 0;                            \
235         }                                                       \
236     } while (0)
237 
238 #define MPIDI_Datatype_check_contig(datatype_,dt_contig_out_)           \
239     do {                                                                \
240         if (IS_BUILTIN(datatype_)) {                                    \
241             (dt_contig_out_) = TRUE;                                    \
242         } else {                                                        \
243             MPIR_Datatype *dt_ptr_;                                     \
244             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));              \
245             (dt_contig_out_) = (dt_ptr_) ? (dt_ptr_)->is_contig : 1;    \
246         }                                                               \
247     } while (0)
248 
249 #define MPIDI_Datatype_check_contig_size(datatype_,count_,      \
250                                          dt_contig_out_,        \
251                                          data_sz_out_)          \
252     do {                                                        \
253         if (IS_BUILTIN(datatype_)) {                            \
254             (dt_contig_out_) = TRUE;                            \
255             (data_sz_out_)   = (size_t)(count_) *               \
256                 MPIR_Datatype_get_basic_size(datatype_);        \
257         } else {                                                \
258             MPIR_Datatype *dt_ptr_;                             \
259             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));      \
260             if (dt_ptr_) {                                      \
261                 (dt_contig_out_) = (dt_ptr_)->is_contig;        \
262                 (data_sz_out_)   = (size_t)(count_) *           \
263                     (dt_ptr_)->size;                            \
264             } else {                                            \
265                 (dt_contig_out_) = 1;                           \
266                 (data_sz_out_)   = 0;                           \
267             }                                                   \
268         }                                                       \
269     } while (0)
270 
271 #define MPIDI_Datatype_check_size(datatype_,count_,data_sz_out_)        \
272     do {                                                                \
273         if (IS_BUILTIN(datatype_)) {                                    \
274             (data_sz_out_)   = (size_t)(count_) *                       \
275                 MPIR_Datatype_get_basic_size(datatype_);                \
276         } else {                                                        \
277             MPIR_Datatype *dt_ptr_;                                     \
278             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));              \
279             (data_sz_out_)   = (dt_ptr_) ? (size_t)(count_) *           \
280                 (dt_ptr_)->size : 0;                                    \
281         }                                                               \
282     } while (0)
283 
284 #define MPIDI_Datatype_check_size_lb(datatype_,count_,data_sz_out_,     \
285                                      dt_true_lb_)                       \
286     do {                                                                \
287         if (IS_BUILTIN(datatype_)) {                                    \
288             (data_sz_out_)   = (size_t)(count_) *                       \
289                 MPIR_Datatype_get_basic_size(datatype_);                \
290             (dt_true_lb_)    = 0;                                       \
291         } else {                                                        \
292             MPIR_Datatype *dt_ptr_;                                     \
293             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));              \
294             (data_sz_out_)   = (dt_ptr_) ? (size_t)(count_) *           \
295                 (dt_ptr_)->size : 0;                                    \
296             (dt_true_lb_)    = (dt_ptr_) ? (dt_ptr_)->true_lb : 0;      \
297         }                                                               \
298     } while (0)
299 
300 #define MPIDI_Datatype_check_contig_size_lb(datatype_,count_,   \
301                                             dt_contig_out_,     \
302                                             data_sz_out_,       \
303                                             dt_true_lb_)        \
304     do {                                                        \
305         if (IS_BUILTIN(datatype_)) {                            \
306             (dt_contig_out_) = TRUE;                            \
307             (data_sz_out_)   = (size_t)(count_) *               \
308                 MPIR_Datatype_get_basic_size(datatype_);        \
309             (dt_true_lb_)    = 0;                               \
310         } else {                                                \
311             MPIR_Datatype *dt_ptr_;                             \
312             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));      \
313             if (dt_ptr_) {                                      \
314                 (dt_contig_out_) = (dt_ptr_)->is_contig;        \
315                 (data_sz_out_)   = (size_t)(count_) *           \
316                     (dt_ptr_)->size;                            \
317                 (dt_true_lb_)    = (dt_ptr_)->true_lb;          \
318             } else {                                            \
319                 (dt_contig_out_) = 1;                           \
320                 (data_sz_out_)   = 0;                           \
321                 (dt_true_lb_)    = 0;                           \
322             }                                                   \
323         }                                                       \
324     } while (0)
325 
326 #define MPIDI_Datatype_check_contig_lb(datatype_, dt_contig_out_, dt_true_lb_) \
327     do {                                                                       \
328         if (IS_BUILTIN(datatype_)) {                                           \
329             (dt_contig_out_) = TRUE;                                           \
330             (dt_true_lb_)    = 0;                                              \
331         } else {                                                               \
332             MPIR_Datatype *dt_ptr_;                                            \
333             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));                     \
334             if (dt_ptr_) {                                                     \
335                 (dt_contig_out_) = (dt_ptr_)->is_contig;                       \
336                 (dt_true_lb_)    = (dt_ptr_)->true_lb;                         \
337             } else {                                                           \
338                 (dt_contig_out_) = 1;                                          \
339                 (dt_true_lb_)    = 0;                                          \
340             }                                                                  \
341         }                                                                      \
342     } while (0)
343 
344 #define MPIDI_Datatype_check_lb(datatype_, dt_true_lb_)    \
345     do {                                                   \
346         if (IS_BUILTIN(datatype_)) {                       \
347             (dt_true_lb_)    = 0;                          \
348         } else {                                           \
349             MPIR_Datatype *dt_ptr_;                        \
350             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_)); \
351             if (dt_ptr_) {                                 \
352                 (dt_true_lb_)    = (dt_ptr_)->true_lb;     \
353             } else {                                       \
354                 (dt_true_lb_)    = 0;                      \
355             }                                              \
356         }                                                  \
357     } while (0)
358 
359 #define MPIDI_Datatype_check_contig_size_extent_lb(datatype_,count_,   \
360                                                    dt_contig_out_,     \
361                                                    data_sz_out_,       \
362                                                    dt_extent_out_,     \
363                                                    dt_true_lb_)        \
364     do {                                                        \
365         if (IS_BUILTIN(datatype_)) {                            \
366             (dt_contig_out_) = TRUE;                            \
367             (data_sz_out_) = (size_t)(count_) * MPIR_Datatype_get_basic_size(datatype_);  \
368             (dt_extent_out_) = (data_sz_out_);                  \
369             (dt_true_lb_) = 0;                                  \
370         } else {                                                \
371             MPIR_Datatype *dt_ptr_;                             \
372             MPIR_Datatype_get_ptr((datatype_), (dt_ptr_));      \
373             MPIR_Assert(dt_ptr_);                               \
374             (dt_contig_out_) = (dt_ptr_)->is_contig;            \
375             (data_sz_out_) = (size_t)(count_) * (dt_ptr_)->size;    \
376             (dt_extent_out_) = (size_t)(count_) * (dt_ptr_)->extent;\
377             (dt_true_lb_) = (dt_ptr_)->true_lb;                 \
378         }                                                       \
379     } while (0)
380 
381 /* Check both origin|target buffers' size. */
382 #define MPIDI_Datatype_check_origin_target_size(o_datatype_, t_datatype_,         \
383                                                 o_count_, t_count_,               \
384                                                 o_data_sz_out_, t_data_sz_out_)   \
385     do {                                                                          \
386         MPIDI_Datatype_check_size(o_datatype_, o_count_, o_data_sz_out_);         \
387         if (t_datatype_ == o_datatype_ && t_count_ == o_count_) {                 \
388             t_data_sz_out_ = o_data_sz_out_;                                      \
389         } else {                                                                  \
390             MPIDI_Datatype_check_size(t_datatype_, t_count_, t_data_sz_out_);     \
391         }                                                                         \
392     } while (0)
393 
394 /* Check both origin|target buffers' size, contig and lb. */
395 #define MPIDI_Datatype_check_origin_target_contig_size_lb(o_datatype_, t_datatype_,             \
396                                                           o_count_, t_count_,                   \
397                                                           o_dt_contig_out_, t_dt_contig_out_,   \
398                                                           o_data_sz_out_, t_data_sz_out_,       \
399                                                           o_dt_true_lb_, t_dt_true_lb_)         \
400     do {                                                                                        \
401         MPIDI_Datatype_check_contig_size_lb(o_datatype_, o_count_, o_dt_contig_out_,            \
402                                             o_data_sz_out_, o_dt_true_lb_);                     \
403         if (t_datatype_ == o_datatype_ && t_count_ == o_count_) {                               \
404             t_dt_contig_out_ = o_dt_contig_out_;                                                \
405             t_data_sz_out_ = o_data_sz_out_;                                                    \
406             t_dt_true_lb_ = o_dt_true_lb_;                                                      \
407         }                                                                                       \
408         else {                                                                                  \
409             MPIDI_Datatype_check_contig_size_lb(t_datatype_, t_count_, t_dt_contig_out_,        \
410                                                 t_data_sz_out_, t_dt_true_lb_);                 \
411         }                                                                                       \
412     } while (0)
413 
414 #define IS_BUILTIN(_datatype)                           \
415     (HANDLE_IS_BUILTIN(_datatype))
416 
417 /* We assume this routine is never called with rank=MPI_PROC_NULL. */
MPIDIU_valid_group_rank(MPIR_Comm * comm,int rank,MPIR_Group * grp)418 MPL_STATIC_INLINE_PREFIX int MPIDIU_valid_group_rank(MPIR_Comm * comm, int rank, MPIR_Group * grp)
419 {
420     int lpid;
421     int size = grp->size;
422     int z;
423     int ret;
424 
425     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIU_VALID_GROUP_RANK);
426     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIU_VALID_GROUP_RANK);
427 
428     MPIDI_NM_comm_get_lpid(comm, rank, &lpid, FALSE);
429 
430     for (z = 0; z < size && lpid != grp->lrank_to_lpid[z].lpid; ++z) {
431     }
432 
433     ret = (z < size);
434 
435   fn_exit:
436     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIU_VALID_GROUP_RANK);
437     return ret;
438 }
439 
440 /* TODO: Several unbounded loops call this macro. One way to avoid holding the
441  * ALLFUNC_MUTEX lock forever is to insert YIELD in each loop. We choose to
442  * insert it here for simplicity, but this might not be the best place. One
443  * needs to investigate the appropriate place to yield the lock. */
444 /* NOTE: Taking off VCI lock is necessary to avoid recursive locking and allow
445  * more granular per-vci locks */
446 /* TODO: MPIDI_global.vci_lock probably will be changed into granular generic lock
447  */
448 
449 #define MPIDIU_PROGRESS()                                   \
450     do {                                                        \
451         MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); \
452         mpi_errno = MPID_Progress_test(NULL);                       \
453         MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); \
454         MPIR_ERR_CHECK(mpi_errno);  \
455         MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX); \
456     } while (0)
457 
458 /* Optimized versions to avoid exessive locking/unlocking */
459 /* FIXME: use inline function rather macros for cleaner semantics */
460 
461 #define MPIDIU_PROGRESS_WHILE(cond)         \
462     do {                                        \
463         MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); \
464         while (cond) {                          \
465             mpi_errno = MPID_Progress_test(NULL);   \
466             if (mpi_errno) break;               \
467             MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX); \
468         } \
469         MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); \
470         MPIR_ERR_CHECK(mpi_errno);              \
471     } while (0)
472 
473 /* This macro is refactored for original code that progress in a do-while loop
474  * NOTE: it's already inside the progress lock and it is calling progress again.
475  *       To avoid recursive locking, we yield the lock here.
476  * TODO: Can we consolidate with previous macro? Double check the reasoning.
477  */
478 #define MPIDIU_PROGRESS_DO_WHILE(cond) \
479     do {                                        \
480         MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); \
481         do {                          \
482             mpi_errno = MPID_Progress_test(NULL);   \
483             if (mpi_errno) break;               \
484             MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX); \
485         } while (cond); \
486         MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); \
487         MPIR_ERR_CHECK(mpi_errno);              \
488     } while (0)
489 
490 #ifdef HAVE_ERROR_CHECKING
491 #define MPIDIG_EPOCH_CHECK_SYNC(win, mpi_errno, stmt)               \
492     do {                                                                \
493         MPID_BEGIN_ERROR_CHECKS;                                        \
494         if (MPIDIG_WIN(win, sync).access_epoch_type == MPIDIG_EPOTYPE_NONE || \
495             MPIDIG_WIN(win, sync).access_epoch_type == MPIDIG_EPOTYPE_POST) \
496             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
497                                 stmt, "**rmasync");                     \
498         MPID_END_ERROR_CHECKS;                                          \
499     } while (0)
500 
501 /* Checking per-target sync status for pscw or lock epoch. */
502 #define MPIDIG_EPOCH_CHECK_TARGET_SYNC(win,target_rank,mpi_errno,stmt) \
503     do {                                                                \
504         MPID_BEGIN_ERROR_CHECKS;                                        \
505         MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_find(win, target_rank); \
506         if ((MPIDIG_WIN(win, sync).access_epoch_type == MPIDIG_EPOTYPE_START && \
507              !MPIDIU_valid_group_rank(win->comm_ptr, target_rank,   \
508                                           MPIDIG_WIN(win, sync).sc.group)) || \
509             (target_ptr != NULL &&                                      \
510              MPIDIG_WIN(win, sync).access_epoch_type == MPIDIG_EPOTYPE_LOCK && \
511              target_ptr->sync.access_epoch_type != MPIDIG_EPOTYPE_LOCK)) \
512             MPIR_ERR_SETANDSTMT(mpi_errno,                              \
513                                 MPI_ERR_RMA_SYNC,                       \
514                                 stmt,                                   \
515                                 "**rmasync");                           \
516         MPID_END_ERROR_CHECKS;                                          \
517     } while (0)
518 
519 #define MPIDIG_EPOCH_CHECK_PASSIVE(win,mpi_errno,stmt)              \
520     do {                                                                \
521         MPID_BEGIN_ERROR_CHECKS;                                        \
522         if ((MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_LOCK) && \
523             (MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_LOCK_ALL)) \
524             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
525                                 stmt, "**rmasync");                     \
526         MPID_END_ERROR_CHECKS;                                          \
527     } while (0)
528 
529 /* NOTE: unlock/flush/flush_local needs to check per-target passive epoch (lock) */
530 #define MPIDIG_EPOCH_CHECK_TARGET_LOCK(target_ptr,mpi_errno,stmt)   \
531     do {                                                                \
532         MPID_BEGIN_ERROR_CHECKS;                                        \
533         if (target_ptr->sync.access_epoch_type != MPIDIG_EPOTYPE_LOCK) \
534             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
535                                 stmt, "**rmasync");                     \
536         MPID_END_ERROR_CHECKS;                                          \
537     } while (0)
538 
539 #define MPIDIG_ACCESS_EPOCH_CHECK_NONE(win,mpi_errno,stmt)          \
540     do {                                                                \
541         MPID_BEGIN_ERROR_CHECKS;                                        \
542         if (MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_NONE && \
543             MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_REFENCE) \
544             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
545                                 stmt, "**rmasync");                     \
546         MPID_END_ERROR_CHECKS;                                          \
547     } while (0)
548 
549 #define MPIDIG_EXPOSURE_EPOCH_CHECK_NONE(win,mpi_errno,stmt)        \
550     do {                                                                \
551         MPID_BEGIN_ERROR_CHECKS;                                        \
552         if (MPIDIG_WIN(win, sync).exposure_epoch_type != MPIDIG_EPOTYPE_NONE && \
553             MPIDIG_WIN(win, sync).exposure_epoch_type != MPIDIG_EPOTYPE_REFENCE) \
554             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
555                                 stmt, "**rmasync");                     \
556         MPID_END_ERROR_CHECKS;                                          \
557     } while (0)
558 
559 /* NOTE: multiple lock access epochs can occur simultaneously, as long as
560  * target to different processes */
561 #define MPIDIG_LOCK_EPOCH_CHECK_NONE(win,rank,mpi_errno,stmt)       \
562     do {                                                                \
563         MPID_BEGIN_ERROR_CHECKS;                                        \
564         MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_find(win, rank); \
565         if (MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_NONE && \
566             MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_REFENCE && \
567             !(MPIDIG_WIN(win, sync).access_epoch_type == MPIDIG_EPOTYPE_LOCK && \
568               (target_ptr == NULL || (!MPIR_CVAR_CH4_RMA_MEM_EFFICIENT && \
569                                       target_ptr->sync.lock.locked == 0)))) \
570             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
571                                 stmt, "**rmasync");                     \
572         MPID_END_ERROR_CHECKS;                                          \
573     } while (0)
574 
575 #define MPIDIG_FENCE_EPOCH_CHECK(win,mpi_errno,stmt)                \
576     do {                                                                \
577         MPID_BEGIN_ERROR_CHECKS;                                        \
578         if ((MPIDIG_WIN(win, sync).exposure_epoch_type != MPIDIG_EPOTYPE_FENCE && \
579              MPIDIG_WIN(win, sync).exposure_epoch_type != MPIDIG_EPOTYPE_REFENCE && \
580              MPIDIG_WIN(win, sync).exposure_epoch_type != MPIDIG_EPOTYPE_NONE) || \
581             (MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_FENCE && \
582              MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_REFENCE && \
583              MPIDIG_WIN(win, sync).access_epoch_type != MPIDIG_EPOTYPE_NONE)) \
584             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
585                                 stmt, "**rmasync");                     \
586         MPID_END_ERROR_CHECKS;                                          \
587     } while (0)
588 
589 #define MPIDIG_ACCESS_EPOCH_CHECK(win, epoch_type, mpi_errno, stmt) \
590     do {                                                                \
591         MPID_BEGIN_ERROR_CHECKS;                                        \
592         if (MPIDIG_WIN(win, sync).access_epoch_type != epoch_type)  \
593             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
594                                 stmt, "**rmasync");                     \
595         MPID_END_ERROR_CHECKS;                                          \
596     } while (0)
597 
598 #define MPIDIG_EXPOSURE_EPOCH_CHECK(win, epoch_type, mpi_errno, stmt) \
599     do {                                                                \
600         MPID_BEGIN_ERROR_CHECKS;                                        \
601         if (MPIDIG_WIN(win, sync).exposure_epoch_type != epoch_type) \
602             MPIR_ERR_SETANDSTMT(mpi_errno, MPI_ERR_RMA_SYNC,            \
603                                 stmt, "**rmasync");                     \
604         MPID_END_ERROR_CHECKS;                                          \
605     } while (0)
606 
607 #define MPIDIG_EPOCH_OP_REFENCE(win)                                \
608     do {                                                                \
609         if (MPIDIG_WIN(win, sync).access_epoch_type == MPIDIG_EPOTYPE_REFENCE && \
610             MPIDIG_WIN(win, sync).exposure_epoch_type == MPIDIG_EPOTYPE_REFENCE) \
611         {                                                               \
612             MPIDIG_WIN(win, sync).access_epoch_type = MPIDIG_EPOTYPE_FENCE; \
613             MPIDIG_WIN(win, sync).exposure_epoch_type = MPIDIG_EPOTYPE_FENCE; \
614         }                                                               \
615     } while (0)
616 
617 #define MPIDIG_EPOCH_FENCE_EVENT(win, massert)                      \
618     do {                                                                \
619         if (massert & MPI_MODE_NOSUCCEED)                               \
620         {                                                               \
621             MPIDIG_WIN(win, sync).access_epoch_type = MPIDIG_EPOTYPE_NONE; \
622             MPIDIG_WIN(win, sync).exposure_epoch_type = MPIDIG_EPOTYPE_NONE; \
623         }                                                               \
624         else                                                            \
625         {                                                               \
626             MPIDIG_WIN(win, sync).access_epoch_type = MPIDIG_EPOTYPE_REFENCE; \
627             MPIDIG_WIN(win, sync).exposure_epoch_type = MPIDIG_EPOTYPE_REFENCE; \
628         }                                                               \
629     } while (0)
630 
631 /* Generic routine for checking synchronization at every RMA operation.
632  * Assuming no RMA operation with target_rank=PROC_NULL will call it. */
633 #define MPIDIG_RMA_OP_CHECK_SYNC(target_rank, win)                                 \
634     do {                                                                               \
635         MPIDIG_EPOCH_CHECK_SYNC(win, mpi_errno, goto fn_fail);                     \
636         MPIDIG_EPOCH_OP_REFENCE(win);                                              \
637         /* Check target sync status for target_rank. */       \
638         MPIDIG_EPOCH_CHECK_TARGET_SYNC(win, target_rank, mpi_errno, goto fn_fail); \
639     } while (0);
640 
641 #else /* HAVE_ERROR_CHECKING */
642 #define MPIDIG_EPOCH_CHECK_SYNC(win, mpi_errno, stmt)               if (0) goto fn_fail;
643 #define MPIDIG_EPOCH_CHECK_TARGET_SYNC(win, target_rank, mpi_errno, stmt)              if (0) goto fn_fail;
644 #define MPIDIG_EPOCH_CHECK_PASSIVE(win, mpi_errno, stmt)            if (0) goto fn_fail;
645 #define MPIDIG_EPOCH_CHECK_TARGET_LOCK(target_ptr, mpi_errno, stmt)  if (0) goto fn_fail;
646 #define MPIDIG_ACCESS_EPOCH_CHECK_NONE(win, mpi_errno, stmt)        if (0) goto fn_fail;
647 #define MPIDIG_EXPOSURE_EPOCH_CHECK_NONE(win, mpi_errno, stmt)           if (0) goto fn_fail;
648 #define MPIDIG_LOCK_EPOCH_CHECK_NONE(win,rank,mpi_errno,stmt)       if (0) goto fn_fail;
649 #define MPIDIG_FENCE_EPOCH_CHECK(win, mpi_errno, stmt)              if (0) goto fn_fail;
650 #define MPIDIG_ACCESS_EPOCH_CHECK(win, epoch_type, mpi_errno, stmt) if (0) goto fn_fail;
651 #define MPIDIG_EXPOSURE_EPOCH_CHECK(win, epoch_type, mpi_errno, stmt)    if (0) goto fn_fail;
652 #define MPIDIG_RMA_OP_CHECK_SYNC(target_rank, win) if (0) goto fn_fail;
653 #define MPIDIG_EPOCH_FENCE_EVENT(win, massert) do {} while (0)
654 #endif /* HAVE_ERROR_CHECKING */
655 
656 /*
657   Calculate base address of the target window at the origin side
658   Return zero to let the target side calculate the actual address
659   (only offset from window base is given to the target in this case)
660 */
MPIDIG_win_base_at_origin(const MPIR_Win * win,int target_rank)661 MPL_STATIC_INLINE_PREFIX uintptr_t MPIDIG_win_base_at_origin(const MPIR_Win * win, int target_rank)
662 {
663     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_BASE_AT_ORIGIN);
664     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_BASE_AT_ORIGIN);
665 
666     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_BASE_AT_ORIGIN);
667 
668     /* TODO: In future we may want to calculate the full virtual address
669      * in the target at the origin side. It can be done by looking at
670      * MPIDIG_WINFO(win, target_rank)->base_addr */
671     return 0;
672 }
673 
674 /*
675   Calculate base address of the window at the target side
676   If MPIDIG_win_base_at_origin calculates the full virtual address
677   this function must return zero
678 */
MPIDIG_win_base_at_target(const MPIR_Win * win)679 MPL_STATIC_INLINE_PREFIX uintptr_t MPIDIG_win_base_at_target(const MPIR_Win * win)
680 {
681     uintptr_t ret;
682 
683     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_WIN_BASE_AT_TARGET);
684     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_WIN_BASE_AT_TARGET);
685 
686     ret = (uintptr_t) win->base;
687 
688     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_WIN_BASE_AT_TARGET);
689     return ret;
690 }
691 
MPIDIG_win_cmpl_cnts_incr(MPIR_Win * win,int target_rank,MPIR_cc_t ** local_cmpl_cnts_ptr)692 MPL_STATIC_INLINE_PREFIX void MPIDIG_win_cmpl_cnts_incr(MPIR_Win * win, int target_rank,
693                                                         MPIR_cc_t ** local_cmpl_cnts_ptr)
694 {
695     int c = 0;
696 
697     /* Increase per-window counters for fence, and per-target counters for
698      * all other synchronization. */
699     switch (MPIDIG_WIN(win, sync).access_epoch_type) {
700         case MPIDIG_EPOTYPE_LOCK:
701             /* FIXME: now we simply set per-target counters for lockall in case
702              * user flushes per target, but this should be optimized. */
703         case MPIDIG_EPOTYPE_LOCK_ALL:
704             /* FIXME: now we simply set per-target counters for PSCW, can it be optimized ? */
705         case MPIDIG_EPOTYPE_START:
706             {
707                 MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_get(win, target_rank);
708 
709                 MPIR_cc_incr(&target_ptr->local_cmpl_cnts, &c);
710                 MPIR_cc_incr(&target_ptr->remote_cmpl_cnts, &c);
711 
712                 *local_cmpl_cnts_ptr = &target_ptr->local_cmpl_cnts;
713                 break;
714             }
715         default:
716             MPIR_cc_incr(&MPIDIG_WIN(win, local_cmpl_cnts), &c);
717             MPIR_cc_incr(&MPIDIG_WIN(win, remote_cmpl_cnts), &c);
718 
719             *local_cmpl_cnts_ptr = &MPIDIG_WIN(win, local_cmpl_cnts);
720             break;
721     }
722 }
723 
724 /* Increase counter for active message acc ops. */
MPIDIG_win_remote_acc_cmpl_cnt_incr(MPIR_Win * win,int target_rank)725 MPL_STATIC_INLINE_PREFIX void MPIDIG_win_remote_acc_cmpl_cnt_incr(MPIR_Win * win, int target_rank)
726 {
727     int c = 0;
728     switch (MPIDIG_WIN(win, sync).access_epoch_type) {
729         case MPIDIG_EPOTYPE_LOCK:
730         case MPIDIG_EPOTYPE_LOCK_ALL:
731         case MPIDIG_EPOTYPE_START:
732             {
733                 MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_get(win, target_rank);
734                 MPIR_cc_incr(&target_ptr->remote_acc_cmpl_cnts, &c);
735                 break;
736             }
737         default:
738             MPIR_cc_incr(&MPIDIG_WIN(win, remote_acc_cmpl_cnts), &c);
739             break;
740     }
741 }
742 
743 /* Decrease counter for active message acc ops. */
MPIDIG_win_remote_acc_cmpl_cnt_decr(MPIR_Win * win,int target_rank)744 MPL_STATIC_INLINE_PREFIX void MPIDIG_win_remote_acc_cmpl_cnt_decr(MPIR_Win * win, int target_rank)
745 {
746     int c = 0;
747     switch (MPIDIG_WIN(win, sync).access_epoch_type) {
748         case MPIDIG_EPOTYPE_LOCK:
749         case MPIDIG_EPOTYPE_LOCK_ALL:
750         case MPIDIG_EPOTYPE_START:
751             {
752                 MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_find(win, target_rank);
753                 MPIR_Assert(target_ptr);
754                 MPIR_cc_decr(&target_ptr->remote_acc_cmpl_cnts, &c);
755                 break;
756             }
757         default:
758             MPIR_cc_decr(&MPIDIG_WIN(win, remote_acc_cmpl_cnts), &c);
759             break;
760     }
761 
762 }
763 
MPIDIG_win_remote_cmpl_cnt_decr(MPIR_Win * win,int target_rank)764 MPL_STATIC_INLINE_PREFIX void MPIDIG_win_remote_cmpl_cnt_decr(MPIR_Win * win, int target_rank)
765 {
766     int c = 0;
767 
768     /* Decrease per-window counter for fence, and per-target counters for
769      * all other synchronization. */
770     switch (MPIDIG_WIN(win, sync).access_epoch_type) {
771         case MPIDIG_EPOTYPE_LOCK:
772         case MPIDIG_EPOTYPE_LOCK_ALL:
773         case MPIDIG_EPOTYPE_START:
774             {
775                 MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_find(win, target_rank);
776                 MPIR_Assert(target_ptr);
777                 MPIR_cc_decr(&target_ptr->remote_cmpl_cnts, &c);
778                 break;
779             }
780         default:
781             MPIR_cc_decr(&MPIDIG_WIN(win, remote_cmpl_cnts), &c);
782             break;
783     }
784 }
785 
MPIDIG_win_check_all_targets_remote_completed(MPIR_Win * win)786 MPL_STATIC_INLINE_PREFIX bool MPIDIG_win_check_all_targets_remote_completed(MPIR_Win * win)
787 {
788     int rank = 0;
789 
790     if (!MPIDIG_WIN(win, targets))
791         return true;
792 
793     bool allcompleted = true;
794     MPIDIG_win_target_t *target_ptr = NULL;
795     for (rank = 0; rank < win->comm_ptr->local_size; rank++) {
796         target_ptr = MPIDIG_win_target_find(win, rank);
797         if (!target_ptr)
798             continue;
799         if (MPIR_cc_get(target_ptr->remote_cmpl_cnts) != 0 ||
800             MPIR_cc_get(target_ptr->remote_acc_cmpl_cnts) != 0) {
801             allcompleted = false;
802             break;
803         }
804     }
805     return allcompleted;
806 }
807 
MPIDIG_win_check_all_targets_local_completed(MPIR_Win * win)808 MPL_STATIC_INLINE_PREFIX bool MPIDIG_win_check_all_targets_local_completed(MPIR_Win * win)
809 {
810     int rank = 0;
811 
812     if (!MPIDIG_WIN(win, targets))
813         return true;
814 
815     bool allcompleted = true;
816     MPIDIG_win_target_t *target_ptr = NULL;
817     for (rank = 0; rank < win->comm_ptr->local_size; rank++) {
818         target_ptr = MPIDIG_win_target_find(win, rank);
819         if (!target_ptr)
820             continue;
821         if (MPIR_cc_get(target_ptr->local_cmpl_cnts) != 0) {
822             allcompleted = false;
823             break;
824         }
825     }
826     return allcompleted;
827 }
828 
MPIDIG_win_check_group_local_completed(MPIR_Win * win,int * ranks_in_win_grp,int grp_siz)829 MPL_STATIC_INLINE_PREFIX bool MPIDIG_win_check_group_local_completed(MPIR_Win * win,
830                                                                      int *ranks_in_win_grp,
831                                                                      int grp_siz)
832 {
833     int i = 0;
834 
835     if (!MPIDIG_WIN(win, targets))
836         return true;
837 
838     bool allcompleted = true;
839     MPIDIG_win_target_t *target_ptr = NULL;
840     for (i = 0; i < grp_siz; i++) {
841         int rank = ranks_in_win_grp[i];
842         target_ptr = MPIDIG_win_target_find(win, rank);
843         if (!target_ptr)
844             continue;
845         if (MPIR_cc_get(target_ptr->local_cmpl_cnts) != 0) {
846             allcompleted = false;
847             break;
848         }
849     }
850     return allcompleted;
851 }
852 
853 /* Map function interfaces in CH4 level */
MPIDIU_map_create(void ** out_map,MPL_memory_class class)854 MPL_STATIC_INLINE_PREFIX void MPIDIU_map_create(void **out_map, MPL_memory_class class)
855 {
856     MPIDIU_map_t *map;
857     map = MPL_malloc(sizeof(MPIDIU_map_t), class);
858     MPIR_Assert(map != NULL);
859     map->head = NULL;
860     *out_map = map;
861 }
862 
MPIDIU_map_destroy(void * in_map)863 MPL_STATIC_INLINE_PREFIX void MPIDIU_map_destroy(void *in_map)
864 {
865     MPIDIU_map_t *map = in_map;
866     MPIDIU_map_entry_t *e, *etmp;
867     HASH_ITER(hh, map->head, e, etmp) {
868         /* Free all remaining entries in the hash */
869         HASH_DELETE(hh, map->head, e);
870         MPL_free(e);
871     }
872     HASH_CLEAR(hh, map->head);
873     MPL_free(map);
874 }
875 
MPIDIU_map_set_unsafe(void * in_map,uint64_t id,void * val,MPL_memory_class class)876 MPL_STATIC_INLINE_PREFIX void MPIDIU_map_set_unsafe(void *in_map, uint64_t id, void *val,
877                                                     MPL_memory_class class)
878 {
879     MPIDIU_map_t *map;
880     MPIDIU_map_entry_t *map_entry;
881     /* MPIDIU_MAP_NOT_FOUND may be used as a special value to indicate an error. */
882     MPIR_Assert(val != MPIDIU_MAP_NOT_FOUND);
883     map = (MPIDIU_map_t *) in_map;
884     map_entry = MPL_malloc(sizeof(MPIDIU_map_entry_t), class);
885     MPIR_Assert(map_entry != NULL);
886     map_entry->key = id;
887     map_entry->value = val;
888     HASH_ADD(hh, map->head, key, sizeof(uint64_t), map_entry, class);
889 }
890 
891 /* Sets a (id -> val) pair into the map, assuming there's no entry with `id`. */
MPIDIU_map_set(void * in_map,uint64_t id,void * val,MPL_memory_class class)892 MPL_STATIC_INLINE_PREFIX void MPIDIU_map_set(void *in_map, uint64_t id, void *val,
893                                              MPL_memory_class class)
894 {
895     MPID_THREAD_CS_ENTER(POBJ, MPIDIU_THREAD_UTIL_MUTEX);
896     MPIDIU_map_set_unsafe(in_map, id, val, class);
897     MPID_THREAD_CS_EXIT(POBJ, MPIDIU_THREAD_UTIL_MUTEX);
898 }
899 
MPIDIU_map_erase(void * in_map,uint64_t id)900 MPL_STATIC_INLINE_PREFIX void MPIDIU_map_erase(void *in_map, uint64_t id)
901 {
902     MPIDIU_map_t *map;
903     MPIDIU_map_entry_t *map_entry;
904     map = (MPIDIU_map_t *) in_map;
905     HASH_FIND(hh, map->head, &id, sizeof(uint64_t), map_entry);
906     MPIR_Assert(map_entry != NULL);
907     HASH_DELETE(hh, map->head, map_entry);
908     MPL_free(map_entry);
909 }
910 
MPIDIU_map_lookup(void * in_map,uint64_t id)911 MPL_STATIC_INLINE_PREFIX void *MPIDIU_map_lookup(void *in_map, uint64_t id)
912 {
913     void *rc;
914     MPIDIU_map_t *map;
915     MPIDIU_map_entry_t *map_entry;
916 
917     map = (MPIDIU_map_t *) in_map;
918     HASH_FIND(hh, map->head, &id, sizeof(uint64_t), map_entry);
919     if (map_entry == NULL)
920         rc = MPIDIU_MAP_NOT_FOUND;
921     else
922         rc = map_entry->value;
923     return rc;
924 }
925 
926 /* Updates a value in the map which has `id` as a key.
927    If `id` does not exist in the map, it will be added. Returns the old value. */
MPIDIU_map_update(void * in_map,uint64_t id,void * new_val,MPL_memory_class class)928 MPL_STATIC_INLINE_PREFIX void *MPIDIU_map_update(void *in_map, uint64_t id, void *new_val,
929                                                  MPL_memory_class class)
930 {
931     void *rc;
932     MPIDIU_map_t *map;
933     MPIDIU_map_entry_t *map_entry;
934 
935     MPID_THREAD_CS_ENTER(POBJ, MPIDI_THREAD_UTIL_MUTEX);
936     map = (MPIDIU_map_t *) in_map;
937     HASH_FIND(hh, map->head, &id, sizeof(uint64_t), map_entry);
938     if (map_entry == NULL) {
939         rc = MPIDIU_MAP_NOT_FOUND;
940         MPIDIU_map_set_unsafe(in_map, id, new_val, class);
941     } else {
942         rc = map_entry->value;
943         map_entry->value = new_val;
944     }
945     MPID_THREAD_CS_EXIT(POBJ, MPIDI_THREAD_UTIL_MUTEX);
946     return rc;
947 }
948 
949 /* Return the associated av for a RMA target.
950  * This is an optimized path for direct intra comm (comm_world or dup from comm_world) by
951  * eliminating pointer dereferences into dynamic allocated objects (i.e., win->comm_ptr).*/
MPIDIU_win_rank_to_av(MPIR_Win * win,int rank,MPIDI_winattr_t winattr)952 MPL_STATIC_INLINE_PREFIX MPIDI_av_entry_t *MPIDIU_win_rank_to_av(MPIR_Win * win, int rank,
953                                                                  MPIDI_winattr_t winattr)
954 {
955     MPIDI_av_entry_t *av = NULL;
956 
957     if (winattr & MPIDI_WINATTR_DIRECT_INTRA_COMM) {
958         av = &MPIDI_av_table0->table[rank];
959     } else
960         av = MPIDIU_comm_rank_to_av(win->comm_ptr, rank);
961     return av;
962 }
963 
964 /* Return the local process's rank in the window.
965  * This is an optimized path for direct intra comm (comm_world or dup from comm_world) by
966  * eliminating pointer dereferences into dynamic allocated objects (i.e., win->comm_ptr).*/
MPIDIU_win_comm_rank(MPIR_Win * win,MPIDI_winattr_t winattr)967 MPL_STATIC_INLINE_PREFIX int MPIDIU_win_comm_rank(MPIR_Win * win, MPIDI_winattr_t winattr)
968 {
969     if (winattr & MPIDI_WINATTR_DIRECT_INTRA_COMM)
970         return MPIR_Process.comm_world->rank;
971     else
972         return win->comm_ptr->rank;
973 }
974 
975 /* Return the corresponding rank in intranode for a RMA target.
976  * This is an optimized path for direct intra comm (comm_world or dup from comm_world) by
977  * eliminating pointer dereferences into dynamic allocated objects (i.e., win->comm_ptr).*/
MPIDIU_win_rank_to_intra_rank(MPIR_Win * win,int rank,MPIDI_winattr_t winattr)978 MPL_STATIC_INLINE_PREFIX int MPIDIU_win_rank_to_intra_rank(MPIR_Win * win, int rank,
979                                                            MPIDI_winattr_t winattr)
980 {
981     if (winattr & MPIDI_WINATTR_DIRECT_INTRA_COMM)
982         return MPIR_Process.comm_world->intranode_table[rank];
983     else
984         return win->comm_ptr->intranode_table[rank];
985 }
986 
987 /* Wait until active message acc ops are done. */
MPIDIG_wait_am_acc(MPIR_Win * win,int target_rank)988 MPL_STATIC_INLINE_PREFIX int MPIDIG_wait_am_acc(MPIR_Win * win, int target_rank)
989 {
990     int mpi_errno = MPI_SUCCESS;
991     MPIDIG_win_target_t *target_ptr = MPIDIG_win_target_find(win, target_rank);
992     while ((target_ptr && MPIR_cc_get(target_ptr->remote_acc_cmpl_cnts) != 0) ||
993            MPIR_cc_get(MPIDIG_WIN(win, remote_acc_cmpl_cnts)) != 0) {
994         MPIDIU_PROGRESS();
995     }
996   fn_exit:
997     return mpi_errno;
998 
999   fn_fail:
1000     goto fn_exit;
1001 }
1002 
1003 /* Compute accumulate operation.
1004  * The source datatype can be only predefined; the target datatype can be
1005  * predefined or derived. If the source buffer has been packed by the caller,
1006  * src_kind must be set to MPIDIG_ACC_SRCBUF_PACKED.*/
MPIDIG_compute_acc_op(void * source_buf,int source_count,MPI_Datatype source_dtp,void * target_buf,int target_count,MPI_Datatype target_dtp,MPI_Op acc_op,int src_kind)1007 MPL_STATIC_INLINE_PREFIX int MPIDIG_compute_acc_op(void *source_buf, int source_count,
1008                                                    MPI_Datatype source_dtp, void *target_buf,
1009                                                    int target_count, MPI_Datatype target_dtp,
1010                                                    MPI_Op acc_op, int src_kind)
1011 {
1012     int mpi_errno = MPI_SUCCESS;
1013     MPI_User_function *uop = NULL;
1014     MPI_Aint source_dtp_size = 0, source_dtp_extent = 0;
1015     int is_empty_source = FALSE;
1016     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_COMPUTE_ACC_OP);
1017 
1018     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_COMPUTE_ACC_OP);
1019 
1020     /* first Judge if source buffer is empty */
1021     if (acc_op == MPI_NO_OP)
1022         is_empty_source = TRUE;
1023 
1024     if (is_empty_source == FALSE) {
1025         MPIR_Assert(MPIR_DATATYPE_IS_PREDEFINED(source_dtp));
1026         MPIR_Datatype_get_size_macro(source_dtp, source_dtp_size);
1027         MPIR_Datatype_get_extent_macro(source_dtp, source_dtp_extent);
1028     }
1029 
1030     if ((HANDLE_IS_BUILTIN(acc_op))
1031         && ((*MPIR_OP_HDL_TO_DTYPE_FN(acc_op)) (source_dtp) == MPI_SUCCESS)) {
1032         /* get the function by indexing into the op table */
1033         uop = MPIR_OP_HDL_TO_FN(acc_op);
1034     } else {
1035         /* --BEGIN ERROR HANDLING-- */
1036         mpi_errno = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE,
1037                                          __func__, __LINE__, MPI_ERR_OP,
1038                                          "**opnotpredefined", "**opnotpredefined %d", acc_op);
1039         return mpi_errno;
1040         /* --END ERROR HANDLING-- */
1041     }
1042 
1043     void *in_targetbuf = target_buf;
1044     void *host_targetbuf = NULL;
1045     MPL_pointer_attr_t attr;
1046     MPIR_GPU_query_pointer_attr(target_buf, &attr);
1047     /* FIXME: use typerep/yaksa GPU-aware accumulate when available */
1048     if (attr.type == MPL_GPU_POINTER_DEV) {
1049         MPI_Aint extent, true_extent;
1050         MPI_Aint true_lb;
1051 
1052         MPIR_Datatype_get_extent_macro(target_dtp, extent);
1053         MPIR_Type_get_true_extent_impl(target_dtp, &true_lb, &true_extent);
1054         extent = MPL_MAX(extent, true_extent);
1055 
1056         host_targetbuf = MPL_malloc(extent * target_count, MPL_MEM_RMA);
1057         MPIR_Assert(host_targetbuf);
1058         MPIR_Localcopy(target_buf, target_count, target_dtp, host_targetbuf, target_count,
1059                        target_dtp);
1060         target_buf = host_targetbuf;
1061     }
1062 
1063     if (is_empty_source == TRUE || HANDLE_IS_BUILTIN(target_dtp)) {
1064         /* directly apply op if target dtp is predefined dtp OR source buffer is empty */
1065         (*uop) (source_buf, target_buf, &source_count, &source_dtp);
1066     } else {
1067         /* derived datatype */
1068         struct iovec *typerep_vec;
1069         int i, count;
1070         MPI_Aint vec_len, type_extent, type_size, src_type_stride;
1071         MPI_Datatype type;
1072         MPIR_Datatype *dtp;
1073         MPI_Aint curr_len;
1074         void *curr_loc;
1075         int accumulated_count;
1076 
1077         MPIR_Datatype_get_ptr(target_dtp, dtp);
1078         MPIR_Assert(dtp != NULL);
1079         vec_len = dtp->typerep.num_contig_blocks * target_count + 1;
1080         /* +1 needed because Rob says so */
1081         typerep_vec = (struct iovec *)
1082             MPL_malloc(vec_len * sizeof(struct iovec), MPL_MEM_RMA);
1083         /* --BEGIN ERROR HANDLING-- */
1084         if (!typerep_vec) {
1085             mpi_errno =
1086                 MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__,
1087                                      MPI_ERR_OTHER, "**nomem", 0);
1088             goto fn_exit;
1089         }
1090         /* --END ERROR HANDLING-- */
1091 
1092         MPI_Aint actual_iov_len, actual_iov_bytes;
1093         MPIR_Typerep_to_iov(NULL, target_count, target_dtp, 0, typerep_vec, vec_len,
1094                             source_count * source_dtp_size, &actual_iov_len, &actual_iov_bytes);
1095         vec_len = actual_iov_len;
1096 
1097         type = dtp->basic_type;
1098         MPIR_Assert(type != MPI_DATATYPE_NULL);
1099 
1100         MPIR_Assert(type == source_dtp);
1101         type_size = source_dtp_size;
1102         type_extent = source_dtp_extent;
1103         /* If the source buffer has been packed by the caller, the distance between
1104          * two elements can be smaller than extent. E.g., predefined pairtype may
1105          * have larger extent than size.*/
1106         /* when predefined pairtype have larger extent than size, we'll end up
1107          * missaligned access. Memcpy the source to workaround the alignment issue.
1108          */
1109         char *src_ptr = NULL;
1110         if (src_kind == MPIDIG_ACC_SRCBUF_PACKED) {
1111             src_type_stride = source_dtp_size;
1112             if (source_dtp_size < source_dtp_extent) {
1113                 src_ptr = MPL_malloc(source_dtp_extent, MPL_MEM_OTHER);
1114             }
1115         } else {
1116             src_type_stride = source_dtp_extent;
1117         }
1118 
1119         i = 0;
1120         curr_loc = typerep_vec[0].iov_base;
1121         curr_len = typerep_vec[0].iov_len;
1122         accumulated_count = 0;
1123         while (i != vec_len) {
1124             if (curr_len < type_size) {
1125                 MPIR_Assert(i != vec_len);
1126                 i++;
1127                 curr_len += typerep_vec[i].iov_len;
1128                 continue;
1129             }
1130 
1131             MPIR_Assign_trunc(count, curr_len / type_size, int);
1132 
1133             if (src_ptr) {
1134                 MPI_Aint unpacked_size;
1135                 MPIR_Typerep_unpack((char *) source_buf + src_type_stride * accumulated_count,
1136                                     source_dtp_size, src_ptr, 1, source_dtp, 0, &unpacked_size);
1137                 (*uop) (src_ptr, (char *) target_buf + MPIR_Ptr_to_aint(curr_loc), &count, &type);
1138             } else {
1139                 (*uop) ((char *) source_buf + src_type_stride * accumulated_count,
1140                         (char *) target_buf + MPIR_Ptr_to_aint(curr_loc), &count, &type);
1141             }
1142 
1143             if (curr_len % type_size == 0) {
1144                 i++;
1145                 if (i != vec_len) {
1146                     curr_loc = typerep_vec[i].iov_base;
1147                     curr_len = typerep_vec[i].iov_len;
1148                 }
1149             } else {
1150                 curr_loc = (void *) ((char *) curr_loc + type_extent * count);
1151                 curr_len -= type_size * count;
1152             }
1153 
1154             accumulated_count += count;
1155         }
1156 
1157         MPL_free(src_ptr);
1158         MPL_free(typerep_vec);
1159     }
1160 
1161     if (host_targetbuf) {
1162         target_buf = in_targetbuf;
1163         MPIR_Localcopy(host_targetbuf, target_count, target_dtp, target_buf, target_count,
1164                        target_dtp);
1165         MPL_free(host_targetbuf);
1166     }
1167 
1168   fn_exit:
1169     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_COMPUTE_ACC_OP);
1170     return mpi_errno;
1171 }
1172 
MPIDIU_win_acc_op_get_index(MPI_Op op)1173 MPL_STATIC_INLINE_PREFIX int MPIDIU_win_acc_op_get_index(MPI_Op op)
1174 {
1175     if (op == MPI_OP_NULL) {
1176         /* Builtin index is from 0 to MPIR_OP_N_BUILTIN-1.
1177          * Thus use MPIR_OP_N_BUILTIN as index for special OP_NULL as RMA cswap */
1178         return MPIR_OP_N_BUILTIN - 1;
1179     } else {
1180         return MPIR_Op_builtin_get_index(op);
1181     }
1182 }
1183 
MPIDIU_win_acc_get_op(int index)1184 MPL_STATIC_INLINE_PREFIX MPI_Op MPIDIU_win_acc_get_op(int index)
1185 {
1186     if (index == MPIR_OP_N_BUILTIN - 1) {
1187         /* Builtin index is from 0 to MPIR_OP_N_BUILTIN-1.
1188          * Thus use MPIR_OP_N_BUILTIN as index for special OP_NULL as RMA cswap */
1189         return MPI_OP_NULL;
1190     } else {
1191         return MPIR_Op_builtin_get_op(index);
1192     }
1193 }
1194 
1195 /* Determine whether need poll progress for RMA target-side active message.
1196  * The polling interval is set globally as we don't distinguish target-side
1197  * AM handling per-window.  */
MPIDIG_rma_need_poll_am(void)1198 MPL_STATIC_INLINE_PREFIX bool MPIDIG_rma_need_poll_am(void)
1199 {
1200     bool poll_flag = false;
1201 
1202     if (MPIR_CVAR_CH4_RMA_ENABLE_DYNAMIC_AM_PROGRESS) {
1203         int interval;
1204         MPIR_cc_incr(&MPIDIG_global.rma_am_poll_cntr, &interval);
1205 
1206         /* Always poll if any RMA target-side AM has arrived because
1207          * we expect more incoming AM now. */
1208         if (MPL_atomic_load_int(&MPIDIG_global.rma_am_flag)) {
1209             poll_flag = true;
1210         } else {
1211             /* Otherwise poll with low frequency to reduce latency */
1212             poll_flag = ((interval + 1) % MPIR_CVAR_CH4_RMA_AM_PROGRESS_LOW_FREQ_INTERVAL
1213                          == 0) ? true : false;
1214         }
1215     } else if (MPIR_CVAR_CH4_RMA_AM_PROGRESS_INTERVAL > 1) {
1216         int interval;
1217         MPIR_cc_incr(&MPIDIG_global.rma_am_poll_cntr, &interval);
1218 
1219         /* User explicitly controls the polling frequency */
1220         poll_flag = ((interval + 1) % MPIR_CVAR_CH4_RMA_AM_PROGRESS_INTERVAL == 0) ? true : false;
1221     } else if (MPIR_CVAR_CH4_RMA_AM_PROGRESS_INTERVAL == 1) {
1222         /* Skip cntr update when interval == 1, as we always poll (default)  */
1223         poll_flag = true;
1224     } else {
1225         /* User explicitly disables polling */
1226         poll_flag = false;
1227     }
1228 
1229     return poll_flag;
1230 }
1231 
1232 /* Set flag to indicate a target-side AM has arrived. */
MPIDIG_rma_set_am_flag(void)1233 MPL_STATIC_INLINE_PREFIX void MPIDIG_rma_set_am_flag(void)
1234 {
1235     MPL_atomic_store_int(&MPIDIG_global.rma_am_flag, 1);
1236 }
1237 
1238 #endif /* CH4_IMPL_H_INCLUDED */
1239