1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #ifndef UCX_RMA_H_INCLUDED
7 #define UCX_RMA_H_INCLUDED
8 
9 #include "ucx_impl.h"
10 
MPIDI_UCX_rma_cmpl_cb(void * request,ucs_status_t status)11 MPL_STATIC_INLINE_PREFIX void MPIDI_UCX_rma_cmpl_cb(void *request, ucs_status_t status)
12 {
13     MPIDI_UCX_ucp_request_t *ucp_request = (MPIDI_UCX_ucp_request_t *) request;
14     MPIR_Request *req = ucp_request->req;
15 
16     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_RMA_CMPL_CB);
17     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_RMA_CMPL_CB);
18 
19     MPIR_Assert(status == UCS_OK);
20     if (req) {
21         MPID_Request_complete(req);
22         ucp_request->req = NULL;
23     }
24     ucp_request_free(ucp_request);
25     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_RMA_CMPL_CB);
26 }
27 
MPIDI_UCX_contig_put(const void * origin_addr,size_t size,int target_rank,MPI_Aint target_disp,MPI_Aint true_lb,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIR_Request ** reqptr)28 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_contig_put(const void *origin_addr,
29                                                   size_t size,
30                                                   int target_rank,
31                                                   MPI_Aint target_disp, MPI_Aint true_lb,
32                                                   MPIR_Win * win, MPIDI_av_entry_t * addr,
33                                                   MPIR_Request ** reqptr)
34 {
35 
36     MPIDI_UCX_win_info_t *win_info = &(MPIDI_UCX_WIN_INFO(win, target_rank));
37     size_t offset;
38     uint64_t base;
39     int mpi_errno = MPI_SUCCESS;
40     MPIDI_UCX_ucp_request_t *ucp_request ATTRIBUTE((unused)) = NULL;
41     ucp_ep_h ep = MPIDI_UCX_AV_TO_EP(addr, 0, 0);
42 
43     base = win_info->addr;
44     offset = target_disp * win_info->disp + true_lb;
45 
46     /* Put without request */
47     if (likely(!reqptr)) {
48         ucs_status_t status;
49         status = ucp_put_nbi(ep, origin_addr, size, base + offset, win_info->rkey);
50         if (status == UCS_INPROGRESS)
51             MPIDI_UCX_WIN(win).target_sync[target_rank].need_sync = MPIDI_UCX_WIN_SYNC_FLUSH_LOCAL;
52         else if (status == UCS_OK)
53             /* UCX 1.4 spec: completed immediately if returns UCS_OK.
54              * FIXME: is it local completion or remote ? Now we assume local,
55              * so we need flush to ensure remote completion.*/
56             MPIDI_UCX_WIN(win).target_sync[target_rank].need_sync = MPIDI_UCX_WIN_SYNC_FLUSH;
57         else
58             MPIDI_UCX_CHK_STATUS(status);
59         goto fn_exit;
60     }
61 #ifdef HAVE_UCP_PUT_NB  /* ucp_put_nb is provided since UCX 1.4 */
62     /* Put with request */
63     ucp_request = ucp_put_nb(ep, origin_addr, size, base + offset, win_info->rkey,
64                              MPIDI_UCX_rma_cmpl_cb);
65     if (ucp_request == UCS_OK) {
66         /* UCX 1.4 spec: completed immediately if returns UCS_OK.
67          * FIXME: is it local completion or remote ? Now we assume local,
68          * so we need flush to ensure remote completion.*/
69         MPIDI_UCX_WIN(win).target_sync[target_rank].need_sync = MPIDI_UCX_WIN_SYNC_FLUSH;
70     } else {
71         MPIDI_UCX_CHK_REQUEST(ucp_request);
72 
73         /* Create an MPI request and return. The completion cb will complete
74          * the request and release ucp_request. */
75         MPIR_Request *req = NULL;
76         req = MPIR_Request_create_from_pool(MPIR_REQUEST_KIND__RMA, 0);
77         MPIR_ERR_CHKANDSTMT(req == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq");
78         MPIR_Request_add_ref(req);
79         ucp_request->req = req;
80         *reqptr = req;
81 
82         MPIDI_UCX_WIN(win).target_sync[target_rank].need_sync = MPIDI_UCX_WIN_SYNC_FLUSH_LOCAL;
83     }
84 #endif
85 
86   fn_exit:
87     return mpi_errno;
88   fn_fail:
89     goto fn_exit;
90 }
91 
MPIDI_UCX_noncontig_put(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,size_t size,MPI_Aint target_disp,MPI_Aint true_lb,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIR_Request ** reqptr ATTRIBUTE ((unused)))92 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_noncontig_put(const void *origin_addr,
93                                                      int origin_count, MPI_Datatype origin_datatype,
94                                                      int target_rank, size_t size,
95                                                      MPI_Aint target_disp, MPI_Aint true_lb,
96                                                      MPIR_Win * win, MPIDI_av_entry_t * addr,
97                                                      MPIR_Request ** reqptr ATTRIBUTE((unused)))
98 {
99     MPIDI_UCX_win_info_t *win_info = &(MPIDI_UCX_WIN_INFO(win, target_rank));
100     size_t base, offset;
101     int mpi_errno = MPI_SUCCESS;
102     ucs_status_t status;
103     char *buffer = NULL;
104     ucp_ep_h ep = MPIDI_UCX_AV_TO_EP(addr, 0, 0);
105 
106     buffer = MPL_malloc(size, MPL_MEM_BUFFER);
107     MPIR_Assert(buffer);
108 
109     MPI_Aint actual_pack_bytes;
110     mpi_errno = MPIR_Typerep_pack(origin_addr, origin_count, origin_datatype, 0, buffer, size,
111                                   &actual_pack_bytes);
112     MPIR_ERR_CHECK(mpi_errno);
113     MPIR_Assert(actual_pack_bytes == size);
114 
115     base = win_info->addr;
116     offset = target_disp * win_info->disp + true_lb;
117     /* We use the blocking put here - should be faster than send/recv - ucp_put returns when it is
118      * locally completed. In reality this means, when the data are copied to the internal UCP-buffer */
119     status = ucp_put(ep, buffer, size, base + offset, win_info->rkey);
120     MPIDI_UCX_CHK_STATUS(status);
121 
122     /* Only need remote flush */
123     MPIDI_UCX_WIN(win).target_sync[target_rank].need_sync = MPIDI_UCX_WIN_SYNC_FLUSH;
124 
125   fn_exit:
126     MPL_free(buffer);
127     return mpi_errno;
128   fn_fail:
129     goto fn_exit;
130 }
131 
MPIDI_UCX_contig_get(void * origin_addr,size_t size,int target_rank,MPI_Aint target_disp,MPI_Aint true_lb,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIR_Request ** reqptr)132 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_contig_get(void *origin_addr,
133                                                   size_t size,
134                                                   int target_rank,
135                                                   MPI_Aint target_disp, MPI_Aint true_lb,
136                                                   MPIR_Win * win, MPIDI_av_entry_t * addr,
137                                                   MPIR_Request ** reqptr)
138 {
139 
140     MPIDI_UCX_win_info_t *win_info = &(MPIDI_UCX_WIN_INFO(win, target_rank));
141     size_t base, offset;
142     int mpi_errno = MPI_SUCCESS;
143     MPIDI_UCX_ucp_request_t *ucp_request ATTRIBUTE((unused)) = NULL;
144     ucp_ep_h ep = MPIDI_UCX_AV_TO_EP(addr, 0, 0);
145 
146     base = win_info->addr;
147     offset = target_disp * win_info->disp + true_lb;
148 
149     /* Get without request */
150     if (likely(!reqptr)) {
151         ucs_status_t status;
152         status = ucp_get_nbi(ep, origin_addr, size, base + offset, win_info->rkey);
153         MPIDI_UCX_CHK_STATUS(status);
154 
155         /* UCX 1.4 spec: ucp_get_nbi always returns immediately and does not
156          * guarantee completion */
157         MPIDI_UCX_WIN(win).target_sync[target_rank].need_sync = MPIDI_UCX_WIN_SYNC_FLUSH_LOCAL;
158         goto fn_exit;
159     }
160 #ifdef HAVE_UCP_GET_NB  /* ucp_get_nb is provided since UCX 1.4 */
161     /* Get with request */
162     ucp_request = ucp_get_nb(ep, origin_addr, size, base + offset, win_info->rkey,
163                              MPIDI_UCX_rma_cmpl_cb);
164     if (ucp_request != UCS_OK) {
165         MPIDI_UCX_CHK_REQUEST(ucp_request);
166 
167         /* Create an MPI request and return. The completion cb will complete
168          * the request and release ucp_request. */
169         MPIR_Request *req = NULL;
170         req = MPIR_Request_create_from_pool(MPIR_REQUEST_KIND__RMA, 0);
171         MPIR_ERR_CHKANDSTMT(req == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq");
172         MPIR_Request_add_ref(req);
173         ucp_request->req = req;
174         *reqptr = req;
175 
176         MPIDI_UCX_WIN(win).target_sync[target_rank].need_sync = MPIDI_UCX_WIN_SYNC_FLUSH_LOCAL;
177     }
178     /* otherwise completed immediately */
179 #endif
180 
181   fn_exit:
182     return mpi_errno;
183   fn_fail:
184     goto fn_exit;
185 }
186 
MPIDI_UCX_do_put(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr,MPIR_Request ** reqptr)187 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_do_put(const void *origin_addr,
188                                               int origin_count,
189                                               MPI_Datatype origin_datatype,
190                                               int target_rank,
191                                               MPI_Aint target_disp,
192                                               int target_count, MPI_Datatype target_datatype,
193                                               MPIR_Win * win, MPIDI_av_entry_t * addr,
194                                               MPIDI_winattr_t winattr, MPIR_Request ** reqptr)
195 {
196     int mpi_errno = MPI_SUCCESS;
197     int target_contig, origin_contig;
198     size_t target_bytes, origin_bytes;
199     MPI_Aint origin_true_lb, target_true_lb;
200     size_t offset;
201 
202     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_DO_PUT);
203     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_DO_PUT);
204 
205     MPIDIG_RMA_OP_CHECK_SYNC(target_rank, win);
206     MPIDI_Datatype_check_origin_target_contig_size_lb(origin_datatype, target_datatype,
207                                                       origin_count, target_count,
208                                                       origin_contig, target_contig,
209                                                       origin_bytes, target_bytes,
210                                                       origin_true_lb, target_true_lb);
211 
212     if (unlikely(origin_bytes == 0))
213         goto fn_exit;
214 
215     if (target_rank == MPIDIU_win_comm_rank(win, winattr)) {
216         offset = win->disp_unit * target_disp;
217         mpi_errno = MPIR_Localcopy(origin_addr,
218                                    origin_count,
219                                    origin_datatype,
220                                    (char *) win->base + offset, target_count, target_datatype);
221         goto fn_exit;
222     }
223 
224     if (origin_contig && target_contig) {
225         mpi_errno = MPIDI_UCX_contig_put((char *) origin_addr + origin_true_lb, origin_bytes,
226                                          target_rank, target_disp, target_true_lb, win, addr,
227                                          reqptr);
228     } else if (target_contig) {
229         mpi_errno = MPIDI_UCX_noncontig_put(origin_addr, origin_count, origin_datatype, target_rank,
230                                             target_bytes, target_disp, target_true_lb, win, addr,
231                                             reqptr);
232     } else {
233         mpi_errno = MPIDIG_mpi_put(origin_addr, origin_count, origin_datatype, target_rank,
234                                    target_disp, target_count, target_datatype, win);
235     }
236 
237   fn_exit:
238     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_DO_PUT);
239     return mpi_errno;
240   fn_fail:
241     goto fn_exit;
242 }
243 
MPIDI_UCX_do_get(void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr,MPIR_Request ** reqptr)244 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_do_get(void *origin_addr,
245                                               int origin_count,
246                                               MPI_Datatype origin_datatype,
247                                               int target_rank,
248                                               MPI_Aint target_disp,
249                                               int target_count, MPI_Datatype target_datatype,
250                                               MPIR_Win * win, MPIDI_av_entry_t * addr,
251                                               MPIDI_winattr_t winattr, MPIR_Request ** reqptr)
252 {
253     int mpi_errno = MPI_SUCCESS;
254     int origin_contig, target_contig;
255     size_t origin_bytes, target_bytes ATTRIBUTE((unused));
256     size_t offset;
257     MPI_Aint origin_true_lb, target_true_lb;
258 
259     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_DO_GET);
260     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_DO_GET);
261 
262     MPIDIG_RMA_OP_CHECK_SYNC(target_rank, win);
263     MPIDI_Datatype_check_origin_target_contig_size_lb(origin_datatype, target_datatype,
264                                                       origin_count, target_count,
265                                                       origin_contig, target_contig,
266                                                       origin_bytes, target_bytes,
267                                                       origin_true_lb, target_true_lb);
268 
269     if (unlikely(origin_bytes == 0))
270         goto fn_exit;
271 
272     if (target_rank == MPIDIU_win_comm_rank(win, winattr)) {
273         offset = target_disp * win->disp_unit;
274         mpi_errno = MPIR_Localcopy((char *) win->base + offset,
275                                    target_count,
276                                    target_datatype, origin_addr, origin_count, origin_datatype);
277         goto fn_exit;
278     }
279 
280     if (origin_contig && target_contig) {
281         mpi_errno = MPIDI_UCX_contig_get((char *) origin_addr + origin_true_lb, origin_bytes,
282                                          target_rank, target_disp, target_true_lb, win, addr,
283                                          reqptr);
284     } else {
285         mpi_errno = MPIDIG_mpi_get(origin_addr, origin_count, origin_datatype, target_rank,
286                                    target_disp, target_count, target_datatype, win);
287     }
288 
289   fn_exit:
290     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_DO_GET);
291     return mpi_errno;
292   fn_fail:
293     goto fn_exit;
294 }
295 
MPIDI_NM_mpi_put(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr)296 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_put(const void *origin_addr,
297                                               int origin_count,
298                                               MPI_Datatype origin_datatype,
299                                               int target_rank,
300                                               MPI_Aint target_disp,
301                                               int target_count, MPI_Datatype target_datatype,
302                                               MPIR_Win * win, MPIDI_av_entry_t * addr,
303                                               MPIDI_winattr_t winattr)
304 {
305     int mpi_errno = MPI_SUCCESS;
306     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_NM_MPI_PUT);
307     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_NM_MPI_PUT);
308 
309     if (!MPIDI_UCX_is_reachable_target(target_rank, win, winattr) ||
310         MPIR_GPU_query_pointer_is_dev(origin_addr)) {
311         mpi_errno = MPIDIG_mpi_put(origin_addr, origin_count, origin_datatype, target_rank,
312                                    target_disp, target_count, target_datatype, win);
313     } else {
314         mpi_errno = MPIDI_UCX_do_put(origin_addr, origin_count, origin_datatype,
315                                      target_rank, target_disp, target_count, target_datatype,
316                                      win, addr, winattr, NULL);
317     }
318 
319     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_NM_MPI_PUT);
320     return mpi_errno;
321 }
322 
MPIDI_NM_mpi_get(void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr)323 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_get(void *origin_addr,
324                                               int origin_count,
325                                               MPI_Datatype origin_datatype,
326                                               int target_rank,
327                                               MPI_Aint target_disp,
328                                               int target_count, MPI_Datatype target_datatype,
329                                               MPIR_Win * win, MPIDI_av_entry_t * addr,
330                                               MPIDI_winattr_t winattr)
331 {
332     int mpi_errno = MPI_SUCCESS;
333     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_NM_MPI_GET);
334     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_NM_MPI_GET);
335 
336     if (!MPIDI_UCX_is_reachable_target(target_rank, win, winattr) ||
337         MPIR_GPU_query_pointer_is_dev(origin_addr)) {
338         mpi_errno = MPIDIG_mpi_get(origin_addr, origin_count, origin_datatype, target_rank,
339                                    target_disp, target_count, target_datatype, win);
340     } else {
341         mpi_errno = MPIDI_UCX_do_get(origin_addr, origin_count, origin_datatype,
342                                      target_rank, target_disp, target_count, target_datatype,
343                                      win, addr, winattr, NULL);
344     }
345 
346     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_NM_MPI_GET);
347     return mpi_errno;
348 }
349 
MPIDI_NM_mpi_rput(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr,MPIR_Request ** request)350 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_rput(const void *origin_addr,
351                                                int origin_count,
352                                                MPI_Datatype origin_datatype,
353                                                int target_rank,
354                                                MPI_Aint target_disp,
355                                                int target_count,
356                                                MPI_Datatype target_datatype,
357                                                MPIR_Win * win,
358                                                MPIDI_av_entry_t * addr, MPIDI_winattr_t winattr,
359                                                MPIR_Request ** request)
360 {
361     /* request based PUT relies on UCX 1.4 function ucp_put_nb */
362 #if !defined(HAVE_UCP_PUT_NB)
363     return MPIDIG_mpi_rput(origin_addr, origin_count, origin_datatype, target_rank,
364                            target_disp, target_count, target_datatype, win, request);
365 #else
366     int mpi_errno = MPI_SUCCESS;
367     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_NM_MPI_RPUT);
368     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_NM_MPI_RPUT);
369 
370     if (!MPIDI_UCX_is_reachable_target(target_rank, win, winattr) ||
371         MPIR_GPU_query_pointer_is_dev(origin_addr)) {
372         mpi_errno = MPIDIG_mpi_rput(origin_addr, origin_count, origin_datatype, target_rank,
373                                     target_disp, target_count, target_datatype, win, request);
374     } else {
375         MPIR_Request *sreq = NULL;
376 
377         mpi_errno = MPIDI_UCX_do_put(origin_addr, origin_count, origin_datatype,
378                                      target_rank, target_disp, target_count, target_datatype,
379                                      win, addr, winattr, &sreq);
380         MPIR_ERR_CHECK(mpi_errno);
381 
382         if (sreq == NULL) {
383             /* create a completed request for user if issuing is completed immediately. */
384             sreq = MPIR_Request_create_complete(MPIR_REQUEST_KIND__RMA);
385             MPIR_ERR_CHKANDSTMT((sreq) == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail,
386                                 "**nomemreq");
387         }
388         *request = sreq;
389     }
390 
391   fn_exit:
392     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_NM_MPI_RPUT);
393     return mpi_errno;
394   fn_fail:
395     goto fn_exit;
396 #endif
397 }
398 
399 
MPIDI_NM_mpi_compare_and_swap(const void * origin_addr,const void * compare_addr,void * result_addr,MPI_Datatype datatype,int target_rank,MPI_Aint target_disp,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr)400 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_compare_and_swap(const void *origin_addr,
401                                                            const void *compare_addr,
402                                                            void *result_addr,
403                                                            MPI_Datatype datatype,
404                                                            int target_rank, MPI_Aint target_disp,
405                                                            MPIR_Win * win, MPIDI_av_entry_t * addr,
406                                                            MPIDI_winattr_t winattr)
407 {
408     return MPIDIG_mpi_compare_and_swap(origin_addr, compare_addr, result_addr, datatype,
409                                        target_rank, target_disp, win);
410 }
411 
MPIDI_NM_mpi_raccumulate(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPI_Op op,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr,MPIR_Request ** request)412 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_raccumulate(const void *origin_addr,
413                                                       int origin_count,
414                                                       MPI_Datatype origin_datatype,
415                                                       int target_rank,
416                                                       MPI_Aint target_disp,
417                                                       int target_count,
418                                                       MPI_Datatype target_datatype,
419                                                       MPI_Op op, MPIR_Win * win,
420                                                       MPIDI_av_entry_t * addr,
421                                                       MPIDI_winattr_t winattr,
422                                                       MPIR_Request ** request)
423 {
424     return MPIDIG_mpi_raccumulate(origin_addr, origin_count, origin_datatype, target_rank,
425                                   target_disp, target_count, target_datatype, op, win, request);
426 }
427 
MPIDI_NM_mpi_rget_accumulate(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,void * result_addr,int result_count,MPI_Datatype result_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPI_Op op,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr,MPIR_Request ** request)428 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_rget_accumulate(const void *origin_addr,
429                                                           int origin_count,
430                                                           MPI_Datatype origin_datatype,
431                                                           void *result_addr,
432                                                           int result_count,
433                                                           MPI_Datatype result_datatype,
434                                                           int target_rank,
435                                                           MPI_Aint target_disp,
436                                                           int target_count,
437                                                           MPI_Datatype target_datatype,
438                                                           MPI_Op op, MPIR_Win * win,
439                                                           MPIDI_av_entry_t * addr,
440                                                           MPIDI_winattr_t winattr,
441                                                           MPIR_Request ** request)
442 {
443     return MPIDIG_mpi_rget_accumulate(origin_addr, origin_count, origin_datatype, result_addr,
444                                       result_count, result_datatype, target_rank, target_disp,
445                                       target_count, target_datatype, op, win, request);
446 }
447 
MPIDI_NM_mpi_fetch_and_op(const void * origin_addr,void * result_addr,MPI_Datatype datatype,int target_rank,MPI_Aint target_disp,MPI_Op op,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr)448 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_fetch_and_op(const void *origin_addr,
449                                                        void *result_addr,
450                                                        MPI_Datatype datatype,
451                                                        int target_rank,
452                                                        MPI_Aint target_disp, MPI_Op op,
453                                                        MPIR_Win * win, MPIDI_av_entry_t * addr,
454                                                        MPIDI_winattr_t winattr)
455 {
456     return MPIDIG_mpi_fetch_and_op(origin_addr, result_addr, datatype, target_rank, target_disp, op,
457                                    win);
458 }
459 
460 
MPIDI_NM_mpi_rget(void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr,MPIR_Request ** request)461 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_rget(void *origin_addr,
462                                                int origin_count,
463                                                MPI_Datatype origin_datatype,
464                                                int target_rank,
465                                                MPI_Aint target_disp,
466                                                int target_count,
467                                                MPI_Datatype target_datatype,
468                                                MPIR_Win * win,
469                                                MPIDI_av_entry_t * addr, MPIDI_winattr_t winattr,
470                                                MPIR_Request ** request)
471 {
472     /* request based GET relies on UCX 1.4 function ucp_get_nb */
473 #if !defined(HAVE_UCP_GET_NB)
474     return MPIDIG_mpi_rget(origin_addr, origin_count, origin_datatype, target_rank,
475                            target_disp, target_count, target_datatype, win, request);
476 #else
477     int mpi_errno = MPI_SUCCESS;
478     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_NM_MPI_RGET);
479     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_NM_MPI_RGET);
480 
481     if (!MPIDI_UCX_is_reachable_target(target_rank, win, winattr) ||
482         MPIR_GPU_query_pointer_is_dev(origin_addr)) {
483         mpi_errno = MPIDIG_mpi_rget(origin_addr, origin_count, origin_datatype, target_rank,
484                                     target_disp, target_count, target_datatype, win, request);
485     } else {
486         MPIR_Request *sreq = NULL;
487 
488         mpi_errno = MPIDI_UCX_do_get(origin_addr, origin_count, origin_datatype,
489                                      target_rank, target_disp, target_count, target_datatype,
490                                      win, addr, winattr, &sreq);
491         MPIR_ERR_CHECK(mpi_errno);
492 
493         if (sreq == NULL) {
494             /* create a completed request for user if issuing is completed immediately. */
495             sreq = MPIR_Request_create_complete(MPIR_REQUEST_KIND__RMA);
496             MPIR_ERR_CHKANDSTMT((sreq) == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail,
497                                 "**nomemreq");
498         }
499         *request = sreq;
500     }
501 
502   fn_exit:
503     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_NM_MPI_RGET);
504     return mpi_errno;
505   fn_fail:
506     goto fn_exit;
507 #endif
508 }
509 
510 
MPIDI_NM_mpi_get_accumulate(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,void * result_addr,int result_count,MPI_Datatype result_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPI_Op op,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr)511 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_get_accumulate(const void *origin_addr,
512                                                          int origin_count,
513                                                          MPI_Datatype origin_datatype,
514                                                          void *result_addr,
515                                                          int result_count,
516                                                          MPI_Datatype result_datatype,
517                                                          int target_rank,
518                                                          MPI_Aint target_disp,
519                                                          int target_count,
520                                                          MPI_Datatype target_datatype, MPI_Op op,
521                                                          MPIR_Win * win, MPIDI_av_entry_t * addr,
522                                                          MPIDI_winattr_t winattr)
523 {
524     return MPIDIG_mpi_get_accumulate(origin_addr, origin_count, origin_datatype, result_addr,
525                                      result_count, result_datatype, target_rank, target_disp,
526                                      target_count, target_datatype, op, win);
527 }
528 
MPIDI_NM_mpi_accumulate(const void * origin_addr,int origin_count,MPI_Datatype origin_datatype,int target_rank,MPI_Aint target_disp,int target_count,MPI_Datatype target_datatype,MPI_Op op,MPIR_Win * win,MPIDI_av_entry_t * addr,MPIDI_winattr_t winattr)529 MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_accumulate(const void *origin_addr,
530                                                      int origin_count,
531                                                      MPI_Datatype origin_datatype,
532                                                      int target_rank,
533                                                      MPI_Aint target_disp,
534                                                      int target_count,
535                                                      MPI_Datatype target_datatype, MPI_Op op,
536                                                      MPIR_Win * win, MPIDI_av_entry_t * addr,
537                                                      MPIDI_winattr_t winattr)
538 {
539     return MPIDIG_mpi_accumulate(origin_addr, origin_count, origin_datatype, target_rank,
540                                  target_disp, target_count, target_datatype, op, win);
541 }
542 
543 #endif /* UCX_RMA_H_INCLUDED */
544