1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #ifndef UCX_IMPL_H_INCLUDED
7 #define UCX_IMPL_H_INCLUDED
8
9 #include <mpidimpl.h>
10 #include "ucx_types.h"
11 #include "mpidch4r.h"
12 #include "ch4_impl.h"
13
14 #include <ucs/type/status.h>
15
16 #define MPIDI_UCX_COMM(comm) ((comm)->dev.ch4.netmod.ucx)
17 #define MPIDI_UCX_REQ(req) ((req)->dev.ch4.netmod.ucx)
18 #define COMM_TO_INDEX(comm,rank) MPIDIU_comm_rank_to_pid(comm, rank, NULL, NULL)
19 #define MPIDI_UCX_COMM_TO_EP(comm,rank,vni_src,vni_dst) \
20 MPIDI_UCX_AV(MPIDIU_comm_rank_to_av(comm, rank)).dest[vni_src][vni_dst]
21 #define MPIDI_UCX_AV_TO_EP(av,vni_src,vni_dst) MPIDI_UCX_AV((av)).dest[vni_src][vni_dst]
22
23 #define MPIDI_UCX_WIN(win) ((win)->dev.netmod.ucx)
24 #define MPIDI_UCX_WIN_INFO(win, rank) MPIDI_UCX_WIN(win).info_table[rank]
25
MPIDI_UCX_init_tag(MPIR_Context_id_t contextid,int source,uint64_t tag)26 MPL_STATIC_INLINE_PREFIX uint64_t MPIDI_UCX_init_tag(MPIR_Context_id_t contextid, int source,
27 uint64_t tag)
28 {
29 uint64_t ucp_tag = 0;
30 ucp_tag = contextid;
31 ucp_tag = (ucp_tag << MPIDI_UCX_SOURCE_SHIFT);
32 ucp_tag |= source;
33 ucp_tag = (ucp_tag << MPIDI_UCX_TAG_SHIFT);
34 ucp_tag |= (MPIDI_UCX_TAG_MASK & tag);
35 return ucp_tag;
36 }
37
MPIDI_UCX_tag_mask(int mpi_tag,int src)38 MPL_STATIC_INLINE_PREFIX uint64_t MPIDI_UCX_tag_mask(int mpi_tag, int src)
39 {
40 uint64_t tag_mask = 0xffffffffffffffff;
41 MPIR_TAG_CLEAR_ERROR_BITS(tag_mask);
42 if (mpi_tag == MPI_ANY_TAG)
43 tag_mask &= ~MPIR_TAG_USABLE_BITS;
44
45 if (src == MPI_ANY_SOURCE)
46 tag_mask &= ~(MPIDI_UCX_SOURCE_MASK);
47
48 return tag_mask;
49 }
50
MPIDI_UCX_recv_tag(int mpi_tag,int src,MPIR_Context_id_t contextid)51 MPL_STATIC_INLINE_PREFIX uint64_t MPIDI_UCX_recv_tag(int mpi_tag, int src,
52 MPIR_Context_id_t contextid)
53 {
54 uint64_t ucp_tag = contextid;
55
56 ucp_tag = (ucp_tag << MPIDI_UCX_SOURCE_SHIFT);
57 if (src != MPI_ANY_SOURCE)
58 ucp_tag |= (src & UCS_MASK(MPIDI_UCX_CONTEXT_RANK_BITS));
59 ucp_tag = ucp_tag << MPIDI_UCX_TAG_SHIFT;
60 if (mpi_tag != MPI_ANY_TAG)
61 ucp_tag |= (MPIDI_UCX_TAG_MASK & mpi_tag);
62 return ucp_tag;
63 }
64
MPIDI_UCX_get_tag(uint64_t match_bits)65 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_get_tag(uint64_t match_bits)
66 {
67 return ((int) (match_bits & MPIDI_UCX_TAG_MASK));
68 }
69
MPIDI_UCX_get_source(uint64_t match_bits)70 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_get_source(uint64_t match_bits)
71 {
72 return ((int) ((match_bits & MPIDI_UCX_SOURCE_MASK) >> MPIDI_UCX_TAG_SHIFT));
73 }
74
75 #define MPIDI_UCX_CHK_STATUS(STATUS) \
76 do { \
77 MPIR_ERR_CHKANDJUMP4((STATUS!=UCS_OK && STATUS!=UCS_INPROGRESS), \
78 mpi_errno, \
79 MPI_ERR_OTHER, \
80 "**ucx_nm_status", \
81 "**ucx_nm_status %s %d %s %s", \
82 __SHORT_FILE__, \
83 __LINE__, \
84 __func__, \
85 ucs_status_string(STATUS)); \
86 } while (0)
87
88 #define MPIDI_UCX_CHK_REQUEST(_req) \
89 do { \
90 MPIR_ERR_CHKANDJUMP4(UCS_PTR_IS_ERR(_req), \
91 mpi_errno, \
92 MPI_ERR_OTHER, \
93 "**ucx_nm_rq_error", \
94 "**ucx_nm_rq_error %s %d %s %s", \
95 __SHORT_FILE__, \
96 __LINE__, \
97 __func__, \
98 ucs_status_string(UCS_PTR_STATUS(_req))); \
99 } while (0)
100
MPIDI_UCX_is_reachable_target(int rank,MPIR_Win * win,MPIDI_winattr_t winattr)101 MPL_STATIC_INLINE_PREFIX bool MPIDI_UCX_is_reachable_target(int rank, MPIR_Win * win,
102 MPIDI_winattr_t winattr)
103 {
104 /* unmapped win target does not have rkey. */
105 return (winattr & MPIDI_WINATTR_NM_REACHABLE) || (MPIDI_UCX_WIN(win).info_table &&
106 MPIDI_UCX_WIN_INFO(win, rank).rkey != NULL);
107 }
108
109 /* This function implements netmod vci to vni(context) mapping.
110 * It returns -1 if the vci does not have a mapping.
111 */
MPIDI_UCX_vci_to_vni(int vci)112 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_vci_to_vni(int vci)
113 {
114 return vci < MPIDI_UCX_global.num_vnis ? vci : -1;
115 }
116
117 /* vni mapping */
118 /* NOTE: concerned by the modulo? If we restrict num_vnis to power of 2,
119 * we may get away with bit mask */
MPIDI_UCX_get_vni(int flag,MPIR_Comm * comm_ptr,int src_rank,int dst_rank,int tag)120 MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_get_vni(int flag, MPIR_Comm * comm_ptr,
121 int src_rank, int dst_rank, int tag)
122 {
123 return MPIDI_get_vci(flag, comm_ptr, src_rank, dst_rank, tag) % MPIDI_UCX_global.num_vnis;
124 }
125
126 #endif /* UCX_IMPL_H_INCLUDED */
127