1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 #include "mpidimpl.h"
8 #include "hcoll.h"
9 #include "hcoll/api/hcoll_dte.h"
10 #include "hcoll_dtypes.h"
11 
12 static int recv_nb(dte_data_representation_t data,
13                    uint32_t count,
14                    void *buffer,
15                    rte_ec_handle_t, rte_grp_handle_t, uint32_t tag, rte_request_handle_t * req);
16 
17 static int send_nb(dte_data_representation_t data,
18                    uint32_t count,
19                    void *buffer,
20                    rte_ec_handle_t ec_h,
21                    rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req);
22 
23 static int test(rte_request_handle_t * request, int *completed);
24 
25 static int ec_handle_compare(rte_ec_handle_t handle_1,
26                              rte_grp_handle_t
27                              group_handle_1,
28                              rte_ec_handle_t handle_2, rte_grp_handle_t group_handle_2);
29 
30 static int get_ec_handles(int num_ec,
31                           int *ec_indexes, rte_grp_handle_t, rte_ec_handle_t * ec_handles);
32 
33 static int group_size(rte_grp_handle_t group);
34 static int my_rank(rte_grp_handle_t grp_h);
35 static int ec_on_local_node(rte_ec_handle_t ec, rte_grp_handle_t group);
36 static rte_grp_handle_t get_world_group_handle(void);
37 static uint32_t jobid(void);
38 
39 static void *get_coll_handle(void);
40 static int coll_handle_test(void *handle);
41 static void coll_handle_free(void *handle);
42 static void coll_handle_complete(void *handle);
43 static int group_id(rte_grp_handle_t group);
44 
45 static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec);
46 
progress(void)47 static void progress(void)
48 {
49     int ret;
50     int made_progress;
51 
52     if (0 == world_comm_destroying) {
53         MPID_Progress_test(NULL);
54     } else {
55         /* FIXME: The hcoll library needs to be updated to return
56          * error codes.  The progress function pointer right now
57          * expects that the function returns void. */
58         ret = hcoll_do_progress(&made_progress);
59         MPIR_Assert(ret == MPI_SUCCESS);
60     }
61 }
62 
63 #if HCOLL_API >= HCOLL_VERSION(3,6)
64 static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
65                                  int *num_addresses, int *num_datatypes,
66                                  hcoll_mpi_type_combiner_t * combiner);
67 static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
68                                  int max_datatypes, int *array_of_integers,
69                                  void *array_of_addresses, void *array_of_datatypes);
70 static int get_hcoll_type(void *mpi_type, dte_data_representation_t * hcoll_type);
71 static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type);
72 static int get_mpi_constants(size_t * mpi_datatype_size,
73                              int *mpi_order_c, int *mpi_order_fortran,
74                              int *mpi_distribute_block,
75                              int *mpi_distribute_cyclic,
76                              int *mpi_distribute_none, int *mpi_distribute_dflt_darg);
77 #endif
78 
init_module_fns(void)79 static void init_module_fns(void)
80 {
81     hcoll_rte_functions.send_fn = send_nb;
82     hcoll_rte_functions.recv_fn = recv_nb;
83     hcoll_rte_functions.ec_cmp_fn = ec_handle_compare;
84     hcoll_rte_functions.get_ec_handles_fn = get_ec_handles;
85     hcoll_rte_functions.rte_group_size_fn = group_size;
86     hcoll_rte_functions.test_fn = test;
87     hcoll_rte_functions.rte_my_rank_fn = my_rank;
88     hcoll_rte_functions.rte_ec_on_local_node_fn = ec_on_local_node;
89     hcoll_rte_functions.rte_world_group_fn = get_world_group_handle;
90     hcoll_rte_functions.rte_jobid_fn = jobid;
91     hcoll_rte_functions.rte_progress_fn = progress;
92     hcoll_rte_functions.rte_get_coll_handle_fn = get_coll_handle;
93     hcoll_rte_functions.rte_coll_handle_test_fn = coll_handle_test;
94     hcoll_rte_functions.rte_coll_handle_free_fn = coll_handle_free;
95     hcoll_rte_functions.rte_coll_handle_complete_fn = coll_handle_complete;
96     hcoll_rte_functions.rte_group_id_fn = group_id;
97     hcoll_rte_functions.rte_world_rank_fn = world_rank;
98 #if HCOLL_API >= HCOLL_VERSION(3,6)
99     hcoll_rte_functions.rte_get_mpi_type_envelope_fn = get_mpi_type_envelope;
100     hcoll_rte_functions.rte_get_mpi_type_contents_fn = get_mpi_type_contents;
101     hcoll_rte_functions.rte_get_hcoll_type_fn = get_hcoll_type;
102     hcoll_rte_functions.rte_set_hcoll_type_fn = set_hcoll_type;
103     hcoll_rte_functions.rte_get_mpi_constants_fn = get_mpi_constants;
104 #endif
105 }
106 
hcoll_rte_fns_setup(void)107 void hcoll_rte_fns_setup(void)
108 {
109     init_module_fns();
110 }
111 
recv_nb(struct dte_data_representation_t data,uint32_t count,void * buffer,rte_ec_handle_t ec_h,rte_grp_handle_t grp_h,uint32_t tag,rte_request_handle_t * req)112 static int recv_nb(struct dte_data_representation_t data,
113                    uint32_t count,
114                    void *buffer,
115                    rte_ec_handle_t ec_h,
116                    rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req)
117 {
118     int mpi_errno;
119     MPI_Datatype dtype;
120     MPIR_Request *request;
121     MPIR_Comm *comm;
122     size_t size;
123     mpi_errno = MPI_SUCCESS;
124     comm = (MPIR_Comm *) grp_h;
125     if (!ec_h.handle) {
126         MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**hcoll_wrong_arg",
127                              "**hcoll_wrong_arg %p %d", ec_h.handle, ec_h.rank);
128     }
129 
130     MPIR_Assert(HCOL_DTE_IS_INLINE(data));
131     if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
132         MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**null_buff_ptr");
133     }
134     size = (size_t) data.rep.in_line_rep.data_handle.in_line.packed_size * count / 8;
135     dtype = MPI_CHAR;
136     request = NULL;
137     mpi_errno = MPIC_Irecv(buffer, size, dtype, ec_h.rank, tag, comm, &request);
138     MPIR_Assert(request);
139     req->data = (void *) request;
140     req->status = HCOLRTE_REQUEST_ACTIVE;
141   fn_exit:
142     return mpi_errno;
143   fn_fail:
144     return HCOLL_ERROR;
145 }
146 
send_nb(dte_data_representation_t data,uint32_t count,void * buffer,rte_ec_handle_t ec_h,rte_grp_handle_t grp_h,uint32_t tag,rte_request_handle_t * req)147 static int send_nb(dte_data_representation_t data,
148                    uint32_t count,
149                    void *buffer,
150                    rte_ec_handle_t ec_h,
151                    rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req)
152 {
153     int mpi_errno;
154     MPI_Datatype dtype;
155     MPIR_Request *request;
156     MPIR_Comm *comm;
157     size_t size;
158     mpi_errno = MPI_SUCCESS;
159     comm = (MPIR_Comm *) grp_h;
160     if (!ec_h.handle) {
161         MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**hcoll_wrong_arg",
162                              "**hcoll_wrong_arg %p %d", ec_h.handle, ec_h.rank);
163     }
164 
165     MPIR_Assert(HCOL_DTE_IS_INLINE(data));
166     if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
167         MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**null_buff_ptr");
168     }
169     size = (size_t) data.rep.in_line_rep.data_handle.in_line.packed_size * count / 8;
170     dtype = MPI_CHAR;
171     request = NULL;
172     MPIR_Errflag_t err = MPIR_ERR_NONE;
173     mpi_errno = MPIC_Isend(buffer, size, dtype, ec_h.rank, tag, comm, &request, &err);
174     MPIR_Assert(request);
175     req->data = (void *) request;
176     req->status = HCOLRTE_REQUEST_ACTIVE;
177   fn_exit:
178     return mpi_errno;
179   fn_fail:
180     return HCOLL_ERROR;
181 }
182 
test(rte_request_handle_t * request,int * completed)183 static int test(rte_request_handle_t * request, int *completed)
184 {
185     MPIR_Request *req;
186     req = (MPIR_Request *) request->data;
187     if (HCOLRTE_REQUEST_ACTIVE != request->status) {
188         *completed = true;
189         return HCOLL_SUCCESS;
190     }
191 
192     *completed = (int) MPIR_Request_is_complete(req);
193     if (*completed) {
194         MPIR_Request_free(req);
195         request->status = HCOLRTE_REQUEST_DONE;
196     }
197 
198     return HCOLL_SUCCESS;
199 }
200 
ec_handle_compare(rte_ec_handle_t handle_1,rte_grp_handle_t group_handle_1,rte_ec_handle_t handle_2,rte_grp_handle_t group_handle_2)201 static int ec_handle_compare(rte_ec_handle_t handle_1,
202                              rte_grp_handle_t
203                              group_handle_1,
204                              rte_ec_handle_t handle_2, rte_grp_handle_t group_handle_2)
205 {
206     return handle_1.handle == handle_2.handle;
207 }
208 
get_ec_handles(int num_ec,int * ec_indexes,rte_grp_handle_t grp_h,rte_ec_handle_t * ec_handles)209 static int get_ec_handles(int num_ec,
210                           int *ec_indexes, rte_grp_handle_t grp_h, rte_ec_handle_t * ec_handles)
211 {
212     int i;
213     MPIR_Comm *comm;
214     comm = (MPIR_Comm *) grp_h;
215     for (i = 0; i < num_ec; i++) {
216         ec_handles[i].rank = ec_indexes[i];
217 #ifdef MPIDCH4_H_INCLUDED
218         ec_handles[i].handle = (void *) (MPIDIU_comm_rank_to_av(comm, ec_indexes[i]));
219 #else
220         ec_handles[i].handle = (void *) (comm->dev.vcrt->vcr_table[ec_indexes[i]]);
221 #endif
222     }
223     return HCOLL_SUCCESS;
224 }
225 
group_size(rte_grp_handle_t grp_h)226 static int group_size(rte_grp_handle_t grp_h)
227 {
228     return MPIR_Comm_size((MPIR_Comm *) grp_h);
229 }
230 
my_rank(rte_grp_handle_t grp_h)231 static int my_rank(rte_grp_handle_t grp_h)
232 {
233     return MPIR_Comm_rank((MPIR_Comm *) grp_h);
234 }
235 
ec_on_local_node(rte_ec_handle_t ec,rte_grp_handle_t group)236 static int ec_on_local_node(rte_ec_handle_t ec, rte_grp_handle_t group)
237 {
238     MPIR_Comm *comm;
239     int nodeid, my_nodeid;
240     int my_rank;
241     comm = (MPIR_Comm *) group;
242     MPID_Get_node_id(comm, ec.rank, &nodeid);
243     my_rank = MPIR_Comm_rank(comm);
244     MPID_Get_node_id(comm, my_rank, &my_nodeid);
245     return (nodeid == my_nodeid);
246 }
247 
248 
get_world_group_handle(void)249 static rte_grp_handle_t get_world_group_handle(void)
250 {
251     return (rte_grp_handle_t) (MPIR_Process.comm_world);
252 }
253 
jobid(void)254 static uint32_t jobid(void)
255 {
256     /* not used currently */
257     return 0;
258 }
259 
group_id(rte_grp_handle_t group)260 static int group_id(rte_grp_handle_t group)
261 {
262     MPIR_Comm *comm;
263     comm = (MPIR_Comm *) group;
264     return comm->context_id;
265 }
266 
get_coll_handle(void)267 static void *get_coll_handle(void)
268 {
269     MPIR_Request *req;
270     req = MPIR_Request_create(MPIR_REQUEST_KIND__COLL);
271     MPIR_Request_add_ref(req);
272     return (void *) req;
273 }
274 
coll_handle_test(void * handle)275 static int coll_handle_test(void *handle)
276 {
277     int completed;
278     MPIR_Request *req;
279     req = (MPIR_Request *) handle;
280     completed = (int) MPIR_Request_is_complete(req);
281     return completed;
282 }
283 
coll_handle_free(void * handle)284 static void coll_handle_free(void *handle)
285 {
286     MPIR_Request *req;
287     if (NULL != handle) {
288         req = (MPIR_Request *) handle;
289         MPIR_Request_free(req);
290     }
291 }
292 
coll_handle_complete(void * handle)293 static void coll_handle_complete(void *handle)
294 {
295     MPIR_Request *req;
296     if (NULL != handle) {
297         req = (MPIR_Request *) handle;
298         MPIR_Request_complete(req);
299     }
300 }
301 
world_rank(rte_grp_handle_t grp_h,rte_ec_handle_t ec)302 static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec)
303 {
304 #ifdef MPIDCH4_H_INCLUDED
305     return MPIDIU_rank_to_lpid(ec.rank, (MPIR_Comm *) grp_h);
306 #else
307     return ((struct MPIDI_VC *) ec.handle)->pg_rank;
308 #endif
309 }
310 
311 #if HCOLL_API >= HCOLL_VERSION(3,6)
mpi_combiner_2_hcoll_combiner(int combiner)312 hcoll_mpi_type_combiner_t mpi_combiner_2_hcoll_combiner(int combiner)
313 {
314     switch (combiner) {
315         case MPI_COMBINER_CONTIGUOUS:
316             return HCOLL_MPI_COMBINER_CONTIGUOUS;
317         case MPI_COMBINER_VECTOR:
318             return HCOLL_MPI_COMBINER_VECTOR;
319         case MPI_COMBINER_HVECTOR:
320             return HCOLL_MPI_COMBINER_HVECTOR;
321         case MPI_COMBINER_INDEXED:
322             return HCOLL_MPI_COMBINER_INDEXED;
323         case MPI_COMBINER_HINDEXED_INTEGER:
324         case MPI_COMBINER_HINDEXED:
325             return HCOLL_MPI_COMBINER_HINDEXED;
326         case MPI_COMBINER_DUP:
327             return HCOLL_MPI_COMBINER_DUP;
328         case MPI_COMBINER_INDEXED_BLOCK:
329             return HCOLL_MPI_COMBINER_INDEXED_BLOCK;
330         case MPI_COMBINER_HINDEXED_BLOCK:
331             return HCOLL_MPI_COMBINER_HINDEXED_BLOCK;
332         case MPI_COMBINER_SUBARRAY:
333             return HCOLL_MPI_COMBINER_SUBARRAY;
334         case MPI_COMBINER_DARRAY:
335             return HCOLL_MPI_COMBINER_DARRAY;
336         case MPI_COMBINER_F90_REAL:
337             return HCOLL_MPI_COMBINER_F90_REAL;
338         case MPI_COMBINER_F90_COMPLEX:
339             return HCOLL_MPI_COMBINER_F90_COMPLEX;
340         case MPI_COMBINER_F90_INTEGER:
341             return HCOLL_MPI_COMBINER_F90_INTEGER;
342         case MPI_COMBINER_RESIZED:
343             return HCOLL_MPI_COMBINER_RESIZED;
344         case MPI_COMBINER_STRUCT:
345         case MPI_COMBINER_STRUCT_INTEGER:
346             return HCOLL_MPI_COMBINER_STRUCT;
347         default:
348             break;
349     }
350     return HCOLL_MPI_COMBINER_LAST;
351 }
352 
get_mpi_type_envelope(void * mpi_type,int * num_integers,int * num_addresses,int * num_datatypes,hcoll_mpi_type_combiner_t * combiner)353 static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
354                                  int *num_addresses, int *num_datatypes,
355                                  hcoll_mpi_type_combiner_t * combiner)
356 {
357     int mpi_combiner;
358     MPI_Datatype dt_handle = (MPI_Datatype) (intptr_t) mpi_type;
359 
360     MPIR_Type_get_envelope(dt_handle, num_integers, num_addresses, num_datatypes, &mpi_combiner);
361 
362     *combiner = mpi_combiner_2_hcoll_combiner(mpi_combiner);
363 
364     return HCOLL_SUCCESS;
365 }
366 
get_mpi_type_contents(void * mpi_type,int max_integers,int max_addresses,int max_datatypes,int * array_of_integers,void * array_of_addresses,void * array_of_datatypes)367 static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
368                                  int max_datatypes, int *array_of_integers,
369                                  void *array_of_addresses, void *array_of_datatypes)
370 {
371     int ret;
372     MPI_Datatype dt_handle = (MPI_Datatype) (intptr_t) mpi_type;
373 
374     ret = MPIR_Type_get_contents(dt_handle,
375                                  max_integers, max_addresses, max_datatypes,
376                                  array_of_integers,
377                                  (MPI_Aint *) array_of_addresses,
378                                  (MPI_Datatype *) array_of_datatypes);
379 
380     return ret == MPI_SUCCESS ? HCOLL_SUCCESS : HCOLL_ERROR;
381 }
382 
get_hcoll_type(void * mpi_type,dte_data_representation_t * hcoll_type)383 static int get_hcoll_type(void *mpi_type, dte_data_representation_t * hcoll_type)
384 {
385     MPI_Datatype dt_handle = (MPI_Datatype) (intptr_t) mpi_type;
386     MPIR_Datatype *dt_ptr;
387 
388     *hcoll_type = mpi_dtype_2_hcoll_dtype(dt_handle, -1, TRY_FIND_DERIVED);
389 
390     return HCOL_DTE_IS_ZERO((*hcoll_type)) ? HCOLL_ERR_NOT_FOUND : HCOLL_SUCCESS;
391 }
392 
set_hcoll_type(void * mpi_type,dte_data_representation_t hcoll_type)393 static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type)
394 {
395     return HCOLL_SUCCESS;
396 }
397 
get_mpi_constants(size_t * mpi_datatype_size,int * mpi_order_c,int * mpi_order_fortran,int * mpi_distribute_block,int * mpi_distribute_cyclic,int * mpi_distribute_none,int * mpi_distribute_dflt_darg)398 static int get_mpi_constants(size_t * mpi_datatype_size,
399                              int *mpi_order_c, int *mpi_order_fortran,
400                              int *mpi_distribute_block,
401                              int *mpi_distribute_cyclic,
402                              int *mpi_distribute_none, int *mpi_distribute_dflt_darg)
403 {
404     *mpi_datatype_size = sizeof(MPI_Datatype);
405     *mpi_order_c = MPI_ORDER_C;
406     *mpi_order_fortran = MPI_ORDER_FORTRAN;
407     *mpi_distribute_block = MPI_DISTRIBUTE_BLOCK;
408     *mpi_distribute_cyclic = MPI_DISTRIBUTE_CYCLIC;
409     *mpi_distribute_none = MPI_DISTRIBUTE_NONE;
410     *mpi_distribute_dflt_darg = MPI_DISTRIBUTE_DFLT_DARG;
411 
412     return HCOLL_SUCCESS;
413 }
414 
415 #endif
416