1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpiimpl.h"
7 #include "ibcast.h"
8
9 /*
10 * Broadcast based on a scatter followed by an allgather.
11
12 * We first scatter the buffer using a binomial tree algorithm. This
13 * costs lgp.alpha + n.((p-1)/p).beta
14 *
15 * If the datatype is contiguous, we treat the data as bytes and
16 * divide (scatter) it among processes by using ceiling division.
17 * For the noncontiguous cases, we first pack the data into a
18 * temporary buffer by using MPI_Pack, scatter it as bytes, and
19 * unpack it after the allgather.
20 *
21 * We use a ring algorithm for the allgather, which takes p-1 steps.
22 * This may perform better than recursive doubling for long messages
23 * and medium-sized non-power-of-two messages.
24 *
25 * Total Cost = (lgp+p-1).alpha + 2.n.((p-1)/p).beta
26 */
MPIR_Ibcast_intra_sched_scatter_ring_allgather(void * buffer,int count,MPI_Datatype datatype,int root,MPIR_Comm * comm_ptr,MPIR_Sched_t s)27 int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, int count, MPI_Datatype datatype,
28 int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s)
29 {
30 int mpi_errno = MPI_SUCCESS;
31 int comm_size, rank;
32 int is_contig, type_size, nbytes;
33 int scatter_size, curr_size;
34 int i, j, jnext, left, right;
35 MPI_Aint true_extent, true_lb;
36 void *tmp_buf = NULL;
37
38 struct MPII_Ibcast_state *ibcast_state;
39 MPIR_SCHED_CHKPMEM_DECL(4);
40
41 comm_size = comm_ptr->local_size;
42 rank = comm_ptr->rank;
43
44 /* If there is only one process, return */
45 if (comm_size == 1)
46 goto fn_exit;
47
48 if (HANDLE_IS_BUILTIN(datatype))
49 is_contig = 1;
50 else {
51 MPIR_Datatype_is_contig(datatype, &is_contig);
52 }
53
54 MPIR_SCHED_CHKPMEM_MALLOC(ibcast_state, struct MPII_Ibcast_state *,
55 sizeof(struct MPII_Ibcast_state), mpi_errno, "MPI_Status",
56 MPL_MEM_BUFFER);
57 MPIR_Datatype_get_size_macro(datatype, type_size);
58 nbytes = type_size * count;
59 ibcast_state->n_bytes = nbytes;
60 ibcast_state->curr_bytes = 0;
61 if (is_contig) {
62 /* contiguous, no need to pack. */
63 MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
64
65 tmp_buf = (char *) buffer + true_lb;
66 } else {
67 MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, nbytes, mpi_errno, "tmp_buf", MPL_MEM_BUFFER);
68
69 if (rank == root) {
70 mpi_errno = MPIR_Sched_copy(buffer, count, datatype, tmp_buf, nbytes, MPI_BYTE, s);
71 MPIR_ERR_CHECK(mpi_errno);
72 MPIR_SCHED_BARRIER(s);
73 }
74 }
75
76 mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, nbytes, s);
77 MPIR_ERR_CHECK(mpi_errno);
78
79 /* this is the block size used for the scatter operation */
80 scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */
81
82 /* curr_size is the amount of data that this process now has stored in
83 * buffer at byte offset (rank*scatter_size) */
84 curr_size = MPL_MIN(scatter_size, (nbytes - (rank * scatter_size)));
85 if (curr_size < 0)
86 curr_size = 0;
87 /* curr_size bytes already inplace */
88 ibcast_state->curr_bytes = curr_size;
89
90 /* long-message allgather or medium-size but non-power-of-two. use ring algorithm. */
91
92 left = (comm_size + rank - 1) % comm_size;
93 right = (rank + 1) % comm_size;
94
95 j = rank;
96 jnext = left;
97 for (i = 1; i < comm_size; i++) {
98 int left_count, right_count, left_disp, right_disp, rel_j, rel_jnext;
99
100 rel_j = (j - root + comm_size) % comm_size;
101 rel_jnext = (jnext - root + comm_size) % comm_size;
102 left_count = MPL_MIN(scatter_size, (nbytes - rel_jnext * scatter_size));
103 if (left_count < 0)
104 left_count = 0;
105 left_disp = rel_jnext * scatter_size;
106 right_count = MPL_MIN(scatter_size, (nbytes - rel_j * scatter_size));
107 if (right_count < 0)
108 right_count = 0;
109 right_disp = rel_j * scatter_size;
110
111 mpi_errno = MPIR_Sched_send(((char *) tmp_buf + right_disp),
112 right_count, MPI_BYTE, right, comm_ptr, s);
113 MPIR_ERR_CHECK(mpi_errno);
114 /* sendrecv, no barrier here */
115 mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + left_disp),
116 left_count, MPI_BYTE, left, comm_ptr,
117 &ibcast_state->status, s);
118 MPIR_ERR_CHECK(mpi_errno);
119 MPIR_SCHED_BARRIER(s);
120 mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_add_length, ibcast_state, s);
121 MPIR_ERR_CHECK(mpi_errno);
122 MPIR_SCHED_BARRIER(s);
123
124 j = jnext;
125 jnext = (comm_size + jnext - 1) % comm_size;
126 }
127 mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_test_curr_length, ibcast_state, s);
128 MPIR_ERR_CHECK(mpi_errno);
129
130 if (!is_contig && rank != root) {
131 mpi_errno = MPIR_Sched_copy(tmp_buf, nbytes, MPI_BYTE, buffer, count, datatype, s);
132 MPIR_ERR_CHECK(mpi_errno);
133 }
134
135 MPIR_SCHED_CHKPMEM_COMMIT(s);
136 fn_exit:
137 return mpi_errno;
138 fn_fail:
139 MPIR_SCHED_CHKPMEM_REAP(s);
140 goto fn_exit;
141 }
142