1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpidimpl.h"
7 #include "ucx_impl.h"
8
9 struct ucx_share {
10 int disp;
11 MPI_Aint addr;
12 };
13
14 static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void **base_ptr);
15 static int win_init(MPIR_Win * win);
16
win_allgather(MPIR_Win * win,size_t length,uint32_t disp_unit,void ** base_ptr)17 static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void **base_ptr)
18 {
19
20 MPIR_Errflag_t err = MPIR_ERR_NONE;
21 int mpi_errno = MPI_SUCCESS;
22 int rank = 0;
23 ucs_status_t status;
24 ucp_mem_h mem_h;
25 int cntr = 0;
26 size_t rkey_size = 0;
27 int *rkey_sizes = NULL, *recv_disps = NULL, i;
28 char *rkey_buffer = NULL, *rkey_recv_buff = NULL;
29 struct ucx_share *share_data = NULL;
30 ucp_mem_map_params_t mem_map_params;
31 ucp_mem_attr_t mem_attr;
32 MPIR_Comm *comm_ptr = win->comm_ptr;
33
34 ucp_context_h ucp_context = MPIDI_UCX_global.context;
35
36 MPIDI_UCX_WIN(win).info_table =
37 MPL_malloc(sizeof(MPIDI_UCX_win_info_t) * comm_ptr->local_size, MPL_MEM_OTHER);
38
39 rkey_size = 0;
40 mem_map_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
41 UCP_MEM_MAP_PARAM_FIELD_LENGTH | UCP_MEM_MAP_PARAM_FIELD_FLAGS;
42
43 mem_map_params.address = *base_ptr;
44 mem_map_params.length = length;
45 mem_map_params.flags = 0;
46
47 if (*base_ptr == NULL)
48 mem_map_params.flags |= UCP_MEM_MAP_ALLOCATE;
49
50 MPIDI_UCX_WIN(win).mem_mapped = false;
51
52 /* As of ucx-1.10, mapping with a CUDA device buffer may be successful but
53 * later RMA segfaults inside UCX. Thus, MPICH manually disables native RMA for device win buffer for now.*/
54 if (MPIR_GPU_query_pointer_is_dev(*base_ptr)) {
55 status = UCS_ERR_UNSUPPORTED;
56 } else {
57 status = ucp_mem_map(MPIDI_UCX_global.context, &mem_map_params, &mem_h);
58 }
59
60 /* some memory types cannot be mapped, skip rkey packing */
61 if (status != UCS_ERR_UNSUPPORTED) {
62 MPIDI_UCX_CHK_STATUS(status);
63
64 /* checked at win_free to unmap mem_h */
65 MPIDI_UCX_WIN(win).mem_mapped = true;
66 MPIDI_UCX_WIN(win).mem_h = mem_h;
67
68 /* query allocated address. */
69 mem_attr.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS | UCP_MEM_ATTR_FIELD_LENGTH;
70 status = ucp_mem_query(mem_h, &mem_attr);
71 MPIDI_UCX_CHK_STATUS(status);
72
73 *base_ptr = mem_attr.address;
74 MPIR_Assert(mem_attr.length >= length);
75
76 /* pack the key */
77 status = ucp_rkey_pack(ucp_context, mem_h, (void **) &rkey_buffer, &rkey_size);
78
79 MPIDI_UCX_CHK_STATUS(status);
80 }
81
82 rkey_sizes = (int *) MPL_malloc(sizeof(int) * comm_ptr->local_size, MPL_MEM_OTHER);
83 rkey_sizes[comm_ptr->rank] = (int) rkey_size;
84 mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 1, MPI_INT, rkey_sizes, 1, MPI_INT, comm_ptr, &err);
85
86 MPIR_ERR_CHECK(mpi_errno);
87
88 recv_disps = (int *) MPL_malloc(sizeof(int) * comm_ptr->local_size, MPL_MEM_OTHER);
89
90
91 for (i = 0; i < comm_ptr->local_size; i++) {
92 recv_disps[i] = cntr;
93 cntr += rkey_sizes[i];
94 }
95
96 rkey_recv_buff = MPL_malloc(cntr, MPL_MEM_OTHER);
97
98 /* allgather */
99 mpi_errno = MPIR_Allgatherv(rkey_buffer, rkey_size, MPI_BYTE,
100 rkey_recv_buff, rkey_sizes, recv_disps, MPI_BYTE, comm_ptr, &err);
101
102 MPIR_ERR_CHECK(mpi_errno);
103
104 /* If we use the shared memory support in UCX, we have to distinguish between local
105 * and remote windows (at least now). If win_create is used, the key cannot be unpackt -
106 * then we need our fallback-solution */
107
108 bool all_reachable = true, none_reachable = true;
109 for (i = 0; i < comm_ptr->local_size; i++) {
110 /* Skip unmapped remote region. */
111 if (rkey_sizes[i] == 0) {
112 all_reachable = false;
113 MPIDI_UCX_WIN_INFO(win, i).rkey = NULL;
114 continue;
115 }
116
117 status = ucp_ep_rkey_unpack(MPIDI_UCX_COMM_TO_EP(comm_ptr, i, 0, 0),
118 &rkey_recv_buff[recv_disps[i]],
119 &(MPIDI_UCX_WIN_INFO(win, i).rkey));
120 if (status == UCS_ERR_UNREACHABLE) {
121 all_reachable = false;
122 MPIDI_UCX_WIN_INFO(win, i).rkey = NULL;
123 } else {
124 MPIDI_UCX_CHK_STATUS(status);
125 none_reachable = false;
126 }
127 }
128
129 if (none_reachable)
130 goto am_fallback;
131
132 share_data = MPL_malloc(comm_ptr->local_size * sizeof(struct ucx_share), MPL_MEM_OTHER);
133
134 share_data[comm_ptr->rank].disp = disp_unit;
135 share_data[comm_ptr->rank].addr = (MPI_Aint) * base_ptr;
136
137 mpi_errno =
138 MPIR_Allgather(MPI_IN_PLACE, sizeof(struct ucx_share), MPI_BYTE, share_data,
139 sizeof(struct ucx_share), MPI_BYTE, comm_ptr, &err);
140 MPIR_ERR_CHECK(mpi_errno);
141
142 for (i = 0; i < comm_ptr->local_size; i++) {
143 MPIDI_UCX_WIN_INFO(win, i).disp = share_data[i].disp;
144 MPIDI_UCX_WIN_INFO(win, i).addr = share_data[i].addr;
145 }
146
147 MPIDI_UCX_WIN(win).target_sync =
148 MPL_malloc(sizeof(MPIDI_UCX_win_target_sync_t) * comm_ptr->local_size, MPL_MEM_RMA);
149 MPIR_Assert(MPIDI_UCX_WIN(win).target_sync);
150 for (rank = 0; rank < win->comm_ptr->local_size; rank++)
151 MPIDI_UCX_WIN(win).target_sync[rank].need_sync = MPIDI_UCX_WIN_SYNC_UNSET;
152
153 if (all_reachable)
154 MPIDI_WIN(win, winattr) |= MPIDI_WINATTR_NM_REACHABLE;
155
156 fn_exit:
157 /* buffer release */
158 if (rkey_buffer)
159 ucp_rkey_buffer_release(rkey_buffer);
160 /* free temps */
161 MPL_free(share_data);
162 MPL_free(rkey_sizes);
163 MPL_free(recv_disps);
164 MPL_free(rkey_recv_buff);
165 return mpi_errno;
166 am_fallback:
167 MPL_free(MPIDI_UCX_WIN(win).info_table);
168 MPIDI_UCX_WIN(win).info_table = NULL;
169 fn_fail:
170 goto fn_exit;
171 }
172
win_init(MPIR_Win * win)173 static int win_init(MPIR_Win * win)
174 {
175 int mpi_errno = MPI_SUCCESS;
176 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_WIN_INIT);
177 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_WIN_INIT);
178
179 memset(&MPIDI_UCX_WIN(win), 0, sizeof(MPIDI_UCX_win_t));
180
181 fn_exit:
182 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_WIN_INIT);
183 return mpi_errno;
184 fn_fail:
185 goto fn_exit;
186 }
187
MPIDI_UCX_mpi_win_set_info(MPIR_Win * win,MPIR_Info * info)188 int MPIDI_UCX_mpi_win_set_info(MPIR_Win * win, MPIR_Info * info)
189 {
190 return MPIDIG_mpi_win_set_info(win, info);
191 }
192
MPIDI_UCX_mpi_win_get_info(MPIR_Win * win,MPIR_Info ** info_p_p)193 int MPIDI_UCX_mpi_win_get_info(MPIR_Win * win, MPIR_Info ** info_p_p)
194 {
195 return MPIDIG_mpi_win_get_info(win, info_p_p);
196 }
197
MPIDI_UCX_mpi_win_free(MPIR_Win ** win_ptr)198 int MPIDI_UCX_mpi_win_free(MPIR_Win ** win_ptr)
199 {
200 return MPIDIG_mpi_win_free(win_ptr);
201 }
202
MPIDI_UCX_mpi_win_create(void * base,MPI_Aint length,int disp_unit,MPIR_Info * info,MPIR_Comm * comm_ptr,MPIR_Win ** win_ptr)203 int MPIDI_UCX_mpi_win_create(void *base, MPI_Aint length, int disp_unit, MPIR_Info * info,
204 MPIR_Comm * comm_ptr, MPIR_Win ** win_ptr)
205 {
206 return MPIDIG_mpi_win_create(base, length, disp_unit, info, comm_ptr, win_ptr);
207 }
208
MPIDI_UCX_mpi_win_attach(MPIR_Win * win,void * base,MPI_Aint size)209 int MPIDI_UCX_mpi_win_attach(MPIR_Win * win, void *base, MPI_Aint size)
210 {
211 return MPIDIG_mpi_win_attach(win, base, size);
212 }
213
MPIDI_UCX_mpi_win_allocate_shared(MPI_Aint size,int disp_unit,MPIR_Info * info_ptr,MPIR_Comm * comm_ptr,void ** base_ptr,MPIR_Win ** win_ptr)214 int MPIDI_UCX_mpi_win_allocate_shared(MPI_Aint size, int disp_unit, MPIR_Info * info_ptr,
215 MPIR_Comm * comm_ptr, void **base_ptr, MPIR_Win ** win_ptr)
216 {
217 return MPIDIG_mpi_win_allocate_shared(size, disp_unit, info_ptr, comm_ptr, base_ptr, win_ptr);
218 }
219
MPIDI_UCX_mpi_win_detach(MPIR_Win * win,const void * base)220 int MPIDI_UCX_mpi_win_detach(MPIR_Win * win, const void *base)
221 {
222 return MPIDIG_mpi_win_detach(win, base);
223 }
224
MPIDI_UCX_mpi_win_allocate(MPI_Aint length,int disp_unit,MPIR_Info * info,MPIR_Comm * comm_ptr,void * baseptr,MPIR_Win ** win_ptr)225 int MPIDI_UCX_mpi_win_allocate(MPI_Aint length, int disp_unit, MPIR_Info * info,
226 MPIR_Comm * comm_ptr, void *baseptr, MPIR_Win ** win_ptr)
227 {
228 return MPIDIG_mpi_win_allocate(length, disp_unit, info, comm_ptr, baseptr, win_ptr);
229 }
230
MPIDI_UCX_mpi_win_create_dynamic(MPIR_Info * info,MPIR_Comm * comm,MPIR_Win ** win)231 int MPIDI_UCX_mpi_win_create_dynamic(MPIR_Info * info, MPIR_Comm * comm, MPIR_Win ** win)
232 {
233 return MPIDIG_mpi_win_create_dynamic(info, comm, win);
234 }
235
MPIDI_UCX_mpi_win_create_hook(MPIR_Win * win)236 int MPIDI_UCX_mpi_win_create_hook(MPIR_Win * win)
237 {
238 int mpi_errno = MPI_SUCCESS;
239
240 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_WIN_CREATE_HOOK);
241 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_WIN_CREATE_HOOK);
242
243 mpi_errno = win_init(win);
244 if (mpi_errno != MPI_SUCCESS)
245 goto fn_fail;
246
247 mpi_errno = win_allgather(win, win->size, win->disp_unit, &win->base);
248 if (mpi_errno != MPI_SUCCESS)
249 goto fn_fail;
250
251 fn_exit:
252 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_WIN_CREATE_HOOK);
253 return mpi_errno;
254 fn_fail:
255 goto fn_exit;
256 }
257
MPIDI_UCX_mpi_win_allocate_hook(MPIR_Win * win)258 int MPIDI_UCX_mpi_win_allocate_hook(MPIR_Win * win)
259 {
260 int mpi_errno = MPI_SUCCESS;
261
262 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_WIN_ALLOCATE_HOOK);
263 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_WIN_ALLOCATE_HOOK);
264
265 mpi_errno = win_init(win);
266 if (mpi_errno != MPI_SUCCESS)
267 goto fn_fail;
268
269 mpi_errno = win_allgather(win, win->size, win->disp_unit, &win->base);
270 if (mpi_errno != MPI_SUCCESS)
271 goto fn_fail;
272
273 fn_exit:
274 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_WIN_ALLOCATE_HOOK);
275 return mpi_errno;
276 fn_fail:
277 goto fn_exit;
278 }
279
MPIDI_UCX_mpi_win_allocate_shared_hook(MPIR_Win * win)280 int MPIDI_UCX_mpi_win_allocate_shared_hook(MPIR_Win * win)
281 {
282 return win_init(win);
283 }
284
MPIDI_UCX_mpi_win_create_dynamic_hook(MPIR_Win * win)285 int MPIDI_UCX_mpi_win_create_dynamic_hook(MPIR_Win * win)
286 {
287 return win_init(win);
288 }
289
MPIDI_UCX_mpi_win_attach_hook(MPIR_Win * win,void * base,MPI_Aint size)290 int MPIDI_UCX_mpi_win_attach_hook(MPIR_Win * win, void *base, MPI_Aint size)
291 {
292 return MPI_SUCCESS;
293 }
294
MPIDI_UCX_mpi_win_detach_hook(MPIR_Win * win,const void * base)295 int MPIDI_UCX_mpi_win_detach_hook(MPIR_Win * win, const void *base)
296 {
297 return MPI_SUCCESS;
298 }
299
MPIDI_UCX_mpi_win_free_hook(MPIR_Win * win)300 int MPIDI_UCX_mpi_win_free_hook(MPIR_Win * win)
301 {
302 int mpi_errno = MPI_SUCCESS;
303 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_WIN_FREE_HOOK);
304 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_WIN_FREE_HOOK);
305
306 if (MPIDI_UCX_WIN(win).info_table) {
307 int i;
308 for (i = 0; i < win->comm_ptr->local_size; i++) {
309 if (MPIDI_UCX_WIN_INFO(win, i).rkey) {
310 ucp_rkey_destroy(MPIDI_UCX_WIN_INFO(win, i).rkey);
311 }
312 }
313 }
314
315 /* Skip unmap for unsupported mem type */
316 if (MPIDI_UCX_WIN(win).mem_mapped)
317 ucp_mem_unmap(MPIDI_UCX_global.context, MPIDI_UCX_WIN(win).mem_h);
318 MPL_free(MPIDI_UCX_WIN(win).info_table);
319 MPL_free(MPIDI_UCX_WIN(win).target_sync);
320
321 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_WIN_FREE_HOOK);
322 return mpi_errno;
323 }
324