1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 #include "ibcast.h"
8 
sched_test_length(MPIR_Comm * comm,int tag,void * state)9 static int sched_test_length(MPIR_Comm * comm, int tag, void *state)
10 {
11     int mpi_errno = MPI_SUCCESS;
12     MPI_Aint recv_size;
13     struct MPII_Ibcast_state *ibcast_state = (struct MPII_Ibcast_state *) state;
14     MPIR_Get_count_impl(&ibcast_state->status, MPI_BYTE, &recv_size);
15     if (ibcast_state->n_bytes != recv_size || ibcast_state->status.MPI_ERROR != MPI_SUCCESS) {
16         mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE,
17                                          __func__, __LINE__, MPI_ERR_OTHER,
18                                          "**collective_size_mismatch",
19                                          "**collective_size_mismatch %d %d", ibcast_state->n_bytes,
20                                          recv_size);
21     }
22     return mpi_errno;
23 }
24 
25 /* This routine purely handles the hierarchical version of bcast, and does not
26  * currently make any decision about which particular algorithm to use for any
27  * subcommunicator. */
MPIR_Ibcast_intra_sched_smp(void * buffer,int count,MPI_Datatype datatype,int root,MPIR_Comm * comm_ptr,MPIR_Sched_t s)28 int MPIR_Ibcast_intra_sched_smp(void *buffer, int count, MPI_Datatype datatype, int root,
29                                 MPIR_Comm * comm_ptr, MPIR_Sched_t s)
30 {
31     int mpi_errno = MPI_SUCCESS;
32     MPI_Aint type_size;
33     struct MPII_Ibcast_state *ibcast_state;
34     MPIR_SCHED_CHKPMEM_DECL(1);
35 
36 #ifdef HAVE_ERROR_CHECKING
37     MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr));
38 #endif
39     MPIR_SCHED_CHKPMEM_MALLOC(ibcast_state, struct MPII_Ibcast_state *,
40                               sizeof(struct MPII_Ibcast_state), mpi_errno, "MPI_Status",
41                               MPL_MEM_BUFFER);
42 
43     MPIR_Datatype_get_size_macro(datatype, type_size);
44 
45     ibcast_state->n_bytes = type_size * count;
46     /* TODO insert packing here */
47 
48     /* send to intranode-rank 0 on the root's node */
49     if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) {   /* is not the node root (0) *//* and is on our node (!-1) */
50         if (root == comm_ptr->rank) {
51             mpi_errno = MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr->node_comm, s);
52             MPIR_ERR_CHECK(mpi_errno);
53         } else if (0 == comm_ptr->node_comm->rank) {
54             mpi_errno =
55                 MPIR_Sched_recv_status(buffer, count, datatype,
56                                        MPIR_Get_intranode_rank(comm_ptr, root), comm_ptr->node_comm,
57                                        &ibcast_state->status, s);
58             MPIR_ERR_CHECK(mpi_errno);
59 #ifdef HAVE_ERROR_CHECKING
60             MPIR_SCHED_BARRIER(s);
61             mpi_errno = MPIR_Sched_cb(&sched_test_length, ibcast_state, s);
62             MPIR_ERR_CHECK(mpi_errno);
63 #endif
64         }
65         MPIR_SCHED_BARRIER(s);
66     }
67 
68     /* perform the internode broadcast */
69     if (comm_ptr->node_roots_comm != NULL) {
70         mpi_errno = MPIR_Ibcast_sched_auto(buffer, count, datatype,
71                                            MPIR_Get_internode_rank(comm_ptr, root),
72                                            comm_ptr->node_roots_comm, s);
73         MPIR_ERR_CHECK(mpi_errno);
74 
75         /* don't allow the local ops for the intranode phase to start until this has completed */
76         MPIR_SCHED_BARRIER(s);
77     }
78     /* perform the intranode broadcast on all except for the root's node */
79     if (comm_ptr->node_comm != NULL) {
80         mpi_errno = MPIR_Ibcast_sched_auto(buffer, count, datatype, 0, comm_ptr->node_comm, s);
81         MPIR_ERR_CHECK(mpi_errno);
82     }
83 
84     MPIR_SCHED_CHKPMEM_COMMIT(s);
85   fn_exit:
86     return mpi_errno;
87   fn_fail:
88     MPIR_SCHED_CHKPMEM_REAP(s);
89     goto fn_exit;
90 }
91