1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpiimpl.h"
7 #include "mpidimpl.h"
8 #include "hcoll.h"
9 #include "hcoll/api/hcoll_dte.h"
10 #include "hcoll_dtypes.h"
11
12 static int recv_nb(dte_data_representation_t data,
13 uint32_t count,
14 void *buffer,
15 rte_ec_handle_t, rte_grp_handle_t, uint32_t tag, rte_request_handle_t * req);
16
17 static int send_nb(dte_data_representation_t data,
18 uint32_t count,
19 void *buffer,
20 rte_ec_handle_t ec_h,
21 rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req);
22
23 static int test(rte_request_handle_t * request, int *completed);
24
25 static int ec_handle_compare(rte_ec_handle_t handle_1,
26 rte_grp_handle_t
27 group_handle_1,
28 rte_ec_handle_t handle_2, rte_grp_handle_t group_handle_2);
29
30 static int get_ec_handles(int num_ec,
31 int *ec_indexes, rte_grp_handle_t, rte_ec_handle_t * ec_handles);
32
33 static int group_size(rte_grp_handle_t group);
34 static int my_rank(rte_grp_handle_t grp_h);
35 static int ec_on_local_node(rte_ec_handle_t ec, rte_grp_handle_t group);
36 static rte_grp_handle_t get_world_group_handle(void);
37 static uint32_t jobid(void);
38
39 static void *get_coll_handle(void);
40 static int coll_handle_test(void *handle);
41 static void coll_handle_free(void *handle);
42 static void coll_handle_complete(void *handle);
43 static int group_id(rte_grp_handle_t group);
44
45 static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec);
46
progress(void)47 static void progress(void)
48 {
49 int ret;
50 int made_progress;
51
52 if (0 == world_comm_destroying) {
53 MPID_Progress_test(NULL);
54 } else {
55 /* FIXME: The hcoll library needs to be updated to return
56 * error codes. The progress function pointer right now
57 * expects that the function returns void. */
58 ret = hcoll_do_progress(&made_progress);
59 MPIR_Assert(ret == MPI_SUCCESS);
60 }
61 }
62
63 #if HCOLL_API >= HCOLL_VERSION(3,6)
64 static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
65 int *num_addresses, int *num_datatypes,
66 hcoll_mpi_type_combiner_t * combiner);
67 static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
68 int max_datatypes, int *array_of_integers,
69 void *array_of_addresses, void *array_of_datatypes);
70 static int get_hcoll_type(void *mpi_type, dte_data_representation_t * hcoll_type);
71 static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type);
72 static int get_mpi_constants(size_t * mpi_datatype_size,
73 int *mpi_order_c, int *mpi_order_fortran,
74 int *mpi_distribute_block,
75 int *mpi_distribute_cyclic,
76 int *mpi_distribute_none, int *mpi_distribute_dflt_darg);
77 #endif
78
init_module_fns(void)79 static void init_module_fns(void)
80 {
81 hcoll_rte_functions.send_fn = send_nb;
82 hcoll_rte_functions.recv_fn = recv_nb;
83 hcoll_rte_functions.ec_cmp_fn = ec_handle_compare;
84 hcoll_rte_functions.get_ec_handles_fn = get_ec_handles;
85 hcoll_rte_functions.rte_group_size_fn = group_size;
86 hcoll_rte_functions.test_fn = test;
87 hcoll_rte_functions.rte_my_rank_fn = my_rank;
88 hcoll_rte_functions.rte_ec_on_local_node_fn = ec_on_local_node;
89 hcoll_rte_functions.rte_world_group_fn = get_world_group_handle;
90 hcoll_rte_functions.rte_jobid_fn = jobid;
91 hcoll_rte_functions.rte_progress_fn = progress;
92 hcoll_rte_functions.rte_get_coll_handle_fn = get_coll_handle;
93 hcoll_rte_functions.rte_coll_handle_test_fn = coll_handle_test;
94 hcoll_rte_functions.rte_coll_handle_free_fn = coll_handle_free;
95 hcoll_rte_functions.rte_coll_handle_complete_fn = coll_handle_complete;
96 hcoll_rte_functions.rte_group_id_fn = group_id;
97 hcoll_rte_functions.rte_world_rank_fn = world_rank;
98 #if HCOLL_API >= HCOLL_VERSION(3,6)
99 hcoll_rte_functions.rte_get_mpi_type_envelope_fn = get_mpi_type_envelope;
100 hcoll_rte_functions.rte_get_mpi_type_contents_fn = get_mpi_type_contents;
101 hcoll_rte_functions.rte_get_hcoll_type_fn = get_hcoll_type;
102 hcoll_rte_functions.rte_set_hcoll_type_fn = set_hcoll_type;
103 hcoll_rte_functions.rte_get_mpi_constants_fn = get_mpi_constants;
104 #endif
105 }
106
hcoll_rte_fns_setup(void)107 void hcoll_rte_fns_setup(void)
108 {
109 init_module_fns();
110 }
111
recv_nb(struct dte_data_representation_t data,uint32_t count,void * buffer,rte_ec_handle_t ec_h,rte_grp_handle_t grp_h,uint32_t tag,rte_request_handle_t * req)112 static int recv_nb(struct dte_data_representation_t data,
113 uint32_t count,
114 void *buffer,
115 rte_ec_handle_t ec_h,
116 rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req)
117 {
118 int mpi_errno;
119 MPI_Datatype dtype;
120 MPIR_Request *request;
121 MPIR_Comm *comm;
122 size_t size;
123 mpi_errno = MPI_SUCCESS;
124 comm = (MPIR_Comm *) grp_h;
125 if (!ec_h.handle) {
126 MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**hcoll_wrong_arg",
127 "**hcoll_wrong_arg %p %d", ec_h.handle, ec_h.rank);
128 }
129
130 MPIR_Assert(HCOL_DTE_IS_INLINE(data));
131 if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
132 MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**null_buff_ptr");
133 }
134 size = (size_t) data.rep.in_line_rep.data_handle.in_line.packed_size * count / 8;
135 dtype = MPI_CHAR;
136 request = NULL;
137 mpi_errno = MPIC_Irecv(buffer, size, dtype, ec_h.rank, tag, comm, &request);
138 MPIR_Assert(request);
139 req->data = (void *) request;
140 req->status = HCOLRTE_REQUEST_ACTIVE;
141 fn_exit:
142 return mpi_errno;
143 fn_fail:
144 return HCOLL_ERROR;
145 }
146
send_nb(dte_data_representation_t data,uint32_t count,void * buffer,rte_ec_handle_t ec_h,rte_grp_handle_t grp_h,uint32_t tag,rte_request_handle_t * req)147 static int send_nb(dte_data_representation_t data,
148 uint32_t count,
149 void *buffer,
150 rte_ec_handle_t ec_h,
151 rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req)
152 {
153 int mpi_errno;
154 MPI_Datatype dtype;
155 MPIR_Request *request;
156 MPIR_Comm *comm;
157 size_t size;
158 mpi_errno = MPI_SUCCESS;
159 comm = (MPIR_Comm *) grp_h;
160 if (!ec_h.handle) {
161 MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**hcoll_wrong_arg",
162 "**hcoll_wrong_arg %p %d", ec_h.handle, ec_h.rank);
163 }
164
165 MPIR_Assert(HCOL_DTE_IS_INLINE(data));
166 if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
167 MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**null_buff_ptr");
168 }
169 size = (size_t) data.rep.in_line_rep.data_handle.in_line.packed_size * count / 8;
170 dtype = MPI_CHAR;
171 request = NULL;
172 MPIR_Errflag_t err = MPIR_ERR_NONE;
173 mpi_errno = MPIC_Isend(buffer, size, dtype, ec_h.rank, tag, comm, &request, &err);
174 MPIR_Assert(request);
175 req->data = (void *) request;
176 req->status = HCOLRTE_REQUEST_ACTIVE;
177 fn_exit:
178 return mpi_errno;
179 fn_fail:
180 return HCOLL_ERROR;
181 }
182
test(rte_request_handle_t * request,int * completed)183 static int test(rte_request_handle_t * request, int *completed)
184 {
185 MPIR_Request *req;
186 req = (MPIR_Request *) request->data;
187 if (HCOLRTE_REQUEST_ACTIVE != request->status) {
188 *completed = true;
189 return HCOLL_SUCCESS;
190 }
191
192 *completed = (int) MPIR_Request_is_complete(req);
193 if (*completed) {
194 MPIR_Request_free(req);
195 request->status = HCOLRTE_REQUEST_DONE;
196 }
197
198 return HCOLL_SUCCESS;
199 }
200
ec_handle_compare(rte_ec_handle_t handle_1,rte_grp_handle_t group_handle_1,rte_ec_handle_t handle_2,rte_grp_handle_t group_handle_2)201 static int ec_handle_compare(rte_ec_handle_t handle_1,
202 rte_grp_handle_t
203 group_handle_1,
204 rte_ec_handle_t handle_2, rte_grp_handle_t group_handle_2)
205 {
206 return handle_1.handle == handle_2.handle;
207 }
208
get_ec_handles(int num_ec,int * ec_indexes,rte_grp_handle_t grp_h,rte_ec_handle_t * ec_handles)209 static int get_ec_handles(int num_ec,
210 int *ec_indexes, rte_grp_handle_t grp_h, rte_ec_handle_t * ec_handles)
211 {
212 int i;
213 MPIR_Comm *comm;
214 comm = (MPIR_Comm *) grp_h;
215 for (i = 0; i < num_ec; i++) {
216 ec_handles[i].rank = ec_indexes[i];
217 #ifdef MPIDCH4_H_INCLUDED
218 ec_handles[i].handle = (void *) (MPIDIU_comm_rank_to_av(comm, ec_indexes[i]));
219 #else
220 ec_handles[i].handle = (void *) (comm->dev.vcrt->vcr_table[ec_indexes[i]]);
221 #endif
222 }
223 return HCOLL_SUCCESS;
224 }
225
group_size(rte_grp_handle_t grp_h)226 static int group_size(rte_grp_handle_t grp_h)
227 {
228 return MPIR_Comm_size((MPIR_Comm *) grp_h);
229 }
230
my_rank(rte_grp_handle_t grp_h)231 static int my_rank(rte_grp_handle_t grp_h)
232 {
233 return MPIR_Comm_rank((MPIR_Comm *) grp_h);
234 }
235
ec_on_local_node(rte_ec_handle_t ec,rte_grp_handle_t group)236 static int ec_on_local_node(rte_ec_handle_t ec, rte_grp_handle_t group)
237 {
238 MPIR_Comm *comm;
239 int nodeid, my_nodeid;
240 int my_rank;
241 comm = (MPIR_Comm *) group;
242 MPID_Get_node_id(comm, ec.rank, &nodeid);
243 my_rank = MPIR_Comm_rank(comm);
244 MPID_Get_node_id(comm, my_rank, &my_nodeid);
245 return (nodeid == my_nodeid);
246 }
247
248
get_world_group_handle(void)249 static rte_grp_handle_t get_world_group_handle(void)
250 {
251 return (rte_grp_handle_t) (MPIR_Process.comm_world);
252 }
253
jobid(void)254 static uint32_t jobid(void)
255 {
256 /* not used currently */
257 return 0;
258 }
259
group_id(rte_grp_handle_t group)260 static int group_id(rte_grp_handle_t group)
261 {
262 MPIR_Comm *comm;
263 comm = (MPIR_Comm *) group;
264 return comm->context_id;
265 }
266
get_coll_handle(void)267 static void *get_coll_handle(void)
268 {
269 MPIR_Request *req;
270 req = MPIR_Request_create(MPIR_REQUEST_KIND__COLL);
271 MPIR_Request_add_ref(req);
272 return (void *) req;
273 }
274
coll_handle_test(void * handle)275 static int coll_handle_test(void *handle)
276 {
277 int completed;
278 MPIR_Request *req;
279 req = (MPIR_Request *) handle;
280 completed = (int) MPIR_Request_is_complete(req);
281 return completed;
282 }
283
coll_handle_free(void * handle)284 static void coll_handle_free(void *handle)
285 {
286 MPIR_Request *req;
287 if (NULL != handle) {
288 req = (MPIR_Request *) handle;
289 MPIR_Request_free(req);
290 }
291 }
292
coll_handle_complete(void * handle)293 static void coll_handle_complete(void *handle)
294 {
295 MPIR_Request *req;
296 if (NULL != handle) {
297 req = (MPIR_Request *) handle;
298 MPIR_Request_complete(req);
299 }
300 }
301
world_rank(rte_grp_handle_t grp_h,rte_ec_handle_t ec)302 static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec)
303 {
304 #ifdef MPIDCH4_H_INCLUDED
305 return MPIDIU_rank_to_lpid(ec.rank, (MPIR_Comm *) grp_h);
306 #else
307 return ((struct MPIDI_VC *) ec.handle)->pg_rank;
308 #endif
309 }
310
311 #if HCOLL_API >= HCOLL_VERSION(3,6)
mpi_combiner_2_hcoll_combiner(int combiner)312 hcoll_mpi_type_combiner_t mpi_combiner_2_hcoll_combiner(int combiner)
313 {
314 switch (combiner) {
315 case MPI_COMBINER_CONTIGUOUS:
316 return HCOLL_MPI_COMBINER_CONTIGUOUS;
317 case MPI_COMBINER_VECTOR:
318 return HCOLL_MPI_COMBINER_VECTOR;
319 case MPI_COMBINER_HVECTOR:
320 return HCOLL_MPI_COMBINER_HVECTOR;
321 case MPI_COMBINER_INDEXED:
322 return HCOLL_MPI_COMBINER_INDEXED;
323 case MPI_COMBINER_HINDEXED_INTEGER:
324 case MPI_COMBINER_HINDEXED:
325 return HCOLL_MPI_COMBINER_HINDEXED;
326 case MPI_COMBINER_DUP:
327 return HCOLL_MPI_COMBINER_DUP;
328 case MPI_COMBINER_INDEXED_BLOCK:
329 return HCOLL_MPI_COMBINER_INDEXED_BLOCK;
330 case MPI_COMBINER_HINDEXED_BLOCK:
331 return HCOLL_MPI_COMBINER_HINDEXED_BLOCK;
332 case MPI_COMBINER_SUBARRAY:
333 return HCOLL_MPI_COMBINER_SUBARRAY;
334 case MPI_COMBINER_DARRAY:
335 return HCOLL_MPI_COMBINER_DARRAY;
336 case MPI_COMBINER_F90_REAL:
337 return HCOLL_MPI_COMBINER_F90_REAL;
338 case MPI_COMBINER_F90_COMPLEX:
339 return HCOLL_MPI_COMBINER_F90_COMPLEX;
340 case MPI_COMBINER_F90_INTEGER:
341 return HCOLL_MPI_COMBINER_F90_INTEGER;
342 case MPI_COMBINER_RESIZED:
343 return HCOLL_MPI_COMBINER_RESIZED;
344 case MPI_COMBINER_STRUCT:
345 case MPI_COMBINER_STRUCT_INTEGER:
346 return HCOLL_MPI_COMBINER_STRUCT;
347 default:
348 break;
349 }
350 return HCOLL_MPI_COMBINER_LAST;
351 }
352
get_mpi_type_envelope(void * mpi_type,int * num_integers,int * num_addresses,int * num_datatypes,hcoll_mpi_type_combiner_t * combiner)353 static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
354 int *num_addresses, int *num_datatypes,
355 hcoll_mpi_type_combiner_t * combiner)
356 {
357 int mpi_combiner;
358 MPI_Datatype dt_handle = (MPI_Datatype) (intptr_t) mpi_type;
359
360 MPIR_Type_get_envelope(dt_handle, num_integers, num_addresses, num_datatypes, &mpi_combiner);
361
362 *combiner = mpi_combiner_2_hcoll_combiner(mpi_combiner);
363
364 return HCOLL_SUCCESS;
365 }
366
get_mpi_type_contents(void * mpi_type,int max_integers,int max_addresses,int max_datatypes,int * array_of_integers,void * array_of_addresses,void * array_of_datatypes)367 static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
368 int max_datatypes, int *array_of_integers,
369 void *array_of_addresses, void *array_of_datatypes)
370 {
371 int ret;
372 MPI_Datatype dt_handle = (MPI_Datatype) (intptr_t) mpi_type;
373
374 ret = MPIR_Type_get_contents(dt_handle,
375 max_integers, max_addresses, max_datatypes,
376 array_of_integers,
377 (MPI_Aint *) array_of_addresses,
378 (MPI_Datatype *) array_of_datatypes);
379
380 return ret == MPI_SUCCESS ? HCOLL_SUCCESS : HCOLL_ERROR;
381 }
382
get_hcoll_type(void * mpi_type,dte_data_representation_t * hcoll_type)383 static int get_hcoll_type(void *mpi_type, dte_data_representation_t * hcoll_type)
384 {
385 MPI_Datatype dt_handle = (MPI_Datatype) (intptr_t) mpi_type;
386 MPIR_Datatype *dt_ptr;
387
388 *hcoll_type = mpi_dtype_2_hcoll_dtype(dt_handle, -1, TRY_FIND_DERIVED);
389
390 return HCOL_DTE_IS_ZERO((*hcoll_type)) ? HCOLL_ERR_NOT_FOUND : HCOLL_SUCCESS;
391 }
392
set_hcoll_type(void * mpi_type,dte_data_representation_t hcoll_type)393 static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type)
394 {
395 return HCOLL_SUCCESS;
396 }
397
get_mpi_constants(size_t * mpi_datatype_size,int * mpi_order_c,int * mpi_order_fortran,int * mpi_distribute_block,int * mpi_distribute_cyclic,int * mpi_distribute_none,int * mpi_distribute_dflt_darg)398 static int get_mpi_constants(size_t * mpi_datatype_size,
399 int *mpi_order_c, int *mpi_order_fortran,
400 int *mpi_distribute_block,
401 int *mpi_distribute_cyclic,
402 int *mpi_distribute_none, int *mpi_distribute_dflt_darg)
403 {
404 *mpi_datatype_size = sizeof(MPI_Datatype);
405 *mpi_order_c = MPI_ORDER_C;
406 *mpi_order_fortran = MPI_ORDER_FORTRAN;
407 *mpi_distribute_block = MPI_DISTRIBUTE_BLOCK;
408 *mpi_distribute_cyclic = MPI_DISTRIBUTE_CYCLIC;
409 *mpi_distribute_none = MPI_DISTRIBUTE_NONE;
410 *mpi_distribute_dflt_darg = MPI_DISTRIBUTE_DFLT_DARG;
411
412 return HCOLL_SUCCESS;
413 }
414
415 #endif
416