1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpidimpl.h"
7 #include "mpidig_am.h"
8 #include "mpidch4r.h"
9 #include "mpidu_genq.h"
10 
11 /*
12 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
13 
14 cvars:
15     - name        : MPIR_CVAR_CH4_AM_PACK_BUFFER_SIZE
16       category    : CH4
17       type        : int
18       default     : 69632
19       class       : none
20       verbosity   : MPI_T_VERBOSITY_USER_BASIC
21       scope       : MPI_T_SCOPE_LOCAL
22       description : >-
23         Specifies the number of buffers for packing/unpacking active messages in
24         each block of the pool. The size here should be greater or equal to the
25         max of the eager buffer limit of SHM and NETMOD.
26 
27     - name        : MPIR_CVAR_CH4_NUM_AM_PACK_BUFFERS_PER_CHUNK
28       category    : CH4
29       type        : int
30       default     : 16
31       class       : none
32       verbosity   : MPI_T_VERBOSITY_USER_BASIC
33       scope       : MPI_T_SCOPE_LOCAL
34       description : >-
35         Specifies the number of buffers for packing/unpacking active messages in
36         each block of the pool.
37 
38     - name        : MPIR_CVAR_CH4_MAX_AM_UNEXPECTED_PACK_BUFFERS_SIZE_BYTE
39       category    : CH4
40       type        : int
41       default     : 8388608
42       class       : none
43       verbosity   : MPI_T_VERBOSITY_USER_BASIC
44       scope       : MPI_T_SCOPE_LOCAL
45       description : >-
46         Specifies the max number of buffers for packing/unpacking active messages
47         in the pool.
48 
49 === END_MPI_T_CVAR_INFO_BLOCK ===
50 */
51 
52 static int dynamic_am_handler_id = MPIDIG_HANDLER_STATIC_MAX;
53 
54 static void *host_alloc(uintptr_t size);
55 static void *host_alloc_buffer_registered(uintptr_t size);
56 static void host_free(void *ptr);
57 static void host_free_buffer_registered(void *ptr);
58 
host_alloc(uintptr_t size)59 static void *host_alloc(uintptr_t size)
60 {
61     return MPL_malloc(size, MPL_MEM_BUFFER);
62 }
63 
host_alloc_buffer_registered(uintptr_t size)64 static void *host_alloc_buffer_registered(uintptr_t size)
65 {
66     void *ptr = MPL_malloc(size, MPL_MEM_BUFFER);
67     MPIR_Assert(ptr);
68     MPL_gpu_register_host(ptr, size);
69     return ptr;
70 }
71 
host_free(void * ptr)72 static void host_free(void *ptr)
73 {
74     MPL_free(ptr);
75 }
76 
host_free_buffer_registered(void * ptr)77 static void host_free_buffer_registered(void *ptr)
78 {
79     MPL_gpu_unregister_host(ptr);
80     MPL_free(ptr);
81 }
82 
MPIDIG_am_check_init(void)83 int MPIDIG_am_check_init(void)
84 {
85     int mpi_errno = MPI_SUCCESS;
86     size_t buf_size_limit = 0;
87 #ifdef MPIDI_CH4_DIRECT_NETMOD
88     buf_size_limit = MPIDI_NM_am_eager_buf_limit();
89 #else
90     buf_size_limit = MPL_MAX(MPIDI_SHM_am_eager_buf_limit(), MPIDI_NM_am_eager_buf_limit());
91 #endif
92     MPIR_Assert(MPIR_CVAR_CH4_AM_PACK_BUFFER_SIZE >= buf_size_limit);
93     return mpi_errno;
94 }
95 
MPIDIG_am_reg_cb_dynamic(MPIDIG_am_origin_cb origin_cb,MPIDIG_am_target_msg_cb target_msg_cb)96 int MPIDIG_am_reg_cb_dynamic(MPIDIG_am_origin_cb origin_cb, MPIDIG_am_target_msg_cb target_msg_cb)
97 {
98     if (dynamic_am_handler_id < MPIDI_AM_HANDLERS_MAX) {
99         MPIDIG_am_reg_cb(dynamic_am_handler_id, origin_cb, target_msg_cb);
100         dynamic_am_handler_id++;
101         return dynamic_am_handler_id - 1;
102     } else {
103         return -1;
104     }
105 }
106 
MPIDIG_am_reg_cb(int handler_id,MPIDIG_am_origin_cb origin_cb,MPIDIG_am_target_msg_cb target_msg_cb)107 void MPIDIG_am_reg_cb(int handler_id,
108                       MPIDIG_am_origin_cb origin_cb, MPIDIG_am_target_msg_cb target_msg_cb)
109 {
110     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_AM_REG_CB);
111     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_AM_REG_CB);
112 
113     MPIDIG_global.target_msg_cbs[handler_id] = target_msg_cb;
114     MPIDIG_global.origin_cbs[handler_id] = origin_cb;
115 
116     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_AM_REG_CB);
117 }
118 
MPIDIG_am_init(void)119 int MPIDIG_am_init(void)
120 {
121     int mpi_errno = MPI_SUCCESS;
122     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_AM_INIT);
123     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_AM_INIT);
124 
125     MPIDI_global.comm_req_lists = (MPIDIG_comm_req_list_t *)
126         MPL_calloc(MPIR_MAX_CONTEXT_MASK * MPIR_CONTEXT_INT_BITS,
127                    sizeof(MPIDIG_comm_req_list_t), MPL_MEM_OTHER);
128 #ifndef MPIDI_CH4U_USE_PER_COMM_QUEUE
129     MPIDI_global.posted_list = NULL;
130     MPIDI_global.unexp_list = NULL;
131 #endif
132 
133     MPIDI_global.cmpl_list = NULL;
134     MPL_atomic_store_uint64(&MPIDI_global.exp_seq_no, 0);
135     MPL_atomic_store_uint64(&MPIDI_global.nxt_seq_no, 0);
136 
137     MPL_atomic_store_int(&MPIDIG_global.rma_am_flag, 0);
138     MPIR_cc_set(&MPIDIG_global.rma_am_poll_cntr, 0);
139 
140     mpi_errno =
141         MPIDU_genq_private_pool_create_unsafe(MPIDIU_REQUEST_POOL_CELL_SIZE,
142                                               MPIDIU_REQUEST_POOL_NUM_CELLS_PER_CHUNK,
143                                               MPIDIU_REQUEST_POOL_MAX_NUM_CELLS, host_alloc,
144                                               host_free, &MPIDI_global.request_pool);
145     MPIR_ERR_CHECK(mpi_errno);
146     /* The cell size need to match the send side (ofi short msg size) */
147     mpi_errno = MPIDU_genq_private_pool_create_unsafe(MPIR_CVAR_CH4_AM_PACK_BUFFER_SIZE,
148                                                       MPIR_CVAR_CH4_NUM_AM_PACK_BUFFERS_PER_CHUNK,
149                                                       INT_MAX,
150                                                       host_alloc_buffer_registered,
151                                                       host_free_buffer_registered,
152                                                       &MPIDI_global.unexp_pack_buf_pool);
153     MPIR_ERR_CHECK(mpi_errno);
154 
155     MPIR_Assert(MPIDIG_HANDLER_STATIC_MAX <= MPIDI_AM_HANDLERS_MAX);
156 
157     MPIDIG_am_reg_cb(MPIDIG_SEND, &MPIDIG_send_origin_cb, &MPIDIG_send_target_msg_cb);
158     MPIDIG_am_reg_cb(MPIDIG_SEND_CTS, NULL, &MPIDIG_send_cts_target_msg_cb);
159     MPIDIG_am_reg_cb(MPIDIG_SEND_DATA,
160                      &MPIDIG_send_data_origin_cb, &MPIDIG_send_data_target_msg_cb);
161 
162     MPIDIG_am_reg_cb(MPIDIG_SSEND_ACK, NULL, &MPIDIG_ssend_ack_target_msg_cb);
163     MPIDIG_am_reg_cb(MPIDIG_PUT_REQ, &MPIDIG_put_origin_cb, &MPIDIG_put_target_msg_cb);
164     MPIDIG_am_reg_cb(MPIDIG_PUT_ACK, NULL, &MPIDIG_put_ack_target_msg_cb);
165     MPIDIG_am_reg_cb(MPIDIG_GET_REQ, &MPIDIG_get_origin_cb, &MPIDIG_get_target_msg_cb);
166     MPIDIG_am_reg_cb(MPIDIG_GET_ACK, &MPIDIG_get_ack_origin_cb, &MPIDIG_get_ack_target_msg_cb);
167     MPIDIG_am_reg_cb(MPIDIG_CSWAP_REQ, &MPIDIG_cswap_origin_cb, &MPIDIG_cswap_target_msg_cb);
168     MPIDIG_am_reg_cb(MPIDIG_CSWAP_ACK,
169                      &MPIDIG_cswap_ack_origin_cb, &MPIDIG_cswap_ack_target_msg_cb);
170     MPIDIG_am_reg_cb(MPIDIG_ACC_REQ, &MPIDIG_acc_origin_cb, &MPIDIG_acc_target_msg_cb);
171     MPIDIG_am_reg_cb(MPIDIG_GET_ACC_REQ, &MPIDIG_get_acc_origin_cb, &MPIDIG_get_acc_target_msg_cb);
172     MPIDIG_am_reg_cb(MPIDIG_ACC_ACK, NULL, &MPIDIG_acc_ack_target_msg_cb);
173     MPIDIG_am_reg_cb(MPIDIG_GET_ACC_ACK,
174                      &MPIDIG_get_acc_ack_origin_cb, &MPIDIG_get_acc_ack_target_msg_cb);
175     MPIDIG_am_reg_cb(MPIDIG_WIN_COMPLETE, NULL, &MPIDIG_win_ctrl_target_msg_cb);
176     MPIDIG_am_reg_cb(MPIDIG_WIN_POST, NULL, &MPIDIG_win_ctrl_target_msg_cb);
177     MPIDIG_am_reg_cb(MPIDIG_WIN_LOCK, NULL, &MPIDIG_win_ctrl_target_msg_cb);
178     MPIDIG_am_reg_cb(MPIDIG_WIN_LOCK_ACK, NULL, &MPIDIG_win_ctrl_target_msg_cb);
179     MPIDIG_am_reg_cb(MPIDIG_WIN_UNLOCK, NULL, &MPIDIG_win_ctrl_target_msg_cb);
180     MPIDIG_am_reg_cb(MPIDIG_WIN_UNLOCK_ACK, NULL, &MPIDIG_win_ctrl_target_msg_cb);
181     MPIDIG_am_reg_cb(MPIDIG_WIN_LOCKALL, NULL, &MPIDIG_win_ctrl_target_msg_cb);
182     MPIDIG_am_reg_cb(MPIDIG_WIN_LOCKALL_ACK, NULL, &MPIDIG_win_ctrl_target_msg_cb);
183     MPIDIG_am_reg_cb(MPIDIG_WIN_UNLOCKALL, NULL, &MPIDIG_win_ctrl_target_msg_cb);
184     MPIDIG_am_reg_cb(MPIDIG_WIN_UNLOCKALL_ACK, NULL, &MPIDIG_win_ctrl_target_msg_cb);
185     MPIDIG_am_reg_cb(MPIDIG_PUT_DT_REQ, &MPIDIG_put_dt_origin_cb, &MPIDIG_put_dt_target_msg_cb);
186     MPIDIG_am_reg_cb(MPIDIG_PUT_DT_ACK, NULL, &MPIDIG_put_dt_ack_target_msg_cb);
187     MPIDIG_am_reg_cb(MPIDIG_PUT_DAT_REQ,
188                      &MPIDIG_put_data_origin_cb, &MPIDIG_put_data_target_msg_cb);
189     MPIDIG_am_reg_cb(MPIDIG_ACC_DT_REQ, &MPIDIG_acc_dt_origin_cb, &MPIDIG_acc_dt_target_msg_cb);
190     MPIDIG_am_reg_cb(MPIDIG_GET_ACC_DT_REQ,
191                      &MPIDIG_get_acc_dt_origin_cb, &MPIDIG_get_acc_dt_target_msg_cb);
192     MPIDIG_am_reg_cb(MPIDIG_ACC_DT_ACK, NULL, &MPIDIG_acc_dt_ack_target_msg_cb);
193     MPIDIG_am_reg_cb(MPIDIG_GET_ACC_DT_ACK, NULL, &MPIDIG_get_acc_dt_ack_target_msg_cb);
194     MPIDIG_am_reg_cb(MPIDIG_ACC_DAT_REQ,
195                      &MPIDIG_acc_data_origin_cb, &MPIDIG_acc_data_target_msg_cb);
196     MPIDIG_am_reg_cb(MPIDIG_GET_ACC_DAT_REQ,
197                      &MPIDIG_get_acc_data_origin_cb, &MPIDIG_get_acc_data_target_msg_cb);
198 
199     MPIDIG_am_comm_abort_init();
200 
201     mpi_errno = MPIDIG_RMA_Init_sync_pvars();
202     MPIR_ERR_CHECK(mpi_errno);
203 
204     mpi_errno = MPIDIG_RMA_Init_targetcb_pvars();
205     MPIR_ERR_CHECK(mpi_errno);
206 
207     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_AM_INIT);
208 
209   fn_exit:
210     return mpi_errno;
211   fn_fail:
212     goto fn_exit;
213 }
214 
MPIDIG_am_finalize(void)215 void MPIDIG_am_finalize(void)
216 {
217     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDIG_AM_FINALIZE);
218     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDIG_AM_FINALIZE);
219 
220     MPIDIU_map_destroy(MPIDI_global.win_map);
221     MPIDU_genq_private_pool_destroy_unsafe(MPIDI_global.request_pool);
222     MPIDU_genq_private_pool_destroy_unsafe(MPIDI_global.unexp_pack_buf_pool);
223     MPL_free(MPIDI_global.comm_req_lists);
224 
225     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDIG_AM_FINALIZE);
226 }
227