1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpidimpl.h"
7 #include "ucx_impl.h"
8 
9 struct ucx_share {
10     int disp;
11     MPI_Aint addr;
12 };
13 
14 static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void **base_ptr);
15 static int win_init(MPIR_Win * win);
16 
win_allgather(MPIR_Win * win,size_t length,uint32_t disp_unit,void ** base_ptr)17 static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void **base_ptr)
18 {
19 
20     MPIR_Errflag_t err = MPIR_ERR_NONE;
21     int mpi_errno = MPI_SUCCESS;
22     int rank = 0;
23     ucs_status_t status;
24     ucp_mem_h mem_h;
25     int cntr = 0;
26     size_t rkey_size = 0;
27     int *rkey_sizes = NULL, *recv_disps = NULL, i;
28     char *rkey_buffer = NULL, *rkey_recv_buff = NULL;
29     struct ucx_share *share_data = NULL;
30     ucp_mem_map_params_t mem_map_params;
31     ucp_mem_attr_t mem_attr;
32     MPIR_Comm *comm_ptr = win->comm_ptr;
33 
34     ucp_context_h ucp_context = MPIDI_UCX_global.context;
35 
36     MPIDI_UCX_WIN(win).info_table =
37         MPL_malloc(sizeof(MPIDI_UCX_win_info_t) * comm_ptr->local_size, MPL_MEM_OTHER);
38 
39     rkey_size = 0;
40     mem_map_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
41         UCP_MEM_MAP_PARAM_FIELD_LENGTH | UCP_MEM_MAP_PARAM_FIELD_FLAGS;
42 
43     mem_map_params.address = *base_ptr;
44     mem_map_params.length = length;
45     mem_map_params.flags = 0;
46 
47     if (*base_ptr == NULL)
48         mem_map_params.flags |= UCP_MEM_MAP_ALLOCATE;
49 
50     MPIDI_UCX_WIN(win).mem_mapped = false;
51 
52     /* As of ucx-1.10, mapping with a CUDA device buffer may be successful but
53      * later RMA segfaults inside UCX. Thus, MPICH manually disables native RMA for device win buffer for now.*/
54     if (MPIR_GPU_query_pointer_is_dev(*base_ptr)) {
55         status = UCS_ERR_UNSUPPORTED;
56     } else {
57         status = ucp_mem_map(MPIDI_UCX_global.context, &mem_map_params, &mem_h);
58     }
59 
60     /* some memory types cannot be mapped, skip rkey packing */
61     if (status != UCS_ERR_UNSUPPORTED) {
62         MPIDI_UCX_CHK_STATUS(status);
63 
64         /* checked at win_free to unmap mem_h */
65         MPIDI_UCX_WIN(win).mem_mapped = true;
66         MPIDI_UCX_WIN(win).mem_h = mem_h;
67 
68         /* query allocated address. */
69         mem_attr.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS | UCP_MEM_ATTR_FIELD_LENGTH;
70         status = ucp_mem_query(mem_h, &mem_attr);
71         MPIDI_UCX_CHK_STATUS(status);
72 
73         *base_ptr = mem_attr.address;
74         MPIR_Assert(mem_attr.length >= length);
75 
76         /* pack the key */
77         status = ucp_rkey_pack(ucp_context, mem_h, (void **) &rkey_buffer, &rkey_size);
78 
79         MPIDI_UCX_CHK_STATUS(status);
80     }
81 
82     rkey_sizes = (int *) MPL_malloc(sizeof(int) * comm_ptr->local_size, MPL_MEM_OTHER);
83     rkey_sizes[comm_ptr->rank] = (int) rkey_size;
84     mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 1, MPI_INT, rkey_sizes, 1, MPI_INT, comm_ptr, &err);
85 
86     MPIR_ERR_CHECK(mpi_errno);
87 
88     recv_disps = (int *) MPL_malloc(sizeof(int) * comm_ptr->local_size, MPL_MEM_OTHER);
89 
90 
91     for (i = 0; i < comm_ptr->local_size; i++) {
92         recv_disps[i] = cntr;
93         cntr += rkey_sizes[i];
94     }
95 
96     rkey_recv_buff = MPL_malloc(cntr, MPL_MEM_OTHER);
97 
98     /* allgather */
99     mpi_errno = MPIR_Allgatherv(rkey_buffer, rkey_size, MPI_BYTE,
100                                 rkey_recv_buff, rkey_sizes, recv_disps, MPI_BYTE, comm_ptr, &err);
101 
102     MPIR_ERR_CHECK(mpi_errno);
103 
104     /* If we use the shared memory support in UCX, we have to distinguish between local
105      * and remote windows (at least now). If win_create is used, the key cannot be unpackt -
106      * then we need our fallback-solution */
107 
108     bool all_reachable = true, none_reachable = true;
109     for (i = 0; i < comm_ptr->local_size; i++) {
110         /* Skip unmapped remote region. */
111         if (rkey_sizes[i] == 0) {
112             all_reachable = false;
113             MPIDI_UCX_WIN_INFO(win, i).rkey = NULL;
114             continue;
115         }
116 
117         status = ucp_ep_rkey_unpack(MPIDI_UCX_COMM_TO_EP(comm_ptr, i, 0, 0),
118                                     &rkey_recv_buff[recv_disps[i]],
119                                     &(MPIDI_UCX_WIN_INFO(win, i).rkey));
120         if (status == UCS_ERR_UNREACHABLE) {
121             all_reachable = false;
122             MPIDI_UCX_WIN_INFO(win, i).rkey = NULL;
123         } else {
124             MPIDI_UCX_CHK_STATUS(status);
125             none_reachable = false;
126         }
127     }
128 
129     if (none_reachable)
130         goto am_fallback;
131 
132     share_data = MPL_malloc(comm_ptr->local_size * sizeof(struct ucx_share), MPL_MEM_OTHER);
133 
134     share_data[comm_ptr->rank].disp = disp_unit;
135     share_data[comm_ptr->rank].addr = (MPI_Aint) * base_ptr;
136 
137     mpi_errno =
138         MPIR_Allgather(MPI_IN_PLACE, sizeof(struct ucx_share), MPI_BYTE, share_data,
139                        sizeof(struct ucx_share), MPI_BYTE, comm_ptr, &err);
140     MPIR_ERR_CHECK(mpi_errno);
141 
142     for (i = 0; i < comm_ptr->local_size; i++) {
143         MPIDI_UCX_WIN_INFO(win, i).disp = share_data[i].disp;
144         MPIDI_UCX_WIN_INFO(win, i).addr = share_data[i].addr;
145     }
146 
147     MPIDI_UCX_WIN(win).target_sync =
148         MPL_malloc(sizeof(MPIDI_UCX_win_target_sync_t) * comm_ptr->local_size, MPL_MEM_RMA);
149     MPIR_Assert(MPIDI_UCX_WIN(win).target_sync);
150     for (rank = 0; rank < win->comm_ptr->local_size; rank++)
151         MPIDI_UCX_WIN(win).target_sync[rank].need_sync = MPIDI_UCX_WIN_SYNC_UNSET;
152 
153     if (all_reachable)
154         MPIDI_WIN(win, winattr) |= MPIDI_WINATTR_NM_REACHABLE;
155 
156   fn_exit:
157     /* buffer release */
158     if (rkey_buffer)
159         ucp_rkey_buffer_release(rkey_buffer);
160     /* free temps */
161     MPL_free(share_data);
162     MPL_free(rkey_sizes);
163     MPL_free(recv_disps);
164     MPL_free(rkey_recv_buff);
165     return mpi_errno;
166   am_fallback:
167     MPL_free(MPIDI_UCX_WIN(win).info_table);
168     MPIDI_UCX_WIN(win).info_table = NULL;
169   fn_fail:
170     goto fn_exit;
171 }
172 
win_init(MPIR_Win * win)173 static int win_init(MPIR_Win * win)
174 {
175     int mpi_errno = MPI_SUCCESS;
176     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_WIN_INIT);
177     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_WIN_INIT);
178 
179     memset(&MPIDI_UCX_WIN(win), 0, sizeof(MPIDI_UCX_win_t));
180 
181   fn_exit:
182     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_WIN_INIT);
183     return mpi_errno;
184   fn_fail:
185     goto fn_exit;
186 }
187 
MPIDI_UCX_mpi_win_set_info(MPIR_Win * win,MPIR_Info * info)188 int MPIDI_UCX_mpi_win_set_info(MPIR_Win * win, MPIR_Info * info)
189 {
190     return MPIDIG_mpi_win_set_info(win, info);
191 }
192 
MPIDI_UCX_mpi_win_get_info(MPIR_Win * win,MPIR_Info ** info_p_p)193 int MPIDI_UCX_mpi_win_get_info(MPIR_Win * win, MPIR_Info ** info_p_p)
194 {
195     return MPIDIG_mpi_win_get_info(win, info_p_p);
196 }
197 
MPIDI_UCX_mpi_win_free(MPIR_Win ** win_ptr)198 int MPIDI_UCX_mpi_win_free(MPIR_Win ** win_ptr)
199 {
200     return MPIDIG_mpi_win_free(win_ptr);
201 }
202 
MPIDI_UCX_mpi_win_create(void * base,MPI_Aint length,int disp_unit,MPIR_Info * info,MPIR_Comm * comm_ptr,MPIR_Win ** win_ptr)203 int MPIDI_UCX_mpi_win_create(void *base, MPI_Aint length, int disp_unit, MPIR_Info * info,
204                              MPIR_Comm * comm_ptr, MPIR_Win ** win_ptr)
205 {
206     return MPIDIG_mpi_win_create(base, length, disp_unit, info, comm_ptr, win_ptr);
207 }
208 
MPIDI_UCX_mpi_win_attach(MPIR_Win * win,void * base,MPI_Aint size)209 int MPIDI_UCX_mpi_win_attach(MPIR_Win * win, void *base, MPI_Aint size)
210 {
211     return MPIDIG_mpi_win_attach(win, base, size);
212 }
213 
MPIDI_UCX_mpi_win_allocate_shared(MPI_Aint size,int disp_unit,MPIR_Info * info_ptr,MPIR_Comm * comm_ptr,void ** base_ptr,MPIR_Win ** win_ptr)214 int MPIDI_UCX_mpi_win_allocate_shared(MPI_Aint size, int disp_unit, MPIR_Info * info_ptr,
215                                       MPIR_Comm * comm_ptr, void **base_ptr, MPIR_Win ** win_ptr)
216 {
217     return MPIDIG_mpi_win_allocate_shared(size, disp_unit, info_ptr, comm_ptr, base_ptr, win_ptr);
218 }
219 
MPIDI_UCX_mpi_win_detach(MPIR_Win * win,const void * base)220 int MPIDI_UCX_mpi_win_detach(MPIR_Win * win, const void *base)
221 {
222     return MPIDIG_mpi_win_detach(win, base);
223 }
224 
MPIDI_UCX_mpi_win_allocate(MPI_Aint length,int disp_unit,MPIR_Info * info,MPIR_Comm * comm_ptr,void * baseptr,MPIR_Win ** win_ptr)225 int MPIDI_UCX_mpi_win_allocate(MPI_Aint length, int disp_unit, MPIR_Info * info,
226                                MPIR_Comm * comm_ptr, void *baseptr, MPIR_Win ** win_ptr)
227 {
228     return MPIDIG_mpi_win_allocate(length, disp_unit, info, comm_ptr, baseptr, win_ptr);
229 }
230 
MPIDI_UCX_mpi_win_create_dynamic(MPIR_Info * info,MPIR_Comm * comm,MPIR_Win ** win)231 int MPIDI_UCX_mpi_win_create_dynamic(MPIR_Info * info, MPIR_Comm * comm, MPIR_Win ** win)
232 {
233     return MPIDIG_mpi_win_create_dynamic(info, comm, win);
234 }
235 
MPIDI_UCX_mpi_win_create_hook(MPIR_Win * win)236 int MPIDI_UCX_mpi_win_create_hook(MPIR_Win * win)
237 {
238     int mpi_errno = MPI_SUCCESS;
239 
240     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_WIN_CREATE_HOOK);
241     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_WIN_CREATE_HOOK);
242 
243     mpi_errno = win_init(win);
244     if (mpi_errno != MPI_SUCCESS)
245         goto fn_fail;
246 
247     mpi_errno = win_allgather(win, win->size, win->disp_unit, &win->base);
248     if (mpi_errno != MPI_SUCCESS)
249         goto fn_fail;
250 
251   fn_exit:
252     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_WIN_CREATE_HOOK);
253     return mpi_errno;
254   fn_fail:
255     goto fn_exit;
256 }
257 
MPIDI_UCX_mpi_win_allocate_hook(MPIR_Win * win)258 int MPIDI_UCX_mpi_win_allocate_hook(MPIR_Win * win)
259 {
260     int mpi_errno = MPI_SUCCESS;
261 
262     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_WIN_ALLOCATE_HOOK);
263     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_WIN_ALLOCATE_HOOK);
264 
265     mpi_errno = win_init(win);
266     if (mpi_errno != MPI_SUCCESS)
267         goto fn_fail;
268 
269     mpi_errno = win_allgather(win, win->size, win->disp_unit, &win->base);
270     if (mpi_errno != MPI_SUCCESS)
271         goto fn_fail;
272 
273   fn_exit:
274     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_WIN_ALLOCATE_HOOK);
275     return mpi_errno;
276   fn_fail:
277     goto fn_exit;
278 }
279 
MPIDI_UCX_mpi_win_allocate_shared_hook(MPIR_Win * win)280 int MPIDI_UCX_mpi_win_allocate_shared_hook(MPIR_Win * win)
281 {
282     return win_init(win);
283 }
284 
MPIDI_UCX_mpi_win_create_dynamic_hook(MPIR_Win * win)285 int MPIDI_UCX_mpi_win_create_dynamic_hook(MPIR_Win * win)
286 {
287     return win_init(win);
288 }
289 
MPIDI_UCX_mpi_win_attach_hook(MPIR_Win * win,void * base,MPI_Aint size)290 int MPIDI_UCX_mpi_win_attach_hook(MPIR_Win * win, void *base, MPI_Aint size)
291 {
292     return MPI_SUCCESS;
293 }
294 
MPIDI_UCX_mpi_win_detach_hook(MPIR_Win * win,const void * base)295 int MPIDI_UCX_mpi_win_detach_hook(MPIR_Win * win, const void *base)
296 {
297     return MPI_SUCCESS;
298 }
299 
MPIDI_UCX_mpi_win_free_hook(MPIR_Win * win)300 int MPIDI_UCX_mpi_win_free_hook(MPIR_Win * win)
301 {
302     int mpi_errno = MPI_SUCCESS;
303     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_WIN_FREE_HOOK);
304     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_WIN_FREE_HOOK);
305 
306     if (MPIDI_UCX_WIN(win).info_table) {
307         int i;
308         for (i = 0; i < win->comm_ptr->local_size; i++) {
309             if (MPIDI_UCX_WIN_INFO(win, i).rkey) {
310                 ucp_rkey_destroy(MPIDI_UCX_WIN_INFO(win, i).rkey);
311             }
312         }
313     }
314 
315     /* Skip unmap for unsupported mem type */
316     if (MPIDI_UCX_WIN(win).mem_mapped)
317         ucp_mem_unmap(MPIDI_UCX_global.context, MPIDI_UCX_WIN(win).mem_h);
318     MPL_free(MPIDI_UCX_WIN(win).info_table);
319     MPL_free(MPIDI_UCX_WIN(win).target_sync);
320 
321     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_WIN_FREE_HOOK);
322     return mpi_errno;
323 }
324