1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #ifndef UCX_IMPL_H_INCLUDED
7 #define UCX_IMPL_H_INCLUDED
8 
9 #include <mpidimpl.h>
10 #include "ucx_types.h"
11 #include "mpidch4r.h"
12 #include "ch4_impl.h"
13 
14 #include <ucs/type/status.h>
15 
16 #define MPIDI_UCX_COMM(comm)     ((comm)->dev.ch4.netmod.ucx)
17 #define MPIDI_UCX_REQ(req)       ((req)->dev.ch4.netmod.ucx)
18 #define COMM_TO_INDEX(comm,rank) MPIDIU_comm_rank_to_pid(comm, rank, NULL, NULL)
19 #define MPIDI_UCX_COMM_TO_EP(comm,rank,vni_src,vni_dst) \
20     MPIDI_UCX_AV(MPIDIU_comm_rank_to_av(comm, rank)).dest[vni_src][vni_dst]
21 #define MPIDI_UCX_AV_TO_EP(av,vni_src,vni_dst) MPIDI_UCX_AV((av)).dest[vni_src][vni_dst]
22 
23 #define MPIDI_UCX_WIN(win) ((win)->dev.netmod.ucx)
24 #define MPIDI_UCX_WIN_INFO(win, rank) MPIDI_UCX_WIN(win).info_table[rank]
25 
MPIDI_UCX_init_tag(MPIR_Context_id_t contextid,int source,uint64_t tag)26 MPL_STATIC_INLINE_PREFIX uint64_t MPIDI_UCX_init_tag(MPIR_Context_id_t contextid, int source,
27                                                      uint64_t tag)
28 {
29     uint64_t ucp_tag = 0;
30     ucp_tag = contextid;
31     ucp_tag = (ucp_tag << MPIDI_UCX_SOURCE_SHIFT);
32     ucp_tag |= source;
33     ucp_tag = (ucp_tag << MPIDI_UCX_TAG_SHIFT);
34     ucp_tag |= (MPIDI_UCX_TAG_MASK & tag);
35     return ucp_tag;
36 }
37 
MPIDI_UCX_tag_mask(int mpi_tag,int src)38 MPL_STATIC_INLINE_PREFIX uint64_t MPIDI_UCX_tag_mask(int mpi_tag, int src)
39 {
40     uint64_t tag_mask = 0xffffffffffffffff;
41     MPIR_TAG_CLEAR_ERROR_BITS(tag_mask);
42     if (mpi_tag == MPI_ANY_TAG)
43         tag_mask &= ~MPIR_TAG_USABLE_BITS;
44 
45     if (src == MPI_ANY_SOURCE)
46         tag_mask &= ~(MPIDI_UCX_SOURCE_MASK);
47 
48     return tag_mask;
49 }
50 
MPIDI_UCX_recv_tag(int mpi_tag,int src,MPIR_Context_id_t contextid)51 MPL_STATIC_INLINE_PREFIX uint64_t MPIDI_UCX_recv_tag(int mpi_tag, int src,
52                                                      MPIR_Context_id_t contextid)
53 {
54     uint64_t ucp_tag = contextid;
55 
56     ucp_tag = (ucp_tag << MPIDI_UCX_SOURCE_SHIFT);
57     if (src != MPI_ANY_SOURCE)
58         ucp_tag |= (src & UCS_MASK(MPIDI_UCX_CONTEXT_RANK_BITS));
59     ucp_tag = ucp_tag << MPIDI_UCX_TAG_SHIFT;
60     if (mpi_tag != MPI_ANY_TAG)
61         ucp_tag |= (MPIDI_UCX_TAG_MASK & mpi_tag);
62     return ucp_tag;
63 }
64 
MPIDI_UCX_get_tag(uint64_t match_bits)65 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_get_tag(uint64_t match_bits)
66 {
67     return ((int) (match_bits & MPIDI_UCX_TAG_MASK));
68 }
69 
MPIDI_UCX_get_source(uint64_t match_bits)70 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_get_source(uint64_t match_bits)
71 {
72     return ((int) ((match_bits & MPIDI_UCX_SOURCE_MASK) >> MPIDI_UCX_TAG_SHIFT));
73 }
74 
75 #define MPIDI_UCX_CHK_STATUS(STATUS)                                    \
76     do {                                                                \
77         MPIR_ERR_CHKANDJUMP4((STATUS!=UCS_OK && STATUS!=UCS_INPROGRESS), \
78                              mpi_errno,                                 \
79                              MPI_ERR_OTHER,                             \
80                              "**ucx_nm_status",                         \
81                              "**ucx_nm_status %s %d %s %s",             \
82                              __SHORT_FILE__,                            \
83                              __LINE__,                                  \
84                              __func__,                                    \
85                              ucs_status_string(STATUS));                \
86     } while (0)
87 
88 #define MPIDI_UCX_CHK_REQUEST(_req)                                     \
89     do {                                                                \
90         MPIR_ERR_CHKANDJUMP4(UCS_PTR_IS_ERR(_req),                      \
91                              mpi_errno,                                 \
92                              MPI_ERR_OTHER,                             \
93                              "**ucx_nm_rq_error",                       \
94                              "**ucx_nm_rq_error %s %d %s %s",           \
95                              __SHORT_FILE__,                            \
96                              __LINE__,                                  \
97                              __func__,                                    \
98                              ucs_status_string(UCS_PTR_STATUS(_req)));  \
99     } while (0)
100 
MPIDI_UCX_is_reachable_target(int rank,MPIR_Win * win,MPIDI_winattr_t winattr)101 MPL_STATIC_INLINE_PREFIX bool MPIDI_UCX_is_reachable_target(int rank, MPIR_Win * win,
102                                                             MPIDI_winattr_t winattr)
103 {
104     /* unmapped win target does not have rkey. */
105     return (winattr & MPIDI_WINATTR_NM_REACHABLE) || (MPIDI_UCX_WIN(win).info_table &&
106                                                       MPIDI_UCX_WIN_INFO(win, rank).rkey != NULL);
107 }
108 
109 /* This function implements netmod vci to vni(context) mapping.
110  * It returns -1 if the vci does not have a mapping.
111  */
MPIDI_UCX_vci_to_vni(int vci)112 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_vci_to_vni(int vci)
113 {
114     return vci < MPIDI_UCX_global.num_vnis ? vci : -1;
115 }
116 
117 /* vni mapping */
118 /* NOTE: concerned by the modulo? If we restrict num_vnis to power of 2,
119  * we may get away with bit mask */
MPIDI_UCX_get_vni(int flag,MPIR_Comm * comm_ptr,int src_rank,int dst_rank,int tag)120 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_get_vni(int flag, MPIR_Comm * comm_ptr,
121                                                int src_rank, int dst_rank, int tag)
122 {
123     return MPIDI_get_vci(flag, comm_ptr, src_rank, dst_rank, tag) % MPIDI_UCX_global.num_vnis;
124 }
125 
126 #endif /* UCX_IMPL_H_INCLUDED */
127