1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 #include "ibcast.h"
8 
9 /*
10  * Broadcast based on a scatter followed by an allgather.
11 
12  * We first scatter the buffer using a binomial tree algorithm. This
13  * costs lgp.alpha + n.((p-1)/p).beta
14  *
15  * If the datatype is contiguous, we treat the data as bytes and
16  * divide (scatter) it among processes by using ceiling division.
17  * For the noncontiguous cases, we first pack the data into a
18  * temporary buffer by using MPI_Pack, scatter it as bytes, and
19  * unpack it after the allgather.
20  *
21  * We use a ring algorithm for the allgather, which takes p-1 steps.
22  * This may perform better than recursive doubling for long messages
23  * and medium-sized non-power-of-two messages.
24  *
25  * Total Cost = (lgp+p-1).alpha + 2.n.((p-1)/p).beta
26  */
MPIR_Ibcast_intra_sched_scatter_ring_allgather(void * buffer,int count,MPI_Datatype datatype,int root,MPIR_Comm * comm_ptr,MPIR_Sched_t s)27 int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, int count, MPI_Datatype datatype,
28                                                    int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s)
29 {
30     int mpi_errno = MPI_SUCCESS;
31     int comm_size, rank;
32     int is_contig, type_size, nbytes;
33     int scatter_size, curr_size;
34     int i, j, jnext, left, right;
35     MPI_Aint true_extent, true_lb;
36     void *tmp_buf = NULL;
37 
38     struct MPII_Ibcast_state *ibcast_state;
39     MPIR_SCHED_CHKPMEM_DECL(4);
40 
41     comm_size = comm_ptr->local_size;
42     rank = comm_ptr->rank;
43 
44     /* If there is only one process, return */
45     if (comm_size == 1)
46         goto fn_exit;
47 
48     if (HANDLE_IS_BUILTIN(datatype))
49         is_contig = 1;
50     else {
51         MPIR_Datatype_is_contig(datatype, &is_contig);
52     }
53 
54     MPIR_SCHED_CHKPMEM_MALLOC(ibcast_state, struct MPII_Ibcast_state *,
55                               sizeof(struct MPII_Ibcast_state), mpi_errno, "MPI_Status",
56                               MPL_MEM_BUFFER);
57     MPIR_Datatype_get_size_macro(datatype, type_size);
58     nbytes = type_size * count;
59     ibcast_state->n_bytes = nbytes;
60     ibcast_state->curr_bytes = 0;
61     if (is_contig) {
62         /* contiguous, no need to pack. */
63         MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
64 
65         tmp_buf = (char *) buffer + true_lb;
66     } else {
67         MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, nbytes, mpi_errno, "tmp_buf", MPL_MEM_BUFFER);
68 
69         if (rank == root) {
70             mpi_errno = MPIR_Sched_copy(buffer, count, datatype, tmp_buf, nbytes, MPI_BYTE, s);
71             MPIR_ERR_CHECK(mpi_errno);
72             MPIR_SCHED_BARRIER(s);
73         }
74     }
75 
76     mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, nbytes, s);
77     MPIR_ERR_CHECK(mpi_errno);
78 
79     /* this is the block size used for the scatter operation */
80     scatter_size = (nbytes + comm_size - 1) / comm_size;        /* ceiling division */
81 
82     /* curr_size is the amount of data that this process now has stored in
83      * buffer at byte offset (rank*scatter_size) */
84     curr_size = MPL_MIN(scatter_size, (nbytes - (rank * scatter_size)));
85     if (curr_size < 0)
86         curr_size = 0;
87     /* curr_size bytes already inplace */
88     ibcast_state->curr_bytes = curr_size;
89 
90     /* long-message allgather or medium-size but non-power-of-two. use ring algorithm. */
91 
92     left = (comm_size + rank - 1) % comm_size;
93     right = (rank + 1) % comm_size;
94 
95     j = rank;
96     jnext = left;
97     for (i = 1; i < comm_size; i++) {
98         int left_count, right_count, left_disp, right_disp, rel_j, rel_jnext;
99 
100         rel_j = (j - root + comm_size) % comm_size;
101         rel_jnext = (jnext - root + comm_size) % comm_size;
102         left_count = MPL_MIN(scatter_size, (nbytes - rel_jnext * scatter_size));
103         if (left_count < 0)
104             left_count = 0;
105         left_disp = rel_jnext * scatter_size;
106         right_count = MPL_MIN(scatter_size, (nbytes - rel_j * scatter_size));
107         if (right_count < 0)
108             right_count = 0;
109         right_disp = rel_j * scatter_size;
110 
111         mpi_errno = MPIR_Sched_send(((char *) tmp_buf + right_disp),
112                                     right_count, MPI_BYTE, right, comm_ptr, s);
113         MPIR_ERR_CHECK(mpi_errno);
114         /* sendrecv, no barrier here */
115         mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + left_disp),
116                                            left_count, MPI_BYTE, left, comm_ptr,
117                                            &ibcast_state->status, s);
118         MPIR_ERR_CHECK(mpi_errno);
119         MPIR_SCHED_BARRIER(s);
120         mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_add_length, ibcast_state, s);
121         MPIR_ERR_CHECK(mpi_errno);
122         MPIR_SCHED_BARRIER(s);
123 
124         j = jnext;
125         jnext = (comm_size + jnext - 1) % comm_size;
126     }
127     mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_test_curr_length, ibcast_state, s);
128     MPIR_ERR_CHECK(mpi_errno);
129 
130     if (!is_contig && rank != root) {
131         mpi_errno = MPIR_Sched_copy(tmp_buf, nbytes, MPI_BYTE, buffer, count, datatype, s);
132         MPIR_ERR_CHECK(mpi_errno);
133     }
134 
135     MPIR_SCHED_CHKPMEM_COMMIT(s);
136   fn_exit:
137     return mpi_errno;
138   fn_fail:
139     MPIR_SCHED_CHKPMEM_REAP(s);
140     goto fn_exit;
141 }
142