1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include <mpidimpl.h>
7 #include "mpidu_init_shm.h"
8 #include "mpl_shm.h"
9 #include "mpidimpl.h"
10 #include "mpir_pmi.h"
11 #include "mpidu_shm_seg.h"
12 
13 #ifdef ENABLE_NO_LOCAL
14 /* shared memory disabled, just stubs */
15 
MPIDU_Init_shm_init(void)16 int MPIDU_Init_shm_init(void)
17 {
18     return MPI_SUCCESS;
19 }
20 
MPIDU_Init_shm_finalize(void)21 int MPIDU_Init_shm_finalize(void)
22 {
23     return MPI_SUCCESS;
24 }
25 
MPIDU_Init_shm_barrier(void)26 int MPIDU_Init_shm_barrier(void)
27 {
28     return MPI_SUCCESS;
29 }
30 
31 /* proper code should never call following under NO_LOCAL */
MPIDU_Init_shm_put(void * orig,size_t len)32 int MPIDU_Init_shm_put(void *orig, size_t len)
33 {
34     MPIR_Assert(0);
35     return MPI_SUCCESS;
36 }
37 
MPIDU_Init_shm_get(int local_rank,size_t len,void * target)38 int MPIDU_Init_shm_get(int local_rank, size_t len, void *target)
39 {
40     MPIR_Assert(0);
41     return MPI_SUCCESS;
42 }
43 
MPIDU_Init_shm_query(int local_rank,void ** target_addr)44 int MPIDU_Init_shm_query(int local_rank, void **target_addr)
45 {
46     MPIR_Assert(0);
47     return MPI_SUCCESS;
48 }
49 
50 #else /* ENABLE_NO_LOCAL */
51 typedef struct Init_shm_barrier {
52     MPL_atomic_int_t val;
53     MPL_atomic_int_t wait;
54 } Init_shm_barrier_t;
55 
56 static int local_size;
57 static int my_local_rank;
58 static MPIDU_shm_seg_t memory;
59 static Init_shm_barrier_t *barrier;
60 static void *baseaddr;
61 
62 static int sense;
63 static int barrier_init = 0;
64 
Init_shm_barrier_init(int is_root)65 static int Init_shm_barrier_init(int is_root)
66 {
67     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_INIT_SHM_BARRIER_INIT);
68 
69     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_INIT_SHM_BARRIER_INIT);
70 
71     barrier = (Init_shm_barrier_t *) memory.base_addr;
72     if (is_root) {
73         MPL_atomic_store_int(&barrier->val, 0);
74         MPL_atomic_store_int(&barrier->wait, 0);
75     }
76     sense = 0;
77     barrier_init = 1;
78 
79     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_INIT_SHM_BARRIER_INIT);
80 
81     return MPI_SUCCESS;
82 }
83 
Init_shm_barrier()84 static int Init_shm_barrier()
85 {
86     int mpi_errno = MPI_SUCCESS;
87     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_INIT_SHM_BARRIER);
88 
89     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_INIT_SHM_BARRIER);
90 
91     if (local_size == 1)
92         goto fn_exit;
93 
94     MPIR_ERR_CHKINTERNAL(!barrier_init, mpi_errno, "barrier not initialized");
95 
96     if (MPL_atomic_fetch_add_int(&barrier->val, 1) == local_size - 1) {
97         MPL_atomic_store_int(&barrier->val, 0);
98         MPL_atomic_store_int(&barrier->wait, 1 - sense);
99     } else {
100         /* wait */
101         while (MPL_atomic_load_int(&barrier->wait) == sense)
102             MPL_sched_yield();  /* skip */
103     }
104     sense = 1 - sense;
105 
106   fn_fail:
107   fn_exit:
108     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_INIT_SHM_BARRIER);
109     return mpi_errno;
110 }
111 
MPIDU_Init_shm_init(void)112 int MPIDU_Init_shm_init(void)
113 {
114     int mpi_errno = MPI_SUCCESS, mpl_err = 0;
115     int local_leader;
116     int rank;
117     MPIR_CHKPMEM_DECL(1);
118     MPIR_CHKLMEM_DECL(1);
119 
120     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDU_INIT_SHM_INIT);
121     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDU_INIT_SHM_INIT);
122 
123     rank = MPIR_Process.rank;
124     local_size = MPIR_Process.local_size;
125     my_local_rank = MPIR_Process.local_rank;
126     local_leader = MPIR_Process.node_local_map[0];
127 
128     size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN + sizeof(MPIDU_Init_shm_block_t) * local_size;
129 
130     char *serialized_hnd = NULL;
131     int serialized_hnd_size = 0;
132 
133     mpl_err = MPL_shm_hnd_init(&(memory.hnd));
134     MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
135 
136     memory.segment_len = segment_len;
137 
138     if (local_size == 1) {
139         char *addr;
140 
141         MPIR_CHKPMEM_MALLOC(addr, char *, segment_len + MPIDU_SHM_CACHE_LINE_LEN, mpi_errno,
142                             "segment", MPL_MEM_SHM);
143 
144         memory.base_addr = addr;
145         baseaddr =
146             (char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) &
147                       (~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1)));
148         memory.symmetrical = 0;
149 
150         mpi_errno = Init_shm_barrier_init(TRUE);
151         MPIR_ERR_CHECK(mpi_errno);
152     } else {
153         if (my_local_rank == 0) {
154             /* root prepare shm segment */
155             mpl_err = MPL_shm_seg_create_and_attach(memory.hnd, memory.segment_len,
156                                                     (void **) &(memory.base_addr), 0);
157             MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
158 
159             MPIR_Assert(local_leader == rank);
160 
161             mpl_err = MPL_shm_hnd_get_serialized_by_ref(memory.hnd, &serialized_hnd);
162             MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
163             serialized_hnd_size = strlen(serialized_hnd);
164             MPIR_Assert(serialized_hnd_size < MPIR_pmi_max_val_size());
165 
166             mpi_errno = Init_shm_barrier_init(TRUE);
167             MPIR_ERR_CHECK(mpi_errno);
168         } else {
169             /* non-root prepare to recv */
170             serialized_hnd_size = MPIR_pmi_max_val_size();
171             MPIR_CHKLMEM_MALLOC(serialized_hnd, char *, serialized_hnd_size, mpi_errno, "val",
172                                 MPL_MEM_OTHER);
173         }
174     }
175     /* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
176      * inside depend on PMI versions, and all processes need participate.
177      */
178     MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL);
179     if (local_size != 1) {
180         MPIR_Assert(local_size > 1);
181         if (my_local_rank > 0) {
182             /* non-root attach shm segment */
183             mpl_err = MPL_shm_hnd_deserialize(memory.hnd, serialized_hnd, strlen(serialized_hnd));
184             MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
185 
186             mpl_err = MPL_shm_seg_attach(memory.hnd, memory.segment_len,
187                                          (void **) &memory.base_addr, 0);
188             MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem");
189 
190             mpi_errno = Init_shm_barrier_init(FALSE);
191             MPIR_ERR_CHECK(mpi_errno);
192         }
193 
194         mpi_errno = Init_shm_barrier();
195         MPIR_ERR_CHECK(mpi_errno);
196 
197         if (my_local_rank == 0) {
198             /* memory->hnd no longer needed */
199             mpl_err = MPL_shm_seg_remove(memory.hnd);
200             MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem");
201         }
202 
203         baseaddr = memory.base_addr + MPIDU_SHM_CACHE_LINE_LEN;
204         memory.symmetrical = 0;
205     }
206 
207     mpi_errno = Init_shm_barrier();
208     MPIR_CHKPMEM_COMMIT();
209 
210   fn_exit:
211     MPIR_CHKLMEM_FREEALL();
212     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDU_INIT_SHM_INIT);
213     return mpi_errno;
214   fn_fail:
215     MPIR_CHKPMEM_REAP();
216     goto fn_exit;
217 }
218 
MPIDU_Init_shm_finalize(void)219 int MPIDU_Init_shm_finalize(void)
220 {
221     int mpi_errno = MPI_SUCCESS, mpl_err;
222 
223     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDU_INIT_SHM_FINALIZE);
224     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDU_INIT_SHM_FINALIZE);
225 
226     mpi_errno = Init_shm_barrier();
227     MPIR_ERR_CHECK(mpi_errno);
228 
229     if (local_size == 1)
230         MPL_free(memory.base_addr);
231     else {
232         mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len);
233         MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
234     }
235 
236   fn_exit:
237     MPL_shm_hnd_finalize(&(memory.hnd));
238     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDU_INIT_SHM_FINALIZE);
239     return mpi_errno;
240   fn_fail:
241     goto fn_exit;
242 }
243 
MPIDU_Init_shm_barrier(void)244 int MPIDU_Init_shm_barrier(void)
245 {
246     int mpi_errno = MPI_SUCCESS;
247 
248     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDU_INIT_SHM_BARRIER);
249     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDU_INIT_SHM_BARRIER);
250 
251     mpi_errno = Init_shm_barrier();
252 
253     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDU_INIT_SHM_BARRIER);
254 
255     return mpi_errno;
256 }
257 
MPIDU_Init_shm_put(void * orig,size_t len)258 int MPIDU_Init_shm_put(void *orig, size_t len)
259 {
260     int mpi_errno = MPI_SUCCESS;
261 
262     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDU_INIT_SHM_PUT);
263     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDU_INIT_SHM_PUT);
264 
265     MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t));
266     MPIR_Memcpy((char *) baseaddr + my_local_rank * sizeof(MPIDU_Init_shm_block_t), orig, len);
267 
268     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDU_INIT_SHM_PUT);
269 
270     return mpi_errno;
271 }
272 
MPIDU_Init_shm_get(int local_rank,size_t len,void * target)273 int MPIDU_Init_shm_get(int local_rank, size_t len, void *target)
274 {
275     int mpi_errno = MPI_SUCCESS;
276 
277     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDU_INIT_SHM_GET);
278     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDU_INIT_SHM_GET);
279 
280     MPIR_Assert(local_rank < local_size && len <= sizeof(MPIDU_Init_shm_block_t));
281     MPIR_Memcpy(target, (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t), len);
282 
283     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDU_INIT_SHM_GET);
284 
285     return mpi_errno;
286 }
287 
MPIDU_Init_shm_query(int local_rank,void ** target_addr)288 int MPIDU_Init_shm_query(int local_rank, void **target_addr)
289 {
290     int mpi_errno = MPI_SUCCESS;
291 
292     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDU_INIT_SHM_QUERY);
293     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDU_INIT_SHM_QUERY);
294 
295     MPIR_Assert(local_rank < local_size);
296     *target_addr = (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t);
297 
298     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDU_INIT_SHM_QUERY);
299 
300     return mpi_errno;
301 }
302 
303 #endif /* ENABLE_NO_LOCAL */
304